3535 'experiments' ,
3636 'Path to experiment dir where logs will be saved.' ,
3737)
38- flags .DEFINE_string (
39- 'experiment_dir' ,
40- 'experiments/' ,
41- 'Path to experiment dir where logs will be saved.' ,
42- )
4338flags .DEFINE_enum (
4439 'framework' ,
4540 'jax' ,
4641 enum_values = ['jax' , 'pytorch' ],
4742 help = 'Can be either pytorch or jax.' ,
4843)
4944flags .DEFINE_integer ('seed' , 0 , 'RNG seed to to generate study seeds from.' )
45+ flags .DEFINE_integer ('max_global_steps' , None , 'Number of steps to run each workload for' )
5046flags .DEFINE_enum (
5147 'tuning_ruleset' ,
5248 'self' ,
7470 'librispeech_deepspeech' : {'dataset' : 'librispeech' },
7571 'criteo1tb' : {'dataset' : 'criteo1tb' },
7672 'librispeech_conformer' : {'dataset' : 'librispeech' },
73+ 'finewebedu_lm' : {'dataset' : 'fineweb_edu_10B' }
7774}
7875
7976
@@ -112,6 +109,8 @@ def main(_):
112109 job ['hparam_end_index' ] = hparam_index + 1
113110 job ['tuning_search_space' ] = FLAGS .tuning_search_space
114111 job ['tuning_ruleset' ] = FLAGS .tuning_ruleset
112+ if FLAGS .max_global_steps :
113+ job ['max_global_steps' ] = FLAGS .max_global_steps
115114 jobs .append (job )
116115 print (job )
117116
@@ -130,6 +129,8 @@ def main(_):
130129 job ['rng_seed' ] = seed
131130 job ['tuning_ruleset' ] = FLAGS .tuning_ruleset
132131 job ['num_tuning_trials' ] = 1
132+ if FLAGS .max_global_steps :
133+ job ['max_global_steps' ] = FLAGS .max_global_steps
133134
134135 jobs .append (job )
135136 print (job )
0 commit comments