Skip to content

Commit ec5eae0

Browse files
committed
Minor fixes to scripts
1. Remove redifiende flags 2. Comment out a debugging line 3. Pipe in a max global steps with a flag
1 parent a673f58 commit ec5eae0

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ RUN cd /algorithmic-efficiency && git fetch origin
8484
RUN cd /algorithmic-efficiency && git pull
8585

8686
# Todo: remove this, this is temporary for developing
87-
COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh
87+
# COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh
8888
RUN chmod a+x /algorithmic-efficiency/docker/scripts/startup.sh
8989

9090
ENTRYPOINT ["bash", "/algorithmic-efficiency/docker/scripts/startup.sh"]

scoring/utils/slurm/make_job_config.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,14 @@
3535
'experiments',
3636
'Path to experiment dir where logs will be saved.',
3737
)
38-
flags.DEFINE_string(
39-
'experiment_dir',
40-
'experiments/',
41-
'Path to experiment dir where logs will be saved.',
42-
)
4338
flags.DEFINE_enum(
4439
'framework',
4540
'jax',
4641
enum_values=['jax', 'pytorch'],
4742
help='Can be either pytorch or jax.',
4843
)
4944
flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.')
45+
flags.DEFINE_integer('max_global_steps', None, 'Number of steps to run each workload for')
5046
flags.DEFINE_enum(
5147
'tuning_ruleset',
5248
'self',
@@ -74,6 +70,7 @@
7470
'librispeech_deepspeech': {'dataset': 'librispeech'},
7571
'criteo1tb': {'dataset': 'criteo1tb'},
7672
'librispeech_conformer': {'dataset': 'librispeech'},
73+
'finewebedu_lm': {'dataset': 'fineweb_edu_10B'}
7774
}
7875

7976

@@ -112,6 +109,8 @@ def main(_):
112109
job['hparam_end_index'] = hparam_index + 1
113110
job['tuning_search_space'] = FLAGS.tuning_search_space
114111
job['tuning_ruleset'] = FLAGS.tuning_ruleset
112+
if FLAGS.max_global_steps:
113+
job['max_global_steps'] = FLAGS.max_global_steps
115114
jobs.append(job)
116115
print(job)
117116

@@ -130,6 +129,8 @@ def main(_):
130129
job['rng_seed'] = seed
131130
job['tuning_ruleset'] = FLAGS.tuning_ruleset
132131
job['num_tuning_trials'] = 1
132+
if FLAGS.max_global_steps:
133+
job['max_global_steps'] = FLAGS.max_global_steps
133134

134135
jobs.append(job)
135136
print(job)

0 commit comments

Comments
 (0)