File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -401,10 +401,10 @@ def init_optimizer_state(workload: Workload,
401401 Dict [str , Tensor ],
402402 LossType ,
403403 OptimizerState ,
404- Dict [str , Any ],
405404 List [Tuple [int , float ]],
406405 int ,
407- RandomState
406+ RandomState ,
407+ Optional [Dict [str , Any ]]
408408],
409409 UpdateReturn ]
410410
@@ -423,10 +423,10 @@ def update_params(workload: Workload,
423423 batch : Dict [str , Tensor ],
424424 loss_type : LossType ,
425425 optimizer_state : OptimizerState ,
426- train_state : Dict [str , Any ],
427426 eval_results : List [Tuple [int , float ]],
428427 global_step : int ,
429- rng : RandomState ) -> UpdateReturn :
428+ rng : RandomState ,
429+ train_state : Optional [Dict [str , Any ]] = None ) -> UpdateReturn :
430430 """Return (updated_optimizer_state, updated_params, updated_model_state)."""
431431 pass
432432
Original file line number Diff line number Diff line change 1717import datetime
1818import gc
1919import importlib
20+ from inspect import signature
2021import itertools
2122import json
2223import os
2324import struct
2425import time
26+ from types import MappingProxyType
2527from typing import Any , Dict , Optional , Tuple
2628
2729from absl import app
@@ -273,6 +275,10 @@ def train_once(
273275 hyperparameters ,
274276 opt_init_rng )
275277 logging .info ('Initializing metrics bundle.' )
278+
279+ # Check if 'train_state' is in the function signature
280+ needs_train_state = 'train_state' in signature (update_params ).parameters
281+
276282 # Bookkeeping.
277283 train_state = {
278284 'validation_goal_reached' : False ,
@@ -357,10 +363,11 @@ def train_once(
357363 batch = batch ,
358364 loss_type = workload .loss_type ,
359365 optimizer_state = optimizer_state ,
360- train_state = train_state .copy (),
361366 eval_results = eval_results ,
362367 global_step = global_step ,
363- rng = update_rng )
368+ rng = update_rng ,
369+ ** ({'train_state' : MappingProxyType (train_state )}
370+ if needs_train_state else {}))
364371 except spec .TrainingCompleteError :
365372 train_state ['training_complete' ] = True
366373 global_step += 1
You can’t perform that action at this time.
0 commit comments