Skip to content

Commit aca5935

Browse files
committed
support multiple checkpoint ensemble during inference
1 parent f91d617 commit aca5935

1 file changed

Lines changed: 68 additions & 5 deletions

File tree

pymic/net_run/agent_seg.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,7 @@ def infer(self):
363363
device_ids = self.config['testing']['gpus']
364364
device = torch.device("cuda:{0:}".format(device_ids[0]))
365365
self.net.to(device)
366-
# load network parameters and set the network as evaluation mode
367-
checkpoint_name = self.get_checkpoint_name()
368-
checkpoint = torch.load(checkpoint_name, map_location = device)
369-
self.net.load_state_dict(checkpoint['model_state_dict'])
370-
366+
371367
if(self.config['testing']['evaluation_mode'] == True):
372368
self.net.eval()
373369
if(self.config['testing']['test_time_dropout'] == True):
@@ -377,6 +373,20 @@ def test_time_dropout(m):
377373
m.train()
378374
self.net.apply(test_time_dropout)
379375

376+
ckpt_mode = self.config['testing']['ckpt_mode']
377+
ckpt_name = self.get_checkpoint_name()
378+
if(ckpt_mode == 3):
379+
assert(isinstance(ckpt_name, (tuple, list)))
380+
self.infer_with_multiple_checkpoints()
381+
return
382+
else:
383+
if(isinstance(ckpt_name, (tuple, list))):
384+
raise ValueError("ckpt_mode should be 3 if ckpt_name is a list")
385+
386+
# load network parameters and set the network as evaluation mode
387+
checkpoint = torch.load(ckpt_name, map_location = device)
388+
self.net.load_state_dict(checkpoint['model_state_dict'])
389+
380390
infer_cfg = self.config['testing']
381391
infer_cfg['class_num'] = self.config['network']['class_num']
382392
infer_obj = Inferer(self.net, infer_cfg)
@@ -416,6 +426,59 @@ def test_time_dropout(m):
416426
time_avg, time_std = infer_time_list.mean(), infer_time_list.std()
417427
print("testing time {0:} +/- {1:}".format(time_avg, time_std))
418428

429+
def infer_with_multiple_checkpoints(self):
430+
"""
431+
inference with ensemble of multilple check points
432+
"""
433+
device_ids = self.config['testing']['gpus']
434+
device = torch.device("cuda:{0:}".format(device_ids[0]))
435+
436+
ckpt_names = self.config['testing']['ckpt_name']
437+
infer_cfg = self.config['testing']
438+
infer_cfg['class_num'] = self.config['network']['class_num']
439+
infer_obj = Inferer(self.net, infer_cfg)
440+
infer_time_list = []
441+
with torch.no_grad():
442+
for data in self.test_loder:
443+
images = self.convert_tensor_type(data['image'])
444+
images = images.to(device)
445+
446+
# for debug
447+
# for i in range(images.shape[0]):
448+
# image_i = images[i][0]
449+
# label_i = images[i][0]
450+
# image_name = "temp/{0:}_image.nii.gz".format(names[0])
451+
# label_name = "temp/{0:}_label.nii.gz".format(names[0])
452+
# save_nd_array_as_image(image_i, image_name, reference_name = None)
453+
# save_nd_array_as_image(label_i, label_name, reference_name = None)
454+
# continue
455+
start_time = time.time()
456+
predict_list = []
457+
for ckpt_name in ckpt_names:
458+
checkpoint = torch.load(ckpt_name, map_location = device)
459+
self.net.load_state_dict(checkpoint['model_state_dict'])
460+
461+
pred = infer_obj.run(images)
462+
# convert tensor to numpy
463+
if(isinstance(pred, (tuple, list))):
464+
pred = [item.cpu().numpy() for item in pred]
465+
else:
466+
pred = pred.cpu().numpy()
467+
predict_list.append(pred)
468+
pred = np.mean(predict_list, axis=0)
469+
data['predict'] = pred
470+
# inverse transform
471+
for transform in self.transform_list[::-1]:
472+
if (transform.inverse):
473+
data = transform.inverse_transform_for_prediction(data)
474+
475+
infer_time = time.time() - start_time
476+
infer_time_list.append(infer_time)
477+
self.save_ouputs(data)
478+
infer_time_list = np.asarray(infer_time_list)
479+
time_avg, time_std = infer_time_list.mean(), infer_time_list.std()
480+
print("testing time {0:} +/- {1:}".format(time_avg, time_std))
481+
419482
def save_ouputs(self, data):
420483
output_dir = self.config['testing']['output_dir']
421484
ignore_dir = self.config['testing'].get('filename_ignore_dir', True)

0 commit comments

Comments
 (0)