Skip to content

Commit db005f3

Browse files
committed
Refactor file backup process in FMPose3D_main.py
- Updated the backup file logic to use the new model_weights_path instead of saved_model_path for consistency. - Cleaned up commented-out code and streamlined the backup process for better readability and maintainability.
1 parent d6dcf06 commit db005f3

1 file changed

Lines changed: 18 additions & 25 deletions

File tree

scripts/FMPose3D_main.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -267,36 +267,29 @@ def print_error_action(action_error_sum, is_train):
267267
args.checkpoint = "./checkpoint/" + folder_name
268268
elif args.train == False:
269269
# create a new folder for the test results
270-
args.previous_dir = os.path.dirname(args.saved_model_path)
270+
args.previous_dir = os.path.dirname(args.model_weights_path)
271271
args.checkpoint = os.path.join(args.previous_dir, folder_name)
272272

273273
if not os.path.exists(args.checkpoint):
274274
os.makedirs(args.checkpoint)
275275

276276
# backup files
277-
# import shutil
278-
# file_name = os.path.basename(__file__)
279-
# shutil.copyfile(
280-
# src=file_name,
281-
# dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name),
282-
# )
283-
# shutil.copyfile(
284-
# src="common/arguments.py",
285-
# dst=os.path.join(args.checkpoint, args.create_time + "_arguments.py"),
286-
# )
287-
# if getattr(args, "model_path", ""):
288-
# model_src_path = os.path.abspath(args.model_path)
289-
# model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
290-
# shutil.copyfile(
291-
# src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name)
292-
# )
293-
# shutil.copyfile(
294-
# src="common/utils.py",
295-
# dst=os.path.join(args.checkpoint, args.create_time + "_utils.py"),
296-
# )
297-
# sh_base = os.path.basename(args.sh_file)
298-
# dst_name = f"{args.create_time}_" + sh_base
299-
# shutil.copyfile(src=args.sh_file, dst=os.path.join(args.checkpoint, dst_name))
277+
import shutil
278+
script_path = os.path.abspath(__file__)
279+
script_name = os.path.basename(script_path)
280+
shutil.copyfile(
281+
src=script_path,
282+
dst=os.path.join(args.checkpoint, args.create_time + "_" + script_name),
283+
)
284+
if getattr(args, "model_path", ""):
285+
model_src_path = os.path.abspath(args.model_path)
286+
model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
287+
shutil.copyfile(
288+
src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name)
289+
)
290+
sh_base = os.path.basename(args.sh_file)
291+
dst_name = f"{args.create_time}_" + sh_base
292+
shutil.copyfile(src=args.sh_file, dst=os.path.join(args.checkpoint, dst_name))
300293

301294
logging.basicConfig(
302295
format="%(asctime)s %(message)s",
@@ -347,7 +340,7 @@ def print_error_action(action_error_sum, is_train):
347340

348341
if args.reload:
349342
model_dict = model["CFM"].state_dict()
350-
model_path = args.saved_model_path
343+
model_path = args.model_weights_path
351344
print(model_path)
352345
pre_dict = torch.load(model_path)
353346
for name, key in model_dict.items():

0 commit comments

Comments
 (0)