Skip to content

Commit fbf0fe7

Browse files
committed
hotfix(eval): updating num_frames into eval.
1 parent d0c7c59 commit fbf0fe7

3 files changed

Lines changed: 9 additions & 6 deletions

File tree

eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def main(cfg):
4848
checkpoint_params = DictConfig(torch_load_ckpt["hyper_parameters"])
4949
cfg.output = checkpoint_params.cfg.output + f"-e{torch_load_ckpt['epoch']}-{cfg.av2_mode}-v{cfg.leaderboard_version}"
5050
cfg.model.update(checkpoint_params.cfg.model)
51+
cfg.num_frames = cfg.model.target.get('num_frames', checkpoint_params.cfg.get('num_frames', cfg.get('num_frames', 2)))
5152

5253
mymodel = ModelWrapper.load_from_checkpoint(cfg.checkpoint, cfg=cfg, eval=True)
5354
print(f"\n---LOG[eval]: Loaded model from {cfg.checkpoint}. The backbone network is {checkpoint_params.cfg.model.name}.\n")
@@ -63,7 +64,7 @@ def main(cfg):
6364
trainer.validate(model = mymodel, \
6465
dataloaders = DataLoader( \
6566
HDF5Dataset(cfg.dataset_path + f"/{cfg.av2_mode}", \
66-
n_frames=checkpoint_params.cfg.num_frames if 'num_frames' in checkpoint_params.cfg else 2, \
67+
n_frames=cfg.num_frames, \
6768
eval=True, leaderboard_version=cfg.leaderboard_version), \
6869
batch_size=1, shuffle=False))
6970
wandb.finish()

save.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def main(cfg):
4545
checkpoint_params = DictConfig(torch.load(cfg.checkpoint)["hyper_parameters"])
4646
cfg.output = checkpoint_params.cfg.output
4747
cfg.model.update(checkpoint_params.cfg.model)
48+
cfg.num_frames = cfg.model.target.get('num_frames', checkpoint_params.cfg.get('num_frames', cfg.get('num_frames', 2)))
4849
mymodel = ModelWrapper.load_from_checkpoint(cfg.checkpoint, cfg=cfg, eval=True)
4950

5051
wandb_logger = WandbLogger(save_dir=output_dir,
@@ -57,7 +58,7 @@ def main(cfg):
5758
# NOTE(Qingwen): search & check in pl_model.py : def test_step(self, batch, res_dict)
5859
trainer.test(model = mymodel, \
5960
dataloaders = DataLoader(\
60-
HDF5Dataset(cfg.dataset_path, n_frames=checkpoint_params.cfg.num_frames if 'num_frames' in checkpoint_params.cfg else 2), \
61+
HDF5Dataset(cfg.dataset_path, n_frames=cfg.num_frames), \
6162
batch_size=1, shuffle=False))
6263
wandb.finish()
6364

src/trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(self, cfg, eval=False):
9494
print(f"We are in {cfg.av2_mode}, results will be saved in: {self.save_res_path} with version: {self.leaderboard_version} format for online leaderboard.")
9595

9696
# self.test_total_num = 0
97+
print(cfg)
9798
self.save_hyperparameters()
9899

99100
# FIXME(Qingwen 2025-08-20): update the loss_calculation fn alone to make all things pretty here....
@@ -208,7 +209,7 @@ def on_validation_epoch_end(self):
208209
self.model.timer.print(random_colors=False, bold=False)
209210

210211
if self.av2_mode == 'test':
211-
print(f"\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.load_checkpoint_path}")
212+
print(f"\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.checkpoint}")
212213
print(f"Test results saved in: {self.save_res_path}, Please run submit command and upload to online leaderboard for results.")
213214
if self.leaderboard_version == 1:
214215
print(f"\nevalai challenge 2010 phase 4018 submit --file {self.save_res_path}.zip --large --private\n")
@@ -221,8 +222,8 @@ def on_validation_epoch_end(self):
221222
return
222223

223224
if self.av2_mode == 'val':
224-
print(f"\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.load_checkpoint_path}")
225-
print(f"More details parameters and training status are in checkpoints")
225+
print(f"\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.checkpoint}")
226+
print(f"More details parameters and training status are in the checkpoint file.")
226227

227228
self.metrics.normalize()
228229

@@ -339,7 +340,7 @@ def test_step(self, batch, batch_idx):
339340

340341
def on_test_epoch_end(self):
341342
self.model.timer.print(random_colors=False, bold=False)
342-
print(f"\n\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.load_checkpoint_path}")
343+
print(f"\n\nModel: {self.model.__class__.__name__}, Checkpoint from: {self.checkpoint}")
343344
print(f"We already write the flow_est into the dataset, please run following commend to visualize the flow. Copy and paste it to your terminal:")
344345
print(f"python tools/visualization.py --res_name '{self.vis_name}' --data_dir {self.dataset_path}")
345346
print(f"Enjoy! ^v^ ------ \n")

0 commit comments

Comments
 (0)