From 091e1a4730cd17dc5dbee8a65c13a3e252eaaa7c Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 19 Jun 2026 02:09:33 +0800 Subject: [PATCH] fix(dnn): BatchNorm running-var bias, pooling channel_axis bound, LayerNorm error - BatchNorm stored the biased batch variance into running_var; apply Bessel's N/(N-1) correction for the running buffer (PyTorch-consistent), keeping the biased variance for in-batch normalization (Medium) - Pooling rejected the leftmost negative channel_axis (== -x_dim) due to an abs() bound; widen to -x_dim <= axis < x_dim (Pool, _MaxPoolNd, AdaptivePool) (Medium) - LayerNorm wrong-shape error did ", ".join() -> TypeError masking the intended ValueError; map(str, ...) (Medium) Findings recorded in docs/issues-found-20260619-dnn.md --- brainpy/dnn/normalization.py | 16 +++- brainpy/dnn/normalization_test.py | 44 +++++++++++ brainpy/dnn/pooling.py | 12 ++- brainpy/dnn/pooling_layers_test.py | 31 ++++++++ docs/issues-found-20260619-dnn.md | 121 +++++++++++++++++++++++++++++ 5 files changed, 219 insertions(+), 5 deletions(-) create mode 100644 docs/issues-found-20260619-dnn.md diff --git a/brainpy/dnn/normalization.py b/brainpy/dnn/normalization.py index 00b321819..d0c5bd640 100644 --- a/brainpy/dnn/normalization.py +++ b/brainpy/dnn/normalization.py @@ -170,8 +170,20 @@ def update(self, x): 2 ) var = jnp.maximum(0., mean_of_square - _square(mean)) + # ``var`` above is the biased (divisor ``N``) variance used to normalize + # the current batch. The running buffer, however, should track the + # unbiased (Bessel-corrected, divisor ``N - 1``) estimate to match the + # conventional BatchNorm running statistic (e.g. PyTorch); otherwise the + # eval-time variance is systematically too small by ``(N - 1) / N`` (M-25). + num_reduced = 1 + for ax in self.axis: + num_reduced *= x.shape[ax] + if num_reduced > 1: + unbiased_var = var * (num_reduced / (num_reduced - 1)) + else: + unbiased_var = var self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean) - self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var) + self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * unbiased_var) else: mean = self.running_mean.value var = self.running_var.value @@ -533,7 +545,7 @@ def __init__( def update(self, x): if x.shape[-len(self.normalized_shape):] != self.normalized_shape: - raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), ' + raise ValueError(f'Expect the input shape should be (..., {", ".join(map(str, self.normalized_shape))}), ' f'but we got {x.shape}') axis = tuple(range(0, x.ndim - len(self.normalized_shape))) mean = jnp.mean(bm.as_jax(x), axis=axis, keepdims=True) diff --git a/brainpy/dnn/normalization_test.py b/brainpy/dnn/normalization_test.py index d1c5de714..74646d269 100644 --- a/brainpy/dnn/normalization_test.py +++ b/brainpy/dnn/normalization_test.py @@ -74,6 +74,50 @@ def test_InstanceNorm(self): net = bp.dnn.InstanceNorm(num_channels=6, mode=bm.training_mode) output = net(input) + def test_LayerNorm_shape_mismatch_raises_valueerror(self): + # Regression for P12-M1: the wrong-shape diagnostic used ``", ".join()`` + # which raised ``TypeError`` and masked the intended ``ValueError``. + net = bp.dnn.LayerNorm(10, mode=bm.training_mode) + bad_input = bm.random.randn(2, 5, 8) # last dim 8 != 10 + with self.assertRaises(ValueError): + net(bad_input) + + def test_BatchNorm_running_var_is_unbiased(self): + # Regression for P12-M3: the running variance buffer must use the unbiased + # (Bessel-corrected, divisor N-1) batch variance, matching PyTorch, instead + # of the biased (divisor N) variance used to normalize the current batch. + import jax.numpy as jnp + import numpy as np + bm.random.seed(123) + net = bp.dnn.BatchNorm1d(num_features=3, affine=False, mode=bm.training_mode) + bp.share.save(fit=True) + x = bm.random.randn(4, 5, 3) * 3.0 + 7.0 # N = 4*5 = 20 reduced elements + net(x) + + xj = bm.as_jax(x) + n = xj.shape[0] * xj.shape[1] + biased = jnp.var(xj, axis=(0, 1)) + unbiased = biased * n / (n - 1) + # After one update: running_var = 0.99 * 1.0 + 0.01 * . + expected_unbiased = 0.99 * 1.0 + 0.01 * unbiased + expected_biased = 0.99 * 1.0 + 0.01 * biased + rv = bm.as_jax(net.running_var.value) + self.assertTrue(bool(jnp.allclose(rv, expected_unbiased, atol=1e-5))) + # And it must NOT match the biased estimate (the previous behaviour). + self.assertFalse(bool(jnp.allclose(rv, expected_biased, atol=1e-5))) + + def test_BatchNorm_batch_is_biased_normalized(self): + # The normalization of the current batch itself must remain unit-variance + # (biased), unaffected by the running-buffer correction. + import jax.numpy as jnp + bm.random.seed(7) + net = bp.dnn.BatchNorm1d(num_features=3, affine=False, mode=bm.training_mode) + bp.share.save(fit=True) + x = bm.random.randn(8, 6, 3) * 2.0 - 1.0 + out = bm.as_jax(net(x)) + self.assertTrue(bool(jnp.allclose(out.mean(axis=(0, 1)), 0.0, atol=1e-5))) + self.assertTrue(bool(jnp.allclose(out.var(axis=(0, 1)), 1.0, atol=1e-4))) + if __name__ == '__main__': absltest.main() diff --git a/brainpy/dnn/pooling.py b/brainpy/dnn/pooling.py index 40f413851..e06749cc3 100644 --- a/brainpy/dnn/pooling.py +++ b/brainpy/dnn/pooling.py @@ -115,7 +115,9 @@ def _infer_shape(self, # channel axis channel_axis = self.channel_axis - if channel_axis and not 0 <= abs(channel_axis) < x_dim: + # Valid axes are ``-x_dim <= channel_axis < x_dim``. The previous ``abs()`` + # bound wrongly rejected the leftmost negative axis ``-x_dim`` (P12-M2). + if channel_axis and not -x_dim <= channel_axis < x_dim: raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions") if channel_axis and channel_axis < 0: channel_axis = x_dim + channel_axis @@ -387,7 +389,9 @@ def update(self, x): def _infer_shape(self, x_dim, inputs, element): channel_axis = self.channel_axis - if channel_axis and not 0 <= abs(channel_axis) < x_dim: + # Valid axes are ``-x_dim <= channel_axis < x_dim``. The previous ``abs()`` + # bound wrongly rejected the leftmost negative axis ``-x_dim`` (P12-M2). + if channel_axis and not -x_dim <= channel_axis < x_dim: raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions") if channel_axis and channel_axis < 0: channel_axis = x_dim + channel_axis @@ -784,7 +788,9 @@ def update(self, x): channel_axis = self.channel_axis if channel_axis: - if not 0 <= abs(channel_axis) < x.ndim: + # Valid axes are ``-x.ndim <= channel_axis < x.ndim``. The previous + # ``abs()`` bound wrongly rejected the leftmost negative axis (P12-M2). + if not -x.ndim <= channel_axis < x.ndim: raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}") if channel_axis < 0: channel_axis = x.ndim + channel_axis diff --git a/brainpy/dnn/pooling_layers_test.py b/brainpy/dnn/pooling_layers_test.py index e2dc9532a..da2437efb 100644 --- a/brainpy/dnn/pooling_layers_test.py +++ b/brainpy/dnn/pooling_layers_test.py @@ -243,5 +243,36 @@ def test_AdaptiveMaxPool3d_v1(self, axis): output = net(input) +class TestPoolingChannelAxis(parameterized.TestCase): + """Regression for P12-M2: the leftmost negative ``channel_axis`` (== -x_dim) + was wrongly rejected because the bound check used ``abs(channel_axis)``.""" + + def test_maxpool2d_leftmost_negative_channel_axis(self): + bm.random.seed() + # channels-first (C, H, W); channel axis is axis 0 == -3. + x = bm.random.randn(6, 8, 8) + net = bp.dnn.MaxPool2d(2, channel_axis=-3) + out = net(x) + self.assertEqual(out.shape, (6, 4, 4)) + # Must equal the equivalent positive channel_axis result. + net_pos = bp.dnn.MaxPool2d(2, channel_axis=0) + self.assertEqual(out.shape, net_pos(x).shape) + + def test_adaptiveavgpool2d_leftmost_negative_channel_axis(self): + bm.random.seed() + x = bm.random.randn(6, 8, 8) # (C, H, W) + net = bp.dnn.AdaptiveAvgPool2d((4, 4), channel_axis=-3) + out = net(x) + self.assertEqual(out.shape, (6, 4, 4)) + + def test_pool_leftmost_negative_channel_axis(self): + bm.random.seed() + # ``Pool`` family (MaxPool) with an integer kernel and channel_axis=-3. + x = bm.random.randn(6, 8, 8) + net = bp.dnn.MaxPool((2, 2), 2, channel_axis=-3) + out = net(x) + self.assertEqual(out.shape, (6, 4, 4)) + + if __name__ == '__main__': absltest.main() diff --git a/docs/issues-found-20260619-dnn.md b/docs/issues-found-20260619-dnn.md new file mode 100644 index 000000000..1ca9d7ab2 --- /dev/null +++ b/docs/issues-found-20260619-dnn.md @@ -0,0 +1,121 @@ +# P12 — `brainpy/dnn` expert review (2026-06-19) + +Branch: `fix/audit-20260619-dnn`. Scope: `brainpy/dnn/{activations,base,conv,dropout,function,interoperation_flax,linear,normalization,pooling}.py` + co-located tests. + +Environment: jax 0.10.2, brainstate 0.5.1, braintools 0.1.10, brainunit 0.5.1 (CPU). + +Severity legend: Critical (silently wrong/crash in default usage), High (wrong in realistic cases / broken public API), Medium (edge/fragility/error-handling), Low (style/docs — record only). + +--- + +### P12-M1 — `LayerNorm` shape-mismatch error does `", ".join()` → masks real error with `TypeError` [Medium] +- File: brainpy/dnn/normalization.py:536 +- Category: edge/error +- What: When the trailing input dims do not match `normalized_shape`, the guard raises + `ValueError(f'... (..., {", ".join(self.normalized_shape)}), but we got {x.shape}')`. + `self.normalized_shape` is a tuple of `int`, so `", ".join(...)` raises + `TypeError: sequence item 0: expected str instance, int found` instead of the intended `ValueError`. +- Why it's a bug: The user-facing diagnostic is replaced by an opaque `TypeError`, hiding the actual + shape problem. Any wrong-shape input to an affine/non-affine `LayerNorm` hits this. +- Repro: + ```python + bp.dnn.LayerNorm(10)(bm.random.randn(2, 5, 8)) # -> TypeError, not ValueError + ``` +- Fix: `", ".join(map(str, self.normalized_shape))` (and format the expected shape readably). +- Tests: `normalization_test.py::Test_Normalization::test_LayerNorm_shape_mismatch_raises_valueerror` +- Status: fixed +- (== prior audit M-26) + +--- + +### P12-M2 — Pooling rejects the leftmost negative `channel_axis` (`channel_axis == -x_dim`) [Medium] +- File: brainpy/dnn/pooling.py:118 (`Pool._infer_shape`), 390 (`_MaxPoolNd._infer_shape`), 787 (`AdaptivePool.update`) +- Category: edge/error +- What: The bound check is `if channel_axis and not 0 <= abs(channel_axis) < x_dim: raise`. + Using `abs(channel_axis)` makes `channel_axis == -x_dim` (the valid leftmost axis, e.g. `-3` for a + 3-D `(C, H, W)` input) fail the check and raise `ValueError`. +- Why it's a bug: A legitimate, common channels-first layout (`channel_axis=-3` on `(C,H,W)`, or + `-4` on `(N,C,H,W)`) is wrongly rejected. The correct numpy-style bound is `-x_dim <= axis < x_dim`. +- Repro: + ```python + bp.dnn.MaxPool2d(2, channel_axis=-3)(bm.random.randn(6, 4, 4)) # ValueError: Invalid channel axis -3 + ``` +- Fix: Replace the `abs()` bound test with `not -x_dim <= channel_axis < x_dim` in all three sites. +- Tests: `pooling_layers_test.py::TestPoolingChannelAxis::test_maxpool2d_leftmost_negative_channel_axis`, + `...::test_adaptiveavgpool2d_leftmost_negative_channel_axis` +- Status: fixed +- (== prior audit M-27) + +--- + +### P12-M3 — `BatchNorm` stores the *biased* batch variance into `running_var` (PyTorch uses unbiased) [Medium] +- File: brainpy/dnn/normalization.py:172-174 +- Category: numerics +- What: In fit mode the running estimate is updated with the biased population variance + `var = mean_of_square - mean**2` (divisor `N`). PyTorch / the conventional BatchNorm running statistic + uses the *unbiased* sample variance (divisor `N-1`, i.e. Bessel's correction) for the running buffer, + while keeping the biased variance only for normalizing the current batch. +- Why it's a bug: The running variance used at eval time is systematically too small by a factor of + `(N-1)/N`. For small batch/window counts `N` this is a meaningful (a few percent) bias in the + eval-time normalization — i.e. inference results drift from the trained reference. +- Repro: + ```python + bn = bp.dnn.BatchNorm1d(3, affine=False); bp.share.save(fit=True) + bn(bm.random.randn(4, 5, 3)) # N = 20 + # running_var == 0.99*1 + 0.01*biased_var, not unbiased_var + ``` +- Fix: Scale the variance fed into the `running_var` EMA by `N/(N-1)` (with `N` = number of reduced + elements), guarding `N == 1`. The batch normalization itself keeps the biased `var`. +- Tests: `normalization_test.py::Test_Normalization::test_BatchNorm_running_var_is_unbiased` +- Status: fixed +- (== prior audit M-25) + +--- + +### P12-L1 — `Flatten` default `start_dim=0` contradicts its docstring example and PyTorch (`start_dim=1`) [Low] +- File: brainpy/dnn/function.py:91 (default), 74-87 (docstring example) +- Category: api-drift/style +- What: `Flatten.__init__` defaults `start_dim=0`. The class docstring example claims + `Flatten()` on `(32, 1, 5, 5)` yields `(32, 25)` (PyTorch's `start_dim=1` semantics), but in + `NonBatchingMode` the actual default flattens the batch dim too → `(800,)`. +- Why it's a bug: The documented contract and the implemented default disagree. PyTorch's `nn.Flatten` + default is `start_dim=1`. +- Repro: + ```python + bp.dnn.Flatten()(bm.random.randn(32, 1, 5, 5)).shape # (800,), docstring says (32, 25) + ``` +- Fix: recorded only. NOTE: changing the default to `1` would change the documented `NonBatchingMode` + contract and break `function_test.py::test_flatten_non_batching_mode` (asserts `(600,)` from default + start_dim under `NonBatchingMode`). The discrepancy is documentation/default-value drift, not a + silent numeric error, and a default change is a cross-cutting API change. Left for maintainers to + decide (fix docs vs. change default + migrate the test). +- Tests: none +- Status: recorded-only + +--- + +## Cross-check vs `dev/issues-found-20260618.md` (dnn entries) + +- **C-05** (`GroupNorm`/`InstanceNorm` reduce over the group axis): **already fixed** in this tree + (`normalization.py:640` reduces over spatial + within-group channel axis, keeps the group axis). + Verified: `GroupNorm(3,6) != GroupNorm(1,6) != GroupNorm(6,6)`, per-group means ≈ 0. No action. +- **H-51** (`BatchNorm`/affine `LayerNorm`/`GroupNorm` crash out-of-the-box under default mode): + **already fixed** (`BatchNorm` defaults to `training_mode`; affine params wrapped as `Variable` vs + `TrainVar` per mode instead of a hard assert). No action. +- **M-25**: still present → fixed here as **P12-M3**. +- **M-26**: still present → fixed here as **P12-M1**. +- **M-27**: still present → fixed here as **P12-M2**. +- **M-28** (`Flatten` default): still present → recorded as **P12-L1** (see note; default change out of safe scope). + +## Checked and found correct (no action) +- `Dropout`: `prob` = keep-probability; `bernoulli(prob)` keeps with prob `prob`; survivors scaled by + `1/prob == 1/(1-rate)`; eval (`fit=False`) is a no-op. Correct. +- `Conv*` / `ConvTranspose*`: kernel shapes, `feature_group_count=groups`, dimension numbers, + bias broadcast (channels-last), non-batching unsqueeze/squeeze, `SAME`/`VALID`/int/tuple padding + normalization. Correct (smoke-checked shapes & a transpose upsample). +- `AvgPool`/`_AvgPoolNd` non-VALID averaging via a second `reduce_window` count. Correct. +- `GroupNorm` affine broadcast via `lax.broadcast_to_rank` to channels-last. Correct. +- Activation wrappers delegate to `bm.*`; formulas match docstrings. `Softmax2d` uses axis `-3` + (channels-first `(N,C,H,W)`/`(C,H,W)`), matching its documented contract. +- `interoperation_flax`: Flax round-trip param flatten/unflatten, `ToFlaxRNNCell` carry handling. + Correct for flax present/absent.