Skip to content

Commit ce8eb18

Browse files
ensure backward compatibility
1 parent 1f59285 commit ce8eb18

2 files changed

Lines changed: 13 additions & 6 deletions

File tree

algorithmic_efficiency/spec.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -401,10 +401,10 @@ def init_optimizer_state(workload: Workload,
401401
Dict[str, Tensor],
402402
LossType,
403403
OptimizerState,
404-
Dict[str, Any],
405404
List[Tuple[int, float]],
406405
int,
407-
RandomState
406+
RandomState,
407+
Optional[Dict[str, Any]]
408408
],
409409
UpdateReturn]
410410

@@ -423,10 +423,10 @@ def update_params(workload: Workload,
423423
batch: Dict[str, Tensor],
424424
loss_type: LossType,
425425
optimizer_state: OptimizerState,
426-
train_state: Dict[str, Any],
427426
eval_results: List[Tuple[int, float]],
428427
global_step: int,
429-
rng: RandomState) -> UpdateReturn:
428+
rng: RandomState,
429+
train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn:
430430
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
431431
pass
432432

submission_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
import datetime
1818
import gc
1919
import importlib
20+
from inspect import signature
2021
import itertools
2122
import json
2223
import os
2324
import struct
2425
import time
26+
from types import MappingProxyType
2527
from typing import Any, Dict, Optional, Tuple
2628

2729
from absl import app
@@ -273,6 +275,10 @@ def train_once(
273275
hyperparameters,
274276
opt_init_rng)
275277
logging.info('Initializing metrics bundle.')
278+
279+
# Check if 'train_state' is in the function signature
280+
needs_train_state = 'train_state' in signature(update_params).parameters
281+
276282
# Bookkeeping.
277283
train_state = {
278284
'validation_goal_reached': False,
@@ -357,10 +363,11 @@ def train_once(
357363
batch=batch,
358364
loss_type=workload.loss_type,
359365
optimizer_state=optimizer_state,
360-
train_state=train_state.copy(),
361366
eval_results=eval_results,
362367
global_step=global_step,
363-
rng=update_rng)
368+
rng=update_rng,
369+
**({'train_state': MappingProxyType(train_state)}
370+
if needs_train_state else {}))
364371
except spec.TrainingCompleteError:
365372
train_state['training_complete'] = True
366373
global_step += 1

0 commit comments

Comments
 (0)