Skip to content

Commit 7ca8365

Browse files
committed
formatting
1 parent f574bf0 commit 7ca8365

8 files changed

Lines changed: 49 additions & 24 deletions

File tree

algorithmic_efficiency/pytorch_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def sync_ddp_time(time: float, device: torch.device) -> float:
5757
dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX)
5858
return time_tensor.item()
5959

60+
6061
def update_batch_norm_fn(module: spec.ParameterContainer,
6162
update_batch_norm: bool) -> None:
6263
bn_layers = (
@@ -75,4 +76,4 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
7576
else:
7677
module.momentum = 0.0
7778
elif hasattr(module, 'momentum_backup'):
78-
module.momentum = module.momentum_backup
79+
module.momentum = module.momentum_backup

algorithmic_efficiency/workloads/cifar/cifar_jax/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def __call__(self,
3131
update_batch_norm: bool = True,
3232
use_running_average_bn: bool = None) -> spec.Tensor:
3333
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
34-
35-
# Preserve default behavior for backwards compatibility
34+
35+
# Preserve default behavior for backwards compatibility
3636
if use_running_average_bn is None:
37-
use_running_average_bn = not update_batch_norm
37+
use_running_average_bn = not update_batch_norm
3838
norm = functools.partial(
3939
nn.BatchNorm,
4040
use_running_average=use_running_average_bn,

algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def model_fn(
111111
mode: spec.ForwardPassMode,
112112
rng: spec.RandomState,
113113
update_batch_norm: bool,
114-
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
114+
use_running_average_bn: Optional[bool] = None
115+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
115116
del mode
116117
del rng
117118
variables = {'params': params, **model_state}

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def __call__(self,
8888
use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
8989
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
9090

91-
# Preserve default behavior for backwards compatibility
91+
# Preserve default behavior for backwards compatibility
9292
if use_running_average_bn is None:
93-
use_running_average_bn = not update_batch_norm
93+
use_running_average_bn = not update_batch_norm
9494
norm = functools.partial(
9595
nn.BatchNorm,
9696
use_running_average=use_running_average_bn,

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def model_fn(
149149
mode: spec.ForwardPassMode,
150150
rng: spec.RandomState,
151151
update_batch_norm: bool,
152-
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
152+
use_running_average_bn: Optional[bool] = None
153+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
153154
del mode
154155
del rng
155156
variables = {'params': params, **model_state}

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -454,15 +454,19 @@ def setup(self):
454454
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
455455

456456
@nn.compact
457-
def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn):
457+
def __call__(self,
458+
inputs,
459+
input_paddings,
460+
update_batch_norm,
461+
use_running_average_bn):
458462
rank = inputs.ndim
459463
reduce_over_dims = list(range(0, rank - 1))
460464

461465
padding = jnp.expand_dims(input_paddings, -1)
462466
momentum = self.config.batch_norm_momentum
463467
epsilon = self.config.batch_norm_epsilon
464468

465-
if use_running_average_bn:
469+
if use_running_average_bn:
466470
mean = self.ra_mean.value
467471
var = self.ra_var.value
468472

@@ -482,13 +486,13 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag
482486
keepdims=True)
483487

484488
var = sum_vv / count_v
485-
489+
486490
if update_batch_norm:
487491
self.ra_mean.value = momentum * \
488492
self.ra_mean.value + (1 - momentum) * mean
489493
self.ra_var.value = momentum * \
490494
self.ra_var.value + (1 - momentum) * var
491-
495+
492496
inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
493497
bn_output = (inputs - mean) * inv + self.beta
494498
bn_output *= 1.0 - padding
@@ -519,7 +523,12 @@ class ConvolutionBlock(nn.Module):
519523
config: ConformerConfig
520524

521525
@nn.compact
522-
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn):
526+
def __call__(self,
527+
inputs,
528+
input_paddings,
529+
train,
530+
update_batch_norm,
531+
use_running_average_bn):
523532
config = self.config
524533
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
525534

@@ -548,7 +557,10 @@ def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running
548557
kernel_init=nn.initializers.xavier_uniform())(
549558
inputs)
550559

551-
inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn)
560+
inputs = BatchNorm(config)(inputs,
561+
input_paddings,
562+
update_batch_norm,
563+
use_running_average_bn)
552564
if config.activation_function_name == 'swish':
553565
activation_fn = nn.swish
554566
elif config.activation_function_name == 'gelu':
@@ -588,7 +600,12 @@ class ConformerBlock(nn.Module):
588600
config: ConformerConfig
589601

590602
@nn.compact
591-
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average):
603+
def __call__(self,
604+
inputs,
605+
input_paddings,
606+
train,
607+
update_batch_norm,
608+
use_running_average):
592609
config = self.config
593610
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
594611

@@ -631,12 +648,12 @@ def setup(self):
631648
.use_dynamic_time_mask_max_frames)
632649

633650
@nn.compact
634-
def __call__(self,
635-
inputs,
636-
input_paddings,
637-
train,
638-
update_batch_norm: Optional[bool] = None,
639-
use_running_average_bn: Optional[bool] = None):
651+
def __call__(self,
652+
inputs,
653+
input_paddings,
654+
train,
655+
update_batch_norm: Optional[bool] = None,
656+
use_running_average_bn: Optional[bool] = None):
640657
config = self.config
641658

642659
outputs = inputs
@@ -673,7 +690,11 @@ def __call__(self,
673690

674691
# Run the conformer encoder layers.
675692
for _ in range(config.num_encoder_layers):
676-
outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn)
693+
outputs = ConformerBlock(config)(outputs,
694+
output_paddings,
695+
train,
696+
update_batch_norm,
697+
use_running_average_bn)
677698

678699
outputs = LayerNorm(config.encoder_dim)(outputs)
679700
# Run the decoder which in this case is a trivial projection layer.

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def model_fn(
108108
mode: spec.ForwardPassMode,
109109
rng: spec.RandomState,
110110
update_batch_norm: bool,
111-
use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
111+
use_running_average_bn: Optional[bool] = None
112+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
112113
variables = {'params': params, **model_state}
113114
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
114115
is_train_mode = mode == spec.ForwardPassMode.TRAIN

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def forward(self, inputs, input_paddings):
373373
self.momentum) * mean.detach()
374374
self.running_var = (1 - self.momentum) * self.running_var + (
375375
self.momentum) * var.detach()
376-
376+
377377
else:
378378
mean = self.running_mean
379379
var = self.running_var

0 commit comments

Comments
 (0)