Skip to content

Commit 42d9ae1

Browse files
Merge pull request mlcommons#855 from mlcommons/dev
Dev -> main
2 parents 1d81455 + ed13e81 commit 42d9ae1

3 files changed

Lines changed: 5 additions & 1 deletion

File tree

prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

3+
import collections
34
import math
45
from typing import Any, Dict, Iterator, List, Optional, Tuple
56

@@ -24,6 +25,7 @@
2425
"weight_decay": 0.08121616522670176,
2526
"warmup_factor": 0.02
2627
}
28+
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
2729

2830

2931
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.

prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

3+
import collections
34
import math
45
from typing import Any, Dict, Iterator, List, Optional, Tuple
56

@@ -24,6 +25,7 @@
2425
"weight_decay": 0.08121616522670176,
2526
"warmup_factor": 0.02
2627
}
28+
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
2729

2830

2931
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.

submission_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def score_submission_on_workload(workload: spec.Workload,
668668
score, _ = train_once(
669669
workload, workload_name, global_batch_size, global_eval_batch_size,
670670
data_dir, imagenet_v2_data_dir,
671-
init_optimizer_state, update_params, data_selection,
671+
init_optimizer_state, update_params, data_selection, prepare_for_eval,
672672
None, rng_seed, rng, profiler, max_global_steps, log_dir,
673673
save_checkpoints=save_checkpoints)
674674
return score

0 commit comments

Comments
 (0)