Skip to content

Commit 876cfd4

Browse files
committed
small fix on ssl_label to None if under supervise training.
1 parent 3c82d48 commit 876cfd4

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def main(cfg):
5757
precheck_cfg_valid(cfg)
5858
pl.seed_everything(cfg.seed, workers=True)
5959

60-
train_dataset = HDF5Dataset(cfg.train_data, n_frames=cfg.num_frames, ssl_label=cfg.ssl_label)
60+
train_dataset = HDF5Dataset(cfg.train_data, n_frames=cfg.num_frames, ssl_label=cfg.get('ssl_label', None))
6161
train_loader = DataLoader(train_dataset,
6262
batch_size=cfg.batch_size,
6363
shuffle=True,
@@ -134,7 +134,7 @@ def main(cfg):
134134
print("Initiating wandb and trainer successfully. ^V^ ")
135135
print(f"We will use {cfg.gpus} GPUs to train the model. Check the checkpoints in {output_dir} checkpoints folder.")
136136
print("Total Train Dataset Size: ", len(train_dataset))
137-
if cfg.add_seloss is not None and cfg.loss_fn in ['seflowLoss', 'seflowppLoss']:
137+
if cfg.get('add_seloss', None) is not None and cfg.loss_fn in ['seflowLoss', 'seflowppLoss']:
138138
print(f"Note: We are in **self-supervised** training now. No ground truth label is used.")
139139
print(f"We will use these loss items in {cfg.loss_fn}: {cfg.add_seloss}")
140140
print("-"*40+"\n")

0 commit comments

Comments
 (0)