Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions brainpy/dnn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,20 @@ def update(self, x):
2
)
var = jnp.maximum(0., mean_of_square - _square(mean))
# ``var`` above is the biased (divisor ``N``) variance used to normalize
# the current batch. The running buffer, however, should track the
# unbiased (Bessel-corrected, divisor ``N - 1``) estimate to match the
# conventional BatchNorm running statistic (e.g. PyTorch); otherwise the
# eval-time variance is systematically too small by ``(N - 1) / N`` (M-25).
num_reduced = 1
for ax in self.axis:
num_reduced *= x.shape[ax]
if num_reduced > 1:
unbiased_var = var * (num_reduced / (num_reduced - 1))
else:
unbiased_var = var
self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean)
self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var)
self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * unbiased_var)
else:
mean = self.running_mean.value
var = self.running_var.value
Expand Down Expand Up @@ -533,7 +545,7 @@ def __init__(

def update(self, x):
if x.shape[-len(self.normalized_shape):] != self.normalized_shape:
raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), '
raise ValueError(f'Expect the input shape should be (..., {", ".join(map(str, self.normalized_shape))}), '
f'but we got {x.shape}')
axis = tuple(range(0, x.ndim - len(self.normalized_shape)))
mean = jnp.mean(bm.as_jax(x), axis=axis, keepdims=True)
Expand Down
44 changes: 44 additions & 0 deletions brainpy/dnn/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,50 @@ def test_InstanceNorm(self):
net = bp.dnn.InstanceNorm(num_channels=6, mode=bm.training_mode)
output = net(input)

def test_LayerNorm_shape_mismatch_raises_valueerror(self):
# Regression for P12-M1: the wrong-shape diagnostic used ``", ".join(<ints>)``
# which raised ``TypeError`` and masked the intended ``ValueError``.
net = bp.dnn.LayerNorm(10, mode=bm.training_mode)
bad_input = bm.random.randn(2, 5, 8) # last dim 8 != 10
with self.assertRaises(ValueError):
net(bad_input)

def test_BatchNorm_running_var_is_unbiased(self):
# Regression for P12-M3: the running variance buffer must use the unbiased
# (Bessel-corrected, divisor N-1) batch variance, matching PyTorch, instead
# of the biased (divisor N) variance used to normalize the current batch.
import jax.numpy as jnp
import numpy as np
bm.random.seed(123)
net = bp.dnn.BatchNorm1d(num_features=3, affine=False, mode=bm.training_mode)
bp.share.save(fit=True)
x = bm.random.randn(4, 5, 3) * 3.0 + 7.0 # N = 4*5 = 20 reduced elements
net(x)

xj = bm.as_jax(x)
n = xj.shape[0] * xj.shape[1]
biased = jnp.var(xj, axis=(0, 1))
unbiased = biased * n / (n - 1)
# After one update: running_var = 0.99 * 1.0 + 0.01 * <var>.
expected_unbiased = 0.99 * 1.0 + 0.01 * unbiased
expected_biased = 0.99 * 1.0 + 0.01 * biased
rv = bm.as_jax(net.running_var.value)
self.assertTrue(bool(jnp.allclose(rv, expected_unbiased, atol=1e-5)))
# And it must NOT match the biased estimate (the previous behaviour).
self.assertFalse(bool(jnp.allclose(rv, expected_biased, atol=1e-5)))

def test_BatchNorm_batch_is_biased_normalized(self):
# The normalization of the current batch itself must remain unit-variance
# (biased), unaffected by the running-buffer correction.
import jax.numpy as jnp
bm.random.seed(7)
net = bp.dnn.BatchNorm1d(num_features=3, affine=False, mode=bm.training_mode)
bp.share.save(fit=True)
x = bm.random.randn(8, 6, 3) * 2.0 - 1.0
out = bm.as_jax(net(x))
self.assertTrue(bool(jnp.allclose(out.mean(axis=(0, 1)), 0.0, atol=1e-5)))
self.assertTrue(bool(jnp.allclose(out.var(axis=(0, 1)), 1.0, atol=1e-4)))


