Skip to content

Commit f574bf0

Browse files
committed
add use_running_average_bn arg for jax
1 parent b24812f commit f574bf0

6 files changed

Lines changed: 63 additions & 31 deletions

File tree

algorithmic_efficiency/workloads/cifar/cifar_jax/models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,16 @@ class ResNet(nn.Module):
2828
@nn.compact
2929
def __call__(self,
3030
x: spec.Tensor,
31-
update_batch_norm: bool = True) -> spec.Tensor:
31+
update_batch_norm: bool = True,
32+
use_running_average_bn: bool = None) -> spec.Tensor:
3233
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
34+
35+
# Preserve default behavior for backwards compatibility
36+
if use_running_average_bn is None:
37+
use_running_average_bn = not update_batch_norm
3338
norm = functools.partial(
3439
nn.BatchNorm,
35-
use_running_average=not update_batch_norm,
40+
use_running_average=use_running_average_bn,
3641
momentum=0.9,
3742
epsilon=1e-5,
3843
dtype=self.dtype)

algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def model_fn(
110110
model_state: spec.ModelAuxiliaryState,
111111
mode: spec.ForwardPassMode,
112112
rng: spec.RandomState,
113-
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
113+
update_batch_norm: bool,
114+
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
114115
del mode
115116
del rng
116117
variables = {'params': params, **model_state}
@@ -119,14 +120,16 @@ def model_fn(
119120
variables,
120121
augmented_and_preprocessed_input_batch['inputs'],
121122
update_batch_norm=update_batch_norm,
122-
mutable=['batch_stats'])
123+
mutable=['batch_stats'],
124+
use_running_average_bn=use_running_average_bn)
123125
return logits, new_model_state
124126
else:
125127
logits = self._model.apply(
126128
variables,
127129
augmented_and_preprocessed_input_batch['inputs'],
128130
update_batch_norm=update_batch_norm,
129-
mutable=False)
131+
mutable=False,
132+
use_running_average_bn=use_running_average_bn)
130133
return logits, model_state
131134

132135
# Does NOT apply regularization, which is left to the submitter to do in

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,16 @@ class ResNet(nn.Module):
8484
@nn.compact
8585
def __call__(self,
8686
x: spec.Tensor,
87-
update_batch_norm: bool = True) -> spec.Tensor:
87+
update_batch_norm: bool = True,
88+
use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
8889
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
90+
91+
# Preserve default behavior for backwards compatibility
92+
if use_running_average_bn is None:
93+
use_running_average_bn = not update_batch_norm
8994
norm = functools.partial(
9095
nn.BatchNorm,
91-
use_running_average=not update_batch_norm,
96+
use_running_average=use_running_average_bn,
9297
momentum=0.9,
9398
epsilon=1e-5,
9499
dtype=self.dtype)

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def model_fn(
148148
model_state: spec.ModelAuxiliaryState,
149149
mode: spec.ForwardPassMode,
150150
rng: spec.RandomState,
151-
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
151+
update_batch_norm: bool,
152+
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
152153
del mode
153154
del rng
154155
variables = {'params': params, **model_state}
@@ -157,14 +158,16 @@ def model_fn(
157158
variables,
158159
augmented_and_preprocessed_input_batch['inputs'],
159160
update_batch_norm=update_batch_norm,
160-
mutable=['batch_stats'])
161+
mutable=['batch_stats'],
162+
use_running_average_bn=use_running_average_bn)
161163
return logits, new_model_state
162164
else:
163165
logits = self._model.apply(
164166
variables,
165167
augmented_and_preprocessed_input_batch['inputs'],
166168
update_batch_norm=update_batch_norm,
167-
mutable=False)
169+
mutable=False,
170+
use_running_average_bn=use_running_average_bn)
168171
return logits, model_state
169172

170173
# Does NOT apply regularization, which is left to the submitter to do in

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -454,15 +454,20 @@ def setup(self):
454454
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
455455

456456
@nn.compact
457-
def __call__(self, inputs, input_paddings, train):
457+
def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn):
458458
rank = inputs.ndim
459459
reduce_over_dims = list(range(0, rank - 1))
460460

461461
padding = jnp.expand_dims(input_paddings, -1)
462462
momentum = self.config.batch_norm_momentum
463463
epsilon = self.config.batch_norm_epsilon
464464

465-
if train:
465+
if use_running_average_bn:
466+
mean = self.ra_mean.value
467+
var = self.ra_var.value
468+
469+
else:
470+
# compute batch statistics
466471
mask = 1.0 - padding
467472
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
468473
count_v = jnp.sum(
@@ -477,17 +482,14 @@ def __call__(self, inputs, input_paddings, train):
477482
keepdims=True)
478483

