Skip to content

Commit 7002699

Browse files
committed
Refactor backup file handling in main_animal3d.py to enable file copying for checkpoints, including model and script files, improving data management during training.
1 parent 1673cad commit 7002699

1 file changed

Lines changed: 12 additions & 15 deletions

File tree

animals/scripts/main_animal3d.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -217,21 +217,18 @@ def get_parameter_number(net):
217217
os.makedirs(args.checkpoint)
218218

219219
# backup files
220-
# import shutil
221-
# file_path = os.path.abspath(__file__)
222-
# file_name = os.path.basename(file_path)
223-
# shutil.copyfile(src=file_path, dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name))
224-
# shutil.copyfile(src=os.path.abspath("common/arguments.py"), dst=os.path.join(args.checkpoint, args.create_time + "_arguments.py"))
225-
# # backup the selected model file (from --model_path if provided)
226-
# if getattr(args, 'model_path', ''):
227-
# model_src_path = os.path.abspath(args.model_path)
228-
# model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
229-
# shutil.copyfile(src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name))
230-
# # shutil.copyfile(src="common/utils.py", dst = os.path.join(args.checkpoint, args.create_time + "_utils.py"))
231-
# sh_base = os.path.basename(args.sh_file)
232-
# dst_name = f"{args.create_time}_" + sh_base
233-
# sh_src = os.path.abspath(args.sh_file)
234-
# shutil.copyfile(src=sh_src, dst=os.path.join(args.checkpoint, dst_name))
220+
import shutil
221+
file_path = os.path.abspath(__file__)
222+
file_name = os.path.basename(file_path)
223+
shutil.copyfile(src=file_path, dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name))
224+
if getattr(args, 'model_path', ''):
225+
model_src_path = os.path.abspath(args.model_path)
226+
model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
227+
shutil.copyfile(src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name))
228+
sh_base = os.path.basename(args.sh_file)
229+
dst_name = f"{args.create_time}_" + sh_base
230+
sh_src = os.path.abspath(args.sh_file)
231+
shutil.copyfile(src=sh_src, dst=os.path.join(args.checkpoint, dst_name))
235232

236233
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', \
237234
filename=os.path.join(args.checkpoint, 'train.log'), level=logging.INFO)

0 commit comments

Comments
 (0)