if __name__ == '__main__':
absltest.main()
12 changes: 9 additions & 3 deletions brainpy/dnn/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def _infer_shape(self,

# channel axis
channel_axis = self.channel_axis
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
# Valid axes are ``-x_dim <= channel_axis < x_dim``. The previous ``abs()``
# bound wrongly rejected the leftmost negative axis ``-x_dim`` (P12-M2).
if channel_axis and not -x_dim <= channel_axis < x_dim:
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
if channel_axis and channel_axis < 0:
channel_axis = x_dim + channel_axis
Expand Down Expand Up @@ -387,7 +389,9 @@ def update(self, x):

def _infer_shape(self, x_dim, inputs, element):
channel_axis = self.channel_axis
if channel_axis and not 0 <= abs(channel_axis) < x_dim:
# Valid axes are ``-x_dim <= channel_axis < x_dim``. The previous ``abs()``
# bound wrongly rejected the leftmost negative axis ``-x_dim`` (P12-M2).
if channel_axis and not -x_dim <= channel_axis < x_dim:
raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions")
if channel_axis and channel_axis < 0:
channel_axis = x_dim + channel_axis
Expand Down Expand Up @@ -784,7 +788,9 @@ def update(self, x):
channel_axis = self.channel_axis

if channel_axis:
if not 0 <= abs(channel_axis) < x.ndim:
# Valid axes are ``-x.ndim <= channel_axis < x.ndim``. The previous
# ``abs()`` bound wrongly rejected the leftmost negative axis (P12-M2).
if not -x.ndim <= channel_axis < x.ndim:
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
if channel_axis < 0:
channel_axis = x.ndim + channel_axis
Expand Down
31 changes: 31 additions & 0 deletions brainpy/dnn/pooling_layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,36 @@ def test_AdaptiveMaxPool3d_v1(self, axis):
output = net(input)


class TestPoolingChannelAxis(parameterized.TestCase):
"""Regression for P12-M2: the leftmost negative ``channel_axis`` (== -x_dim)
was wrongly rejected because the bound check used ``abs(channel_axis)``."""

def test_maxpool2d_leftmost_negative_channel_axis(self):
bm.random.seed()
# channels-first (C, H, W); channel axis is axis 0 == -3.
x = bm.random.randn(6, 8, 8)
net = bp.dnn.MaxPool2d(2, channel_axis=-3)
out = net(x)
self.assertEqual(out.shape, (6, 4, 4))
# Must equal the equivalent positive channel_axis result.
net_pos = bp.dnn.MaxPool2d(2, channel_axis=0)
self.assertEqual(out.shape, net_pos(x).shape)

def test_adaptiveavgpool2d_leftmost_negative_channel_axis(self):
bm.random.seed()
x = bm.random.randn(6, 8, 8) # (C, H, W)
net = bp.dnn.AdaptiveAvgPool2d((4, 4), channel_axis=-3)
out = net(x)
self.assertEqual(out.shape, (6, 4, 4))

def test_pool_leftmost_negative_channel_axis(self):
bm.random.seed()
# ``Pool`` family (MaxPool) with an integer kernel and channel_axis=-3.
x = bm.random.randn(6, 8, 8)
net = bp.dnn.MaxPool((2, 2), 2, channel_axis=-3)
out = net(x)
self.assertEqual(out.shape, (6, 4, 4))


if __name__ == '__main__':
absltest.main()
121 changes: 121 additions & 0 deletions docs/issues-found-20260619-dnn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# P12 — `brainpy/dnn` expert review (2026-06-19)

Branch: `fix/audit-20260619-dnn`. Scope: `brainpy/dnn/{activations,base,conv,dropout,function,interoperation_flax,linear,normalization,pooling}.py` + co-located tests.

Environment: jax 0.10.2, brainstate 0.5.1, braintools 0.1.10, brainunit 0.5.1 (CPU).

Severity legend: Critical (silently wrong/crash in default usage), High (wrong in realistic cases / broken public API), Medium (edge/fragility/error-handling), Low (style/docs — record only).

---

### P12-M1 — `LayerNorm` shape-mismatch error does `", ".join(<ints>)` → masks real error with `TypeError` [Medium]
- File: brainpy/dnn/normalization.py:536
- Category: edge/error
- What: When the trailing input dims do not match `normalized_shape`, the guard raises
`ValueError(f'... (..., {", ".join(self.normalized_shape)}), but we got {x.shape}')`.
`self.normalized_shape` is a tuple of `int`, so `", ".join(...)` raises
`TypeError: sequence item 0: expected str instance, int found` instead of the intended `ValueError`.
- Why it's a bug: The user-facing diagnostic is replaced by an opaque `TypeError`, hiding the actual
shape problem. Any wrong-shape input to an affine/non-affine `LayerNorm` hits this.
- Repro:
```python
bp.dnn.LayerNorm(10)(bm.random.randn(2, 5, 8)) # -> TypeError, not ValueError
```
- Fix: `", ".join(map(str, self.normalized_shape))` (and format the expected shape readably).
- Tests: `normalization_test.py::Test_Normalization::test_LayerNorm_shape_mismatch_raises_valueerror`
- Status: fixed
- (== prior audit M-26)

---

### P12-M2 — Pooling rejects the leftmost negative `channel_axis` (`channel_axis == -x_dim`) [Medium]
- File: brainpy/dnn/pooling.py:118 (`Pool._infer_shape`), 390 (`_MaxPoolNd._infer_shape`), 787 (`AdaptivePool.update`)
- Category: edge/error
- What: The bound check is `if channel_axis and not 0 <= abs(channel_axis) < x_dim: raise`.
Using `abs(channel_axis)` makes `channel_axis == -x_dim` (the valid leftmost axis, e.g. `-3` for a
3-D `(C, H, W)` input) fail the check and raise `ValueError`.
- Why it's a bug: A legitimate, common channels-first layout (`channel_axis=-3` on `(C,H,W)`, or
`-4` on `(N,C,H,W)`) is wrongly rejected. The correct numpy-style bound is `-x_dim <= axis < x_dim`.
- Repro:
```python
bp.dnn.MaxPool2d(2, channel_axis=-3)(bm.random.randn(6, 4, 4)) # ValueError: Invalid channel axis -3
```
- Fix: Replace the `abs()` bound test with `not -x_dim <= channel_axis < x_dim` in all three sites.
- Tests: `pooling_layers_test.py::TestPoolingChannelAxis::test_maxpool2d_leftmost_negative_channel_axis`,
`...::test_adaptiveavgpool2d_leftmost_negative_channel_axis`
- Status: fixed
- (== prior audit M-27)

