@@ -363,11 +363,7 @@ def infer(self):
363363 device_ids = self .config ['testing' ]['gpus' ]
364364 device = torch .device ("cuda:{0:}" .format (device_ids [0 ]))
365365 self .net .to (device )
366- # load network parameters and set the network as evaluation mode
367- checkpoint_name = self .get_checkpoint_name ()
368- checkpoint = torch .load (checkpoint_name , map_location = device )
369- self .net .load_state_dict (checkpoint ['model_state_dict' ])
370-
366+
371367 if (self .config ['testing' ]['evaluation_mode' ] == True ):
372368 self .net .eval ()
373369 if (self .config ['testing' ]['test_time_dropout' ] == True ):
@@ -377,6 +373,20 @@ def test_time_dropout(m):
377373 m .train ()
378374 self .net .apply (test_time_dropout )
379375
376+ ckpt_mode = self .config ['testing' ]['ckpt_mode' ]
377+ ckpt_name = self .get_checkpoint_name ()
378+ if (ckpt_mode == 3 ):
379+ assert (isinstance (ckpt_name , (tuple , list )))
380+ self .infer_with_multiple_checkpoints ()
381+ return
382+ else :
383+ if (isinstance (ckpt_name , (tuple , list ))):
384+ raise ValueError ("ckpt_mode should be 3 if ckpt_name is a list" )
385+
386+ # load network parameters and set the network as evaluation mode
387+ checkpoint = torch .load (ckpt_name , map_location = device )
388+ self .net .load_state_dict (checkpoint ['model_state_dict' ])
389+
380390 infer_cfg = self .config ['testing' ]
381391 infer_cfg ['class_num' ] = self .config ['network' ]['class_num' ]
382392 infer_obj = Inferer (self .net , infer_cfg )
@@ -416,6 +426,59 @@ def test_time_dropout(m):
416426 time_avg , time_std = infer_time_list .mean (), infer_time_list .std ()
417427 print ("testing time {0:} +/- {1:}" .format (time_avg , time_std ))
418428
429+ def infer_with_multiple_checkpoints (self ):
430+ """
431+ inference with ensemble of multilple check points
432+ """
433+ device_ids = self .config ['testing' ]['gpus' ]
434+ device = torch .device ("cuda:{0:}" .format (device_ids [0 ]))
435+
436+ ckpt_names = self .config ['testing' ]['ckpt_name' ]
437+ infer_cfg = self .config ['testing' ]
438+ infer_cfg ['class_num' ] = self .config ['network' ]['class_num' ]
439+ infer_obj = Inferer (self .net , infer_cfg )
440+ infer_time_list = []
441+ with torch .no_grad ():
442+ for data in self .test_loder :
443+ images = self .convert_tensor_type (data ['image' ])
444+ images = images .to (device )
445+
446+ # for debug
447+ # for i in range(images.shape[0]):
448+ # image_i = images[i][0]
449+ # label_i = images[i][0]
450+ # image_name = "temp/{0:}_image.nii.gz".format(names[0])
451+ # label_name = "temp/{0:}_label.nii.gz".format(names[0])
452+ # save_nd_array_as_image(image_i, image_name, reference_name = None)
453+ # save_nd_array_as_image(label_i, label_name, reference_name = None)
454+ # continue
455+ start_time = time .time ()
456+ predict_list = []
457+ for ckpt_name in ckpt_names :
458+ checkpoint = torch .load (ckpt_name , map_location = device )
459+ self .net .load_state_dict (checkpoint ['model_state_dict' ])
460+
461+ pred = infer_obj .run (images )
462+ # convert tensor to numpy
463+ if (isinstance (pred , (tuple , list ))):
464+ pred = [item .cpu ().numpy () for item in pred ]
465+ else :
466+ pred = pred .cpu ().numpy ()
467+ predict_list .append (pred )
468+ pred = np .mean (predict_list , axis = 0 )
469+ data ['predict' ] = pred
470+ # inverse transform
471+ for transform in self .transform_list [::- 1 ]:
472+ if (transform .inverse ):
473+ data = transform .inverse_transform_for_prediction (data )
474+
475+ infer_time = time .time () - start_time
476+ infer_time_list .append (infer_time )
477+ self .save_ouputs (data )
478+ infer_time_list = np .asarray (infer_time_list )
479+ time_avg , time_std = infer_time_list .mean (), infer_time_list .std ()
480+ print ("testing time {0:} +/- {1:}" .format (time_avg , time_std ))
481+
419482 def save_ouputs (self , data ):
420483 output_dir = self .config ['testing' ]['output_dir' ]
421484 ignore_dir = self .config ['testing' ].get ('filename_ignore_dir' , True )
0 commit comments