Skip to content

Commit b21be29

Browse files
Merge pull request #921 from mlcommons/dev
Dev -> main
2 parents f364a0b + d77c538 commit b21be29

10 files changed

Lines changed: 377 additions & 121 deletions

docker/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ RUN if [ "$framework" = "jax" ] ; then \
8383
RUN cd /algorithmic-efficiency && git fetch origin
8484
RUN cd /algorithmic-efficiency && git pull
8585

86-
# Todo: remove this, this is temporary for developing
87-
COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh
86+
# Uncomment this for developing purposes
87+
# COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh
8888
RUN chmod a+x /algorithmic-efficiency/docker/scripts/startup.sh
8989

9090
ENTRYPOINT ["bash", "/algorithmic-efficiency/docker/scripts/startup.sh"]

docker/build_docker_images.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ done
1717

1818
# Artifact repostiory
1919
if [ "$PROJECT" = "mlcommons-algoperf" ]; then
20-
ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
20+
ARTIFACT_REPO="europe-west4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
2121
else
2222
ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo"
2323
fi

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ dependencies = [
4646
"clu==0.0.12",
4747
"matplotlib>=3.9.2",
4848
"tabulate==0.9.0",
49-
"wandb==0.21.0"
49+
"wandb==0.21.0",
50+
"importlib_resources"
5051
]
5152

5253
[build-system]

scoring/performance_profile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
# workloads and rules for the scoring to be correct.
6060
# We do not use the workload registry since it contains test and development
6161
# workloads as well.
62-
NUM_BASE_WORKLOADS = 8
62+
NUM_BASE_WORKLOADS = 9
6363
NUM_VARIANT_WORKLOADS = 0
6464
NUM_TRIALS = 5
6565
NUM_STUDIES = 3

scoring/score_submissions.py

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@
7575
FLAGS = flags.FLAGS
7676

7777

78-
def get_summary_df(workload, workload_df, include_test_split=False):
78+
def get_summary_df(workload, workload_df):
7979
print(f' WORKLOAD: {workload}')
8080
validation_metric, validation_target = (
81-
scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
81+
scoring_utils.get_workload_metrics_and_targets(workload)
8282
)
8383

8484
is_minimized = performance_profile.check_if_minimized(validation_metric)
@@ -127,7 +127,7 @@ def get_summary_df(workload, workload_df, include_test_split=False):
127127

128128
# compute the step times
129129
def delta(series):
130-
return series.shift(1, fill_value=0) - series
130+
return series.apply(lambda x: np.diff(x, prepend=0))
131131

132132
accumulated_time_intervals = delta(workload_df['accumulated_submission_time'])
133133
step_intervals = delta(workload_df['global_step'])
@@ -136,57 +136,27 @@ def delta(series):
136136
f'WARNING: The number of evals may be too low to calculate reliable step time for {workload}'
137137
)
138138

