Skip to content

Commit e09bbf5

Browse files
add train_state to all instances of `update_params', passing it by (shallow) copy in submission_runner
1 parent a23b5ea commit e09bbf5

33 files changed

Lines changed: 88 additions & 25 deletions

File tree

DOCUMENTATION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def update_params(
199199
batch: Dict[str, Tensor],
200200
loss_type: LossType,
201201
optimizer_state: OptimizerState,
202+
train_state: Dict[str, Any],
202203
eval_results: List[Tuple[int, float]],
203204
global_step: int,
204205
rng: RandomState

algorithmic_efficiency/spec.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def init_optimizer_state(workload: Workload,
401401
Dict[str, Tensor],
402402
LossType,
403403
OptimizerState,
404+
Dict[str, Any],
404405
List[Tuple[int, float]],
405406
int,
406407
RandomState
@@ -422,6 +423,7 @@ def update_params(workload: Workload,
422423
batch: Dict[str, Tensor],
423424
loss_type: LossType,
424425
optimizer_state: OptimizerState,
426+
train_state: Dict[str, Any],
425427
eval_results: List[Tuple[int, float]],
426428
global_step: int,
427429
rng: RandomState) -> UpdateReturn:

prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,14 @@ def update_params(workload: spec.Workload,
260260
batch: Dict[str, spec.Tensor],
261261
loss_type: spec.LossType,
262262
optimizer_state: spec.OptimizerState,
263+
train_state: Dict[str, Any],
263264
eval_results: List[Tuple[int, float]],
264265
global_step: int,
265266
rng: spec.RandomState) -> spec.UpdateReturn:
266267
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
267268
del current_params_types
268269
del loss_type
270+
del train_state
269271
del eval_results
270272

271273
optimizer_state, opt_update_fn = optimizer_state

prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,14 @@ def update_params(workload: spec.Workload,
260260
batch: Dict[str, spec.Tensor],
261261
loss_type: spec.LossType,
262262
optimizer_state: spec.OptimizerState,
263+
train_state: Dict[str, Any],
263264
eval_results: List[Tuple[int, float]],
264265
global_step: int,
265266
rng: spec.RandomState) -> spec.UpdateReturn:
266267
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
267268
del current_params_types
268269
del loss_type
270+
del train_state
269271
del eval_results
270272

271273
optimizer_state, opt_update_fn = optimizer_state

prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Dict, Iterator, List, Tuple
4+
from typing import Dict, Iterator, List, Tuple, Any
55

66
from absl import logging
77
import torch
@@ -232,12 +232,14 @@ def update_params(workload: spec.Workload,
232232
batch: Dict[str, spec.Tensor],
233233
loss_type: spec.LossType,
234234
optimizer_state: spec.OptimizerState,
235+
train_state: Dict[str, Any],
235236
eval_results: List[Tuple[int, float]],
236237
global_step: int,
237238
rng: spec.RandomState) -> spec.UpdateReturn:
238239
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
239240
del current_params_types
240241
del loss_type
242+
del train_state
241243
del eval_results
242244

243245
current_model = current_param_container

prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Dict, Iterator, List, Tuple
4+
from typing import Dict, Iterator, List, Tuple, Any
55

66
from absl import logging
77
import torch
@@ -232,12 +232,14 @@ def update_params(workload: spec.Workload,
232232
batch: Dict[str, spec.Tensor],
233233
loss_type: spec.LossType,
234234
optimizer_state: spec.OptimizerState,
235+
train_state: Dict[str, Any],
235236
eval_results: List[Tuple[int, float]],
236237
global_step: int,
237238
rng: spec.RandomState) -> spec.UpdateReturn:
238239
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
239240
del current_params_types
240241
del loss_type
242+
del train_state
241243
del eval_results
242244

243245
current_model = current_param_container

prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,14 @@ def update_params(workload: spec.Workload,
272272
batch: Dict[str, spec.Tensor],
273273
loss_type: spec.LossType,
274274
optimizer_state: spec.OptimizerState,
275+
train_state: Dict[str, Any],
275276
eval_results: List[Tuple[int, float]],
276277
global_step: int,
277278
rng: spec.RandomState) -> spec.UpdateReturn:
278279
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
279280
del current_params_types
280281
del loss_type
282+
del train_state
281283
del eval_results
282284
del hyperparameters
283285

prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,14 @@ def update_params(workload: spec.Workload,
272272
batch: Dict[str, spec.Tensor],
273273
loss_type: spec.LossType,
274274
optimizer_state: spec.OptimizerState,
275+
train_state: Dict[str, Any],
275276
eval_results: List[Tuple[int, float]],
276277
global_step: int,
277278
rng: spec.RandomState) -> spec.UpdateReturn:
278279
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
279280
del current_params_types
280281
del loss_type
282+
del train_state
281283
del eval_results
282284
del hyperparameters
283285

prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Dict, Iterator, List, Tuple
4+
from typing import Dict, Iterator, List, Tuple, Any
55

66
from absl import logging
77
import torch
@@ -244,12 +244,14 @@ def update_params(workload: spec.Workload,
244244
batch: Dict[str, spec.Tensor],
245245
loss_type: spec.LossType,
246246
optimizer_state: spec.OptimizerState,
247+
train_state: Dict[str, Any],
247248
eval_results: List[Tuple[int, float]],
248249
global_step: int,
249250
rng: spec.RandomState) -> spec.UpdateReturn:
250251
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
251252
del current_params_types
252253
del loss_type
254+
del train_state
253255
del eval_results
254256
del hyperparameters
255257

prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

33
import math
4-
from typing import Dict, Iterator, List, Tuple
4+
from typing import Dict, Iterator, List, Tuple, Any
55

66
from absl import logging
77
import torch
@@ -244,12 +244,14 @@ def update_params(workload: spec.Workload,
244244
batch: Dict[str, spec.Tensor],
245245
loss_type: spec.LossType,
246246
optimizer_state: spec.OptimizerState,
247+
train_state: Dict[str, Any],
247248
eval_results: List[Tuple[int, float]],
248249
global_step: int,
249250
rng: spec.RandomState) -> spec.UpdateReturn:
250251
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
251252
del current_params_types
252253
del loss_type
254+
del train_state
253255
del eval_results
254256
del hyperparameters
255257

0 commit comments

Comments
 (0)