---

### P12-M3 — `BatchNorm` stores the *biased* batch variance into `running_var` (PyTorch uses unbiased) [Medium]
- File: brainpy/dnn/normalization.py:172-174
- Category: numerics
- What: In fit mode the running estimate is updated with the biased population variance
`var = mean_of_square - mean**2` (divisor `N`). PyTorch / the conventional BatchNorm running statistic
uses the *unbiased* sample variance (divisor `N-1`, i.e. Bessel's correction) for the running buffer,
while keeping the biased variance only for normalizing the current batch.
- Why it's a bug: The running variance used at eval time is systematically too small by a factor of
`(N-1)/N`. For small batch/window counts `N` this is a meaningful (a few percent) bias in the
eval-time normalization — i.e. inference results drift from the trained reference.
- Repro:
```python
bn = bp.dnn.BatchNorm1d(3, affine=False); bp.share.save(fit=True)
bn(bm.random.randn(4, 5, 3)) # N = 20
# running_var == 0.99*1 + 0.01*biased_var, not unbiased_var
```
- Fix: Scale the variance fed into the `running_var` EMA by `N/(N-1)` (with `N` = number of reduced
elements), guarding `N == 1`. The batch normalization itself keeps the biased `var`.
- Tests: `normalization_test.py::Test_Normalization::test_BatchNorm_running_var_is_unbiased`
- Status: fixed
- (== prior audit M-25)

---

### P12-L1 — `Flatten` default `start_dim=0` contradicts its docstring example and PyTorch (`start_dim=1`) [Low]
- File: brainpy/dnn/function.py:91 (default), 74-87 (docstring example)
- Category: api-drift/style
- What: `Flatten.__init__` defaults `start_dim=0`. The class docstring example claims
`Flatten()` on `(32, 1, 5, 5)` yields `(32, 25)` (PyTorch's `start_dim=1` semantics), but in
`NonBatchingMode` the actual default flattens the batch dim too → `(800,)`.
- Why it's a bug: The documented contract and the implemented default disagree. PyTorch's `nn.Flatten`
default is `start_dim=1`.
- Repro:
```python
bp.dnn.Flatten()(bm.random.randn(32, 1, 5, 5)).shape # (800,), docstring says (32, 25)
```
- Fix: recorded only. NOTE: changing the default to `1` would change the documented `NonBatchingMode`
contract and break `function_test.py::test_flatten_non_batching_mode` (asserts `(600,)` from default
start_dim under `NonBatchingMode`). The discrepancy is documentation/default-value drift, not a
silent numeric error, and a default change is a cross-cutting API change. Left for maintainers to
decide (fix docs vs. change default + migrate the test).
- Tests: none
- Status: recorded-only

---

## Cross-check vs `dev/issues-found-20260618.md` (dnn entries)

- **C-05** (`GroupNorm`/`InstanceNorm` reduce over the group axis): **already fixed** in this tree
(`normalization.py:640` reduces over spatial + within-group channel axis, keeps the group axis).
Verified: `GroupNorm(3,6) != GroupNorm(1,6) != GroupNorm(6,6)`, per-group means ≈ 0. No action.
- **H-51** (`BatchNorm`/affine `LayerNorm`/`GroupNorm` crash out-of-the-box under default mode):
**already fixed** (`BatchNorm` defaults to `training_mode`; affine params wrapped as `Variable` vs
`TrainVar` per mode instead of a hard assert). No action.
- **M-25**: still present → fixed here as **P12-M3**.
- **M-26**: still present → fixed here as **P12-M1**.
- **M-27**: still present → fixed here as **P12-M2**.
- **M-28** (`Flatten` default): still present → recorded as **P12-L1** (see note; default change out of safe scope).

## Checked and found correct (no action)
- `Dropout`: `prob` = keep-probability; `bernoulli(prob)` keeps with prob `prob`; survivors scaled by
`1/prob == 1/(1-rate)`; eval (`fit=False`) is a no-op. Correct.
- `Conv*` / `ConvTranspose*`: kernel shapes, `feature_group_count=groups`, dimension numbers,
bias broadcast (channels-last), non-batching unsqueeze/squeeze, `SAME`/`VALID`/int/tuple padding
normalization. Correct (smoke-checked shapes & a transpose upsample).
- `AvgPool`/`_AvgPoolNd` non-VALID averaging via a second `reduce_window` count. Correct.
- `GroupNorm` affine broadcast via `lax.broadcast_to_rank` to channels-last. Correct.
- Activation wrappers delegate to `bm.*`; formulas match docstrings. `Softmax2d` uses axis `-3`
(channels-first `(N,C,H,W)`/`(C,H,W)`), matching its documented contract.
- `interoperation_flax`: Flax round-trip param flatten/unflatten, `ToFlaxRNNCell` carry handling.
Correct for flax present/absent.
Loading