Skip to content

Commit c5c36c2

Browse files
committed
add seperate model_fn for deepspeech jax without use_running_average_bn
1 parent 087fd5c commit c5c36c2

3 files changed

Lines changed: 32 additions & 1 deletion

File tree

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def model_fn(
113113
variables = {'params': params, **model_state}
114114
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
115115
is_train_mode = mode == spec.ForwardPassMode.TRAIN
116-
print(type(use_running_average_bn))
117116
if update_batch_norm or is_train_mode:
118117
(logits, logit_paddings), new_model_state = self._model.apply(
119118
variables,

algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,37 @@ def init_model_fn(
5555
model_state = jax_utils.replicate(model_state)
5656
params = jax_utils.replicate(params)
5757
return params, model_state
58+
59+
def model_fn(
60+
self,
61+
params: spec.ParameterContainer,
62+
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
63+
model_state: spec.ModelAuxiliaryState,
64+
mode: spec.ForwardPassMode,
65+
rng: spec.RandomState,
66+
update_batch_norm: bool,
67+
use_running_average_bn: Optional[bool] = None
68+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
69+
variables = {'params': params, **model_state}
70+
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
71+
is_train_mode = mode == spec.ForwardPassMode.TRAIN
72+
if update_batch_norm or is_train_mode:
73+
(logits, logit_paddings), new_model_state = self._model.apply(
74+
variables,
75+
inputs,
76+
input_paddings,
77+
train=True,
78+
rngs={'dropout' : rng},
79+
mutable=['batch_stats'])
80+
return (logits, logit_paddings), new_model_state
81+
else:
82+
logits, logit_paddings = self._model.apply(
83+
variables,
84+
inputs,
85+
input_paddings,
86+
train=False,
87+
mutable=False)
88+
return (logits, logit_paddings), model_state
5889

5990
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
6091
return param_key == 'Dense_0'

tests/reference_algorithm_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def _test_submission(workload_name,
408408
workload_path=workload_metadata['workload_path'],
409409
workload_class_name=workload_metadata['workload_class_name'],
410410
return_class=True)
411+
print(f'Workload class for {workload_name} is {workload_class}')
411412

412413
submission_module_path = workloads.convert_filepath_to_module(submission_path)
413414
submission_module = importlib.import_module(submission_module_path)

0 commit comments

Comments
 (0)