You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: DOCUMENTATION.md
+2Lines changed: 2 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -199,6 +199,7 @@ def update_params(
199
199
batch: Dict[str, Tensor],
200
200
loss_type: LossType,
201
201
optimizer_state: OptimizerState,
202
+
train_state: Dict[str, Any],
202
203
eval_results: List[Tuple[int, float]],
203
204
global_step: int,
204
205
rng: RandomState
@@ -212,6 +213,7 @@ def update_params(
212
213
- The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used.
213
214
- Allowed to update state for the optimizer.
214
215
- Uses the `model_fn` of the `workload`in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state).
216
+
- The submission can access the elapsed training time and get further information about the evaluation through `train_state`.
215
217
- The submission can access the target evaluation metric via the `workload` variable.
216
218
-**A call to this function will be considered a step**
217
219
- The time between a call to this function and the next call to this function will be considered the per-step time.
0 commit comments