139-
summary_df['step_time (s)'] = np.median(
140-
(accumulated_time_intervals / step_intervals).iloc[0]
141-
)
142-
143-
summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload)
144-
145-
# test metrics
146-
if include_test_split:
147-
test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(
148-
workload, split='test'
139+
# Flatten all intervals from all trials and take the global median
140+
with np.errstate(divide='ignore', invalid='ignore'):
141+
all_ratios = np.concatenate(
142+
(accumulated_time_intervals / step_intervals).values
149143
)
144+
summary_df['step_time (s)'] = np.nanmedian(all_ratios)
150145

151-
summary_df['test target metric name'] = test_metric
152-
summary_df['test target metric value'] = test_target
153-
154-
summary_df['test target reached'] = (
155-
workload_df[test_metric]
156-
.apply(lambda x: target_op(x, test_target))
157-
.apply(np.any)
158-
)
159-
summary_df['best metric value on test'] = workload_df[test_metric].apply(
160-
lambda x: best_op(x)
161-
)
162-
workload_df['index best eval on test'] = workload_df[test_metric].apply(
163-
lambda x: idx_op(x)
164-
)
165-
summary_df['time to best eval on test (s)'] = workload_df.apply(
166-
lambda x: x['accumulated_submission_time'][x['index best eval on test']],
167-
axis=1,
168-
)
169-
summary_df['time to target on test (s)'] = summary_df.apply(
170-
lambda x: x['time to best eval on test (s)']
171-
if x['test target reached']
172-
else np.inf,
173-
axis=1,
174-
)
146+
summary_df['step_hint'] = scoring_utils.get_workload_stephint(workload)
175147

176148
return summary_df
177149

178150

179-
def get_submission_summary(df, include_test_split=False):
151+
def get_submission_summary(df):
180152
"""Summarizes the submission results into metric and time tables
181153
organized by workload.
182154
"""
183155

184156
dfs = []
185157
print(df)
186158
for workload, group in df.groupby('workload'):
187-
summary_df = get_summary_df(
188-
workload, group, include_test_split=include_test_split
189-
)
159+
summary_df = get_summary_df(workload, group)
190160
dfs.append(summary_df)
191161

192162
df = pd.concat(dfs)

scoring/scoring_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def get_experiment_df(experiment_dir):
214214

215215

216216
## Get workload properties
217-
def get_workload_metrics_and_targets(workload, split='validation'):
217+
def get_workload_metrics_and_targets(workload):
218218
"""Returns workload target metric name and value."""
219219
workload_name = re.match(WORKLOAD_NAME_PATTERN, workload).group(1)
220220
framework = re.match(WORKLOAD_NAME_PATTERN, workload).group(2)
@@ -233,12 +233,8 @@ def get_workload_metrics_and_targets(workload, split='validation'):
233233
workload_init_kwargs=workload_init_kwargs,
234234
)
235235
metric_name = workload_obj.target_metric_name
236-
if split == 'validation':
237-
metric = f'validation/{metric_name}'
238-
target = workload_obj.validation_target_value
239-
elif split == 'test':
240-
metric = f'test/{metric_name}'
241-
target = workload_obj.test_target_value
236+
metric = f'validation/{metric_name}'
237+
target = workload_obj.validation_target_value
242238
return metric, target
243239

244240

scoring/utils/slurm/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,29 @@ LOGS_BUCKET="algoperf-runs-internal"
4848
sbatch run_jobs.sh
4949
```
5050

51+
## Convenient bash script to launch SLURM jobs
52+
53+
The run_submissions.sh script does all the steps above for you. It is intended to be used on a slurm login node. It however does expect a very specific directory structure. You need to be in the $HOME dir with the algorithmic-efficiency and submissions_algorithms git repos in the home dir.
54+
55+
```
56+
$USER$@$USER$:~/$ tree -L 1
57+
.
58+
├── algorithmic-efficiency
59+
└── submissions_algorithms
60+
```
61+
62+
And you run the script with a command like so:
63+
64+
```
65+
./algorithmic-efficiency/scoring/utils/slurm/run_submission.sh \
66+
--submission_path submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2
67+
--dry_run false
68+
```
69+
70+
The submission path points to the dir where the submission exists (in the submissions git repo). `dry_run` is set to true by default (which limits max global steps to 10) to prevent accidental commands from wasting resources. Explicitly set it to false for full runs.
71+
72+
The script will figure out the rest and run them for you (creating the config, saving it to a path with a reasonable name, and running the sbatch script with the right flags).
73+
5174
# Set up new SLURM cluster
5275

5376
If you are setting up a new cluster, we recommend using the [HPC toolkit to set up a SLURM cluster](https://cloud.google.com/cluster-toolkit/docs/quickstarts/slurm-cluster).

scoring/utils/slurm/make_job_config.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import json
1111
import os
12+
import struct
1213

1314
import jax
1415
from absl import app, flags
@@ -17,8 +18,6 @@
1718
TUNING_SEARCH_SPACE = (
1819
'reference_algorithms/paper_baselines/adamw/tuning_search_space.json'
1920
)
20-
NUM_TUNING_TRIALS = 3 # For external tuning ruleset
21-
NUM_STUDIES = 3
2221

2322
flags.DEFINE_string(
2423
'submission_path',
@@ -35,11 +34,6 @@
3534
'experiments',
3635
'Path to experiment dir where logs will be saved.',
3736
)
38-
flags.DEFINE_string(
39-
'experiment_dir',
40-
'experiments/',
41-
'Path to experiment dir where logs will be saved.',
42-
)
4337
flags.DEFINE_enum(
4438
'framework',
4539
'jax',
@@ -56,14 +50,13 @@
5650
flags.DEFINE_string(
5751
'workloads', None, help='Comma seperated list of workloads to run.'
5852
)
59-
flags.DEFINE_integer('num_studies', NUM_STUDIES, help='Number of studies.')
53+
flags.DEFINE_integer('num_studies', None, help='Number of studies.')
54+
flags.DEFINE_integer('num_tuning_trials', None, help='Number of tuning trials.')
6055

6156
FLAGS = flags.FLAGS
6257

6358
MIN_INT = -(2 ** (31))
6459
MAX_INT = 2 ** (31) - 1
65-
NUM_TUNING_TRIALS = 5 # For external tuning ruleset
66-
NUM_STUDIES = 3
6760

6861
WORKLOADS = {
6962
'imagenet_resnet': {'dataset': 'imagenet'},
@@ -74,6 +67,12 @@
7467
'librispeech_deepspeech': {'dataset': 'librispeech'},
7568
'criteo1tb': {'dataset': 'criteo1tb'},
7669
'librispeech_conformer': {'dataset': 'librispeech'},
70+
'finewebedu_lm': {'dataset': 'fineweb_edu_10B'},
71+
}
72+
73+
RULESET_CONFIGS = {
74+
'self': {'num_studies': 3, 'num_tuning_trials': 1},
75+
'external': {'num_studies': 3, 'num_tuning_trials': 5},
7776
}
7877

7978

@@ -83,17 +82,31 @@ def main(_):
8382
else:
8483
workloads = FLAGS.workloads.split(',')
8584

86-
key = jax.random.key(FLAGS.seed)
85+
if not FLAGS.seed:
86+
FLAGS.seed = struct.unpack('I', os.urandom(4))[0]
87+
88+
# Set defaults based on tuning_ruleset if not provided by user
89+
num_studies = FLAGS.num_studies
90+
if num_studies is None:
91+
num_studies = RULESET_CONFIGS[FLAGS.tuning_ruleset]['num_studies']
92+
93+
num_tuning_trials = FLAGS.num_tuning_trials
94+
if num_tuning_trials is None:
95+
num_tuning_trials = RULESET_CONFIGS[FLAGS.tuning_ruleset][
96+
'num_tuning_trials'
97+
]
98+
99+
key = jax.random.PRNGKey(FLAGS.seed)
87100

88101
jobs = []
89102

90103
for workload in workloads:
91104
# Fold in hash(workload) mod(max(uint32))
92105
workload_key = jax.random.fold_in(key, hash(workload) % (2**32 - 1))
93-
for study_index in range(NUM_STUDIES):
106+
for study_index in range(num_studies):
94107
study_key = jax.random.fold_in(workload_key, study_index)
95108
if FLAGS.tuning_ruleset == 'external':
96-
for hparam_index in range(NUM_TUNING_TRIALS):
109+
for hparam_index in range(num_tuning_trials):
97110
run_key = jax.random.fold_in(study_key, hparam_index)
98111
seed = jax.random.randint(run_key, (1,), MIN_INT, MAX_INT)[0].item()
99112
print(seed)
@@ -107,7 +120,7 @@ def main(_):
107120
job['experiment_dir'] = study_dir
108121
job['rng_seed'] = seed
109122
job['tuning_ruleset'] = FLAGS.tuning_ruleset
110-
job['num_tuning_trials'] = NUM_TUNING_TRIALS
123+
job['num_tuning_trials'] = num_tuning_trials
111124
job['hparam_start_index'] = hparam_index
112125
job['hparam_end_index'] = hparam_index + 1
113126
job['tuning_search_space'] = FLAGS.tuning_search_space

0 commit comments

Comments
 (0)