Skip to content

Commit 56ab820

Browse files
dropout piping dev JAX workloads
1 parent 76180d9 commit 56ab820

2 files changed

Lines changed: 4 additions & 0 deletions

File tree

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ def model_fn(
105105
rng: spec.RandomState,
106106
update_batch_norm: bool,
107107
use_running_average_bn: Optional[bool] = None,
108+
dropout_rate: float = 0.0,
108109
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
109110
del mode
110111
del rng
112+
del dropout_rate
111113
variables = {'params': params, **model_state}
112114
if update_batch_norm:
113115
logits, new_model_state = self._model.apply(

algoperf/workloads/mnist/mnist_jax/workload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ def model_fn(
4848
mode: spec.ForwardPassMode,
4949
rng: spec.RandomState,
5050
update_batch_norm: bool,
51+
dropout_rate: float = 0.0,
5152
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
5253
del model_state
5354
del rng
5455
del update_batch_norm
56+
del dropout_rate
5557
train = mode == spec.ForwardPassMode.TRAIN
5658
logits_batch = self._model.apply(
5759
{'params': params},

0 commit comments

Comments
 (0)