33import logging
44import os
55import sys
6- import shutil
6+ from datetime import datetime
77from pymic .util .parse_config import *
88from pymic .net_run .agent_cls import ClassificationAgent
99from pymic .net_run .agent_seg import SegmentationAgent
@@ -12,34 +12,31 @@ def main():
1212 """
1313 The main function for running a network for training or inference.
1414 """
15- if (len (sys .argv ) < 3 ):
16- print ('Number of arguments should be 3 . e.g.' )
17- print (' pymic_run train config.cfg' )
15+ if (len (sys .argv ) < 2 ):
16+ print ('Number of arguments should be 2 . e.g.' )
17+ print (' pymic_test config.cfg' )
1818 exit ()
19- stage = str (sys .argv [1 ])
20- cfg_file = str (sys .argv [2 ])
19+ cfg_file = str (sys .argv [1 ])
2120 config = parse_config (cfg_file )
2221 config = synchronize_config (config )
23- log_dir = config ['training ' ]['ckpt_save_dir ' ]
22+ log_dir = config ['testing ' ]['output_dir ' ]
2423 if (not os .path .exists (log_dir )):
2524 os .makedirs (log_dir , exist_ok = True )
26- if (stage == "train" ):
27- dst_cfg = cfg_file if "/" not in cfg_file else cfg_file .split ("/" )[- 1 ]
28- shutil .copy (cfg_file , log_dir + "/" + dst_cfg )
25+
2926 if sys .version .startswith ("3.9" ):
30- logging .basicConfig (filename = log_dir + "/log_{0:} .txt" . format ( stage ), level = logging . INFO ,
31- format = '%(message)s' , force = True ) # for python 3.9
27+ logging .basicConfig (filename = log_dir + "/log_test .txt" ,
28+ level = logging . INFO , format = '%(message)s' , force = True ) # for python 3.9
3229 else :
33- logging .basicConfig (filename = log_dir + "/log_{0:} .txt" . format ( stage ), level = logging . INFO ,
34- format = '%(message)s' ) # for python 3.6
30+ logging .basicConfig (filename = log_dir + "/log_test .txt" ,
31+ level = logging . INFO , format = '%(message)s' ) # for python 3.6
3532 logging .getLogger ().addHandler (logging .StreamHandler (sys .stdout ))
3633 logging_config (config )
3734 task = config ['dataset' ]['task_type' ]
3835 assert task in ['cls' , 'cls_nexcl' , 'seg' ]
3936 if (task == 'cls' or task == 'cls_nexcl' ):
40- agent = ClassificationAgent (config , stage )
37+ agent = ClassificationAgent (config , 'test' )
4138 else :
42- agent = SegmentationAgent (config , stage )
39+ agent = SegmentationAgent (config , 'test' )
4340 agent .run ()
4441
4542if __name__ == "__main__" :
0 commit comments