@@ -267,36 +267,29 @@ def print_error_action(action_error_sum, is_train):
267267 args .checkpoint = "./checkpoint/" + folder_name
268268 elif args .train == False :
269269 # create a new folder for the test results
270- args .previous_dir = os .path .dirname (args .saved_model_path )
270+ args .previous_dir = os .path .dirname (args .model_weights_path )
271271 args .checkpoint = os .path .join (args .previous_dir , folder_name )
272272
273273 if not os .path .exists (args .checkpoint ):
274274 os .makedirs (args .checkpoint )
275275
276276 # backup files
277- # import shutil
278- # file_name = os.path.basename(__file__)
279- # shutil.copyfile(
280- # src=file_name,
281- # dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name),
282- # )
283- # shutil.copyfile(
284- # src="common/arguments.py",
285- # dst=os.path.join(args.checkpoint, args.create_time + "_arguments.py"),
286- # )
287- # if getattr(args, "model_path", ""):
288- # model_src_path = os.path.abspath(args.model_path)
289- # model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
290- # shutil.copyfile(
291- # src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name)
292- # )
293- # shutil.copyfile(
294- # src="common/utils.py",
295- # dst=os.path.join(args.checkpoint, args.create_time + "_utils.py"),
296- # )
297- # sh_base = os.path.basename(args.sh_file)
298- # dst_name = f"{args.create_time}_" + sh_base
299- # shutil.copyfile(src=args.sh_file, dst=os.path.join(args.checkpoint, dst_name))
277+ import shutil
278+ script_path = os .path .abspath (__file__ )
279+ script_name = os .path .basename (script_path )
280+ shutil .copyfile (
281+ src = script_path ,
282+ dst = os .path .join (args .checkpoint , args .create_time + "_" + script_name ),
283+ )
284+ if getattr (args , "model_path" , "" ):
285+ model_src_path = os .path .abspath (args .model_path )
286+ model_dst_name = f"{ args .create_time } _" + os .path .basename (model_src_path )
287+ shutil .copyfile (
288+ src = model_src_path , dst = os .path .join (args .checkpoint , model_dst_name )
289+ )
290+ sh_base = os .path .basename (args .sh_file )
291+ dst_name = f"{ args .create_time } _" + sh_base
292+ shutil .copyfile (src = args .sh_file , dst = os .path .join (args .checkpoint , dst_name ))
300293
301294 logging .basicConfig (
302295 format = "%(asctime)s %(message)s" ,
@@ -347,7 +340,7 @@ def print_error_action(action_error_sum, is_train):
347340
348341 if args .reload :
349342 model_dict = model ["CFM" ].state_dict ()
350- model_path = args .saved_model_path
343+ model_path = args .model_weights_path
351344 print (model_path )
352345 pre_dict = torch .load (model_path )
353346 for name , key in model_dict .items ():
0 commit comments