479484
var = sum_vv / count_v
480-
481-
self.ra_mean.value = momentum * \
482-
self.ra_mean.value + (1 - momentum) * mean
483-
self.ra_var.value = momentum * \
484-
self.ra_var.value + (1 - momentum) * var
485-
else:
486-
mean = self.ra_mean.value
487-
var = self.ra_var.value
488-
485+
486+
if update_batch_norm:
487+
self.ra_mean.value = momentum * \
488+
self.ra_mean.value + (1 - momentum) * mean
489+
self.ra_var.value = momentum * \
490+
self.ra_var.value + (1 - momentum) * var
491+
489492
inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
490-
491493
bn_output = (inputs - mean) * inv + self.beta
492494
bn_output *= 1.0 - padding
493495

@@ -517,7 +519,7 @@ class ConvolutionBlock(nn.Module):
517519
config: ConformerConfig
518520

519521
@nn.compact
520-
def __call__(self, inputs, input_paddings, train):
522+
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn):
521523
config = self.config
522524
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
523525

@@ -546,7 +548,7 @@ def __call__(self, inputs, input_paddings, train):
546548
kernel_init=nn.initializers.xavier_uniform())(
547549
inputs)
548550

549-
inputs = BatchNorm(config)(inputs, input_paddings, train)
551+
inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn)
550552
if config.activation_function_name == 'swish':
551553
activation_fn = nn.swish
552554
elif config.activation_function_name == 'gelu':
@@ -586,7 +588,7 @@ class ConformerBlock(nn.Module):
586588
config: ConformerConfig
587589

588590
@nn.compact
589-
def __call__(self, inputs, input_paddings, train):
591+
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average):
590592
config = self.config
591593
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
592594

@@ -597,7 +599,7 @@ def __call__(self, inputs, input_paddings, train):
597599
inputs, input_paddings, train)
598600

599601
inputs = inputs + \
600-
ConvolutionBlock(config)(inputs, input_paddings, train)
602+
ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average)
601603

602604
inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
603605
inputs, padding_mask, train)
@@ -629,12 +631,23 @@ def setup(self):
629631
.use_dynamic_time_mask_max_frames)
630632

631633
@nn.compact
632-
def __call__(self, inputs, input_paddings, train):
634+
def __call__(self,
635+
inputs,
636+
input_paddings,
637+
train,
638+
update_batch_norm: Optional[bool] = None,
639+
use_running_average_bn: Optional[bool] = None):
633640
config = self.config
634641

635642
outputs = inputs
636643
output_paddings = input_paddings
637644

645+
# Set BN args if not supplied for backwards compatibility
646+
if update_batch_norm is None:
647+
update_batch_norm = train
648+
if use_running_average_bn is None:
649+
use_running_average_bn = not train
650+
638651
# Compute normalized log mel spectrograms from input audio signal.
639652
preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
640653
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
@@ -660,7 +673,7 @@ def __call__(self, inputs, input_paddings, train):
660673

661674
# Run the conformer encoder layers.
662675
for _ in range(config.num_encoder_layers):
663-
outputs = ConformerBlock(config)(outputs, output_paddings, train)
676+
outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn)
664677

665678
outputs = LayerNorm(config.encoder_dim)(outputs)
666679
# Run the decoder which in this case is a trivial projection layer.

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def model_fn(
107107
model_state: spec.ModelAuxiliaryState,
108108
mode: spec.ForwardPassMode,
109109
rng: spec.RandomState,
110-
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
110+
update_batch_norm: bool,
111+
use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
111112
variables = {'params': params, **model_state}
112113
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
113114
is_train_mode = mode == spec.ForwardPassMode.TRAIN
@@ -118,15 +119,17 @@ def model_fn(
118119
input_paddings,
119120
train=True,
120121
rngs={'dropout' : rng},
121-
mutable=['batch_stats'])
122+
mutable=['batch_stats'],
123+
use_running_average_bn=use_running_average_bn)
122124
return (logits, logit_paddings), new_model_state
123125
else:
124126
logits, logit_paddings = self._model.apply(
125127
variables,
126128
inputs,
127129
input_paddings,
128130
train=False,
129-
mutable=False)
131+
mutable=False,
132+
use_running_average_bn=use_running_average_bn)
130133
return (logits, logit_paddings), model_state
131134

132135
def _build_input_queue(

0 commit comments

Comments
 (0)