Skip to content

Commit b24812f

Browse files
committed
BN Fixes
1 parent bdece3b commit b24812f

3 files changed

Lines changed: 19 additions & 16 deletions

File tree

algorithmic_efficiency/pytorch_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def sync_ddp_time(time: float, device: torch.device) -> float:
5757
dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX)
5858
return time_tensor.item()
5959

60-
6160
def update_batch_norm_fn(module: spec.ParameterContainer,
6261
update_batch_norm: bool) -> None:
6362
bn_layers = (
@@ -67,10 +66,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
6766
)
6867
if isinstance(module, bn_layers):
6968
if not update_batch_norm:
70-
module.eval()
71-
module.momentum_backup = module.momentum
69+
if not hasattr(module, 'momentum_backup'):
70+
module.momentum_backup = module.momentum
71+
7272
# module.momentum can be float or torch.Tensor.
73-
module.momentum = 0. * module.momentum_backup
73+
if torch.is_tensor(module.momentum_backup):
74+
module.momentum = torch.zeros_like(module.momentum_backup)
75+
else:
76+
module.momentum = 0.0
7477
elif hasattr(module, 'momentum_backup'):
75-
module.momentum = module.momentum_backup
76-
module.track_running_stats = update_batch_norm
78+
module.momentum = module.momentum_backup

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ConformerConfig:
4040
time_masks_per_frame: float = 0.0
4141
use_dynamic_time_mask_max_frames: bool = True
4242
input_dropout_rate: float = 0.1
43-
batch_norm_momentum: float = 0.999
43+
batch_norm_momentum: float = 1 - 0.999
4444
batch_norm_epsilon: float = 0.001
4545
use_specaug: bool = True
4646
attention_temperature: float = 1.0
@@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings):
369369
mean = (masked_inp).sum(dim=(0, 1)) / count
370370
var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count
371371

372-
self.running_mean = self.momentum * self.running_mean + (
373-
1 - self.momentum) * mean.detach()
374-
self.running_var = self.momentum * self.running_var + (
375-
1 - self.momentum) * var.detach()
372+
self.running_mean = (1 - self.momentum) * self.running_mean + (
373+
self.momentum) * mean.detach()
374+
self.running_var = (1 - self.momentum) * self.running_var + (
375+
self.momentum) * var.detach()
376+
376377
else:
377378
mean = self.running_mean
378379
var = self.running_var

algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class DeepspeechConfig:
3636
time_mask_max_ratio: float = 0.05
3737
time_masks_per_frame: float = 0.0
3838
use_dynamic_time_mask_max_frames: bool = True
39-
batch_norm_momentum: float = 0.999
39+
batch_norm_momentum: float = 1 - 0.999
4040
batch_norm_epsilon: float = 0.001
4141
# If None, defaults to 0.1.
4242
input_dropout_rate: Optional[float] = 0.1
@@ -264,10 +264,10 @@ def forward(self, inputs, input_paddings):
264264
sum_ = dist_nn.all_reduce(sum_)
265265
var = sum_ / count
266266

267-
self.running_mean = self.momentum * self.running_mean + (
268-
1 - self.momentum) * mean.detach()
269-
self.running_var = self.momentum * self.running_var + (
270-
1 - self.momentum) * var.detach()
267+
self.running_mean = (1 - self.momentum) * self.running_mean + (
268+
self.momentum) * mean.detach()
269+
self.running_var = (1 - self.momentum) * self.running_var + (
270+
self.momentum) * var.detach()
271271
else:
272272
mean = self.running_mean
273273
var = self.running_var

0 commit comments

Comments
 (0)