From 4f3d91bed80ba12dc5709ec1e5a9d6f2576ac635 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 19 Jun 2026 02:06:43 +0800 Subject: [PATCH] fix(math): correct ShardedArray pytree flatten and remove_diag guard - ShardedArray pytree round-trip dropped _keep_sharding, raising AttributeError under every JAX transform (jit/vmap/scan/grad/tree_map); add tree_flatten/tree_unflatten carrying _keep_sharding in aux_data (High) - remove_diag raised an opaque broadcasting error on tall (m>n) matrices; add a clear ValueError guard (Medium) Findings recorded in docs/issues-found-20260619-math-core.md --- brainpy/math/math_core_fixes_test.py | 80 +++++++++++ brainpy/math/ndarray.py | 18 +++ brainpy/math/others.py | 12 ++ docs/issues-found-20260619-math-core.md | 173 ++++++++++++++++++++++++ 4 files changed, 283 insertions(+) create mode 100644 docs/issues-found-20260619-math-core.md diff --git a/brainpy/math/math_core_fixes_test.py b/brainpy/math/math_core_fixes_test.py index 310b050bb..8ef21a61a 100644 --- a/brainpy/math/math_core_fixes_test.py +++ b/brainpy/math/math_core_fixes_test.py @@ -895,3 +895,83 @@ def test_partition_with_axis_name_sequence(): def test_keep_constraint_on_bp_array(): out = sharding.keep_constraint(Array([1., 2., 3.])) np.testing.assert_allclose(np.asarray(out), [1., 2., 3.]) + + +# =========================================================================== +# P2 audit (2026-06-19) regression tests +# =========================================================================== + +# --- ndarray.py : P2-H1 (ShardedArray pytree round-trip) ------------------- + +def test_shardedarray_pytree_round_trip_preserves_value_and_keep_sharding(): + """P2-H1: ``ShardedArray`` reused the base ``Array.tree_unflatten`` which + only set ``_value`` and never ``_keep_sharding``. Any pytree round-trip + (``jit``/``vmap``/``scan``/``grad``/``tree_map``) then made the ``value`` + getter raise ``AttributeError: ... has no attribute '_keep_sharding'``. + The flatten/unflatten pair must round-trip both attributes.""" + from jax.tree_util import tree_flatten, tree_unflatten + + for keep in (True, False): + sa = ShardedArray(jnp.arange(6.), keep_sharding=keep) + flat, treedef = tree_flatten(sa) + back = tree_unflatten(treedef, flat) + assert isinstance(back, ShardedArray) + # The getter must not raise (regression for the missing attribute). + np.testing.assert_allclose(np.asarray(back.value), np.arange(6.)) + # ``keep_sharding`` must survive the round-trip. + assert back._keep_sharding is keep + + +def test_shardedarray_works_under_jit(): + """P2-H1: a ``ShardedArray`` passed through ``jit`` (which pytree-flattens + and unflattens its arguments) must not crash when its value is read.""" + + @jax.jit + def f(x): + return x.value + 1. + + out = f(ShardedArray(jnp.arange(3.))) + np.testing.assert_allclose(np.asarray(out), [1., 2., 3.]) + + +def test_shardedarray_works_under_vmap(): + """P2-H1: the same fix is exercised by ``vmap``.""" + + @jax.vmap + def g(x): + return x.value * 2. + + out = g(ShardedArray(jnp.arange(4.))) + np.testing.assert_allclose(np.asarray(out), [0., 2., 4., 6.]) + + +# --- others.py : P2-M1 (remove_diag m > n) --------------------------------- + +def test_remove_diag_square_and_wide(): + """P2-M1: the working ``m <= n`` path is unchanged.""" + from brainpy.math.others import remove_diag + + square = remove_diag(jnp.arange(9).reshape(3, 3)) + np.testing.assert_array_equal(np.asarray(square), [[1, 2], [3, 5], [6, 7]]) + + wide = remove_diag(jnp.arange(12).reshape(3, 4)) + np.testing.assert_array_equal(np.asarray(wide), + [[1, 2, 3], [4, 6, 7], [8, 9, 11]]) + + +def test_remove_diag_tall_raises_clear_error(): + """P2-M1: a tall matrix (m > n) has no well-defined ``(m, n-1)`` result; the + old code crashed with an opaque broadcasting error. It must now raise a + clear ``ValueError`` mentioning the shape constraint.""" + from brainpy.math.others import remove_diag + + with pytest.raises(ValueError, match=r'm <= n'): + remove_diag(jnp.arange(12).reshape(4, 3)) + + +def test_remove_diag_still_rejects_non_2d(): + """P2-M1: the pre-existing ndim guard is preserved.""" + from brainpy.math.others import remove_diag + + with pytest.raises(ValueError, match=r'2D matrix'): + remove_diag(jnp.arange(8).reshape(2, 2, 2)) diff --git a/brainpy/math/ndarray.py b/brainpy/math/ndarray.py index b41c702df..1d4e66f35 100644 --- a/brainpy/math/ndarray.py +++ b/brainpy/math/ndarray.py @@ -243,6 +243,24 @@ def __init__(self, value, dtype: Any = None, *, keep_sharding: bool = True): super().__init__(value, dtype) self._keep_sharding = keep_sharding + def tree_flatten(self): + # Carry ``_keep_sharding`` in ``aux_data`` so it survives a pytree + # round-trip (``jit``/``vmap``/``scan``/``grad``). Flatten the *raw* + # ``_value`` rather than the ``value`` property: the property inserts a + # ``with_sharding_constraint``, which must not run during the abstract + # flatten step (the leaf may be a tracer/``ShapeDtypeStruct``). + return (self._value,), self._keep_sharding + + @classmethod + def tree_unflatten(cls, aux_data, flat_contents): + # Reconstruct without ``__init__`` (the leaf may be abstract during + # tracing) and restore ``_keep_sharding`` from ``aux_data``; otherwise + # the ``value`` getter raises ``AttributeError`` after any transform. + ins = object.__new__(cls) + ins._value = flat_contents[0] + ins._keep_sharding = True if aux_data is None else aux_data + return ins + @property def value(self): """The value stored in this array. diff --git a/brainpy/math/others.py b/brainpy/math/others.py index a95a34c80..63a349eb4 100644 --- a/brainpy/math/others.py +++ b/brainpy/math/others.py @@ -94,6 +94,18 @@ def remove_diag(arr): raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') arr = as_jax(arr) m, n = arr.shape + # ``remove_diag`` drops the diagonal element ``[i, i]`` from every row, so it + # only has a well-defined ``(m, n - 1)`` result when every row owns a + # diagonal element, i.e. ``m <= n``. With ``m > n`` the rows ``i >= n`` have + # no diagonal to remove and the off-diagonal element count + # (``m * n - n``) no longer matches ``m * (n - 1)``; the old code crashed + # with an opaque broadcasting/reshape error. Fail fast with a clear message. + if m > n: + raise ValueError( + f'remove_diag requires the number of rows to not exceed the number ' + f'of columns (m <= n), so that every row has a diagonal element to ' + f'remove. But we got a matrix with shape {arr.shape}.' + ) # Static off-diagonal indices (computed with numpy so they are concrete # constants and the gather traces cleanly under jit/vmap). rows = np.repeat(np.arange(m), n - 1) diff --git a/docs/issues-found-20260619-math-core.md b/docs/issues-found-20260619-math-core.md new file mode 100644 index 000000000..b82fe02d1 --- /dev/null +++ b/docs/issues-found-20260619-math-core.md @@ -0,0 +1,173 @@ +# BrainPy math-core fresh audit — 2026-06-19 (P2) + +Scope: `brainpy/math/_utils.py`, `datatypes.py`, `defaults.py`, `environment.py`, +`modes.py`, `ndarray.py`, `scales.py`, `sharding.py`, `others.py`, `remove_vmap.py` +(+ co-located `*_test.py`). + +Environment: jax 0.10.2, brainstate 0.5.1, brainunit 0.5.1, brainevent 0.1.0, +braintools 0.1.10 (CPU-only). `import brainpy` works, so findings tagged +`[verified]` were reproduced at runtime. + +This is a fresh pass. The fixes recorded in `dev/issues-found-20260618.md` +(C-10, M-07, M-08, M-09, M-10, H-10, H-11, H-12, H-14, H-15, L-02, L-03) are all +present in the current tree and were re-verified as correct; they are **not** +re-reported here. The findings below are new. + +## Summary counts +- Critical: 0 +- High: 1 +- Medium: 1 +- Low: 4 +- Fixed: 2 (the High + the Medium) +- Recorded-only: 4 (all Low) + +--- + +### P2-H1 — `ShardedArray` pytree round-trip drops `_keep_sharding` → `AttributeError` under every JAX transform [High] +- File: brainpy/math/ndarray.py:228-288 (root cause: inherited `Array.tree_unflatten` at :110-119; `_keep_sharding` introduced at :242) +- Category: correctness / api-drift +- What: `ShardedArray` adds the slot `_keep_sharding` (set only in `__init__`) but + reuses the base `Array.tree_flatten`/`tree_unflatten`. `tree_flatten` returns + `aux_data=None` and `tree_unflatten` reconstructs via `object.__new__(cls)` + setting only `_value`. So after any pytree round-trip the reconstructed + `ShardedArray` has no `_keep_sharding` attribute, and its `value` getter + (which reads `self._keep_sharding`) raises + `AttributeError: 'ShardedArray' object has no attribute '_keep_sharding'`. +- Why it's a bug: JAX flattens/unflattens pytree leaves on essentially every + transform boundary (`jit`, `vmap`, `scan`/`for_loop`, `grad`, `tree_map`, + `eval_shape`). `ShardedArray` is a registered pytree node and is the wrapper + `brainpy.math.sharding._device_put` returns (so `partition`/`partition_by_*`/ + `device_mesh` all hand back `ShardedArray`s). Passing such an array into any + jitted/vmapped function — the entire point of sharding — crashes. The + `keep_sharding=False` option was also silently lost (reset to the default). +- Repro (verified): + ```python + import jax, jax.numpy as jnp + from brainpy.math.ndarray import ShardedArray + jax.jit(lambda x: x.value + 1.)(ShardedArray(jnp.arange(3.))) + # AttributeError: 'ShardedArray' object has no attribute '_keep_sharding' + ``` +- Fix: Added `ShardedArray.tree_flatten` (returns `(self._value,), self._keep_sharding` + — flattens the raw `_value` to avoid running `with_sharding_constraint` during + the abstract flatten step) and `ShardedArray.tree_unflatten` (reconstructs + `_value` and restores `_keep_sharding` from `aux_data`, defaulting to `True`). +- Tests: `test_shardedarray_pytree_round_trip_preserves_value_and_keep_sharding`, + `test_shardedarray_works_under_jit`, `test_shardedarray_works_under_vmap` + (in `math_core_fixes_test.py`). +- Status: fixed + +--- + +### P2-M1 — `remove_diag` crashes with an opaque error on tall (m > n) matrices [Medium] +- File: brainpy/math/others.py:80-102 +- Category: edge/error +- What: The docstring claims support for any `(M, N)` matrix returning + `(M, N-1)`, but the off-diagonal index construction is inconsistent for + `m > n`: `rows = np.repeat(np.arange(m), n - 1)` yields `m*(n-1)` indices + while `cols` is taken from `~np.eye(m, n)` which has `m*n - n` `True` entries. + When `m > n` these counts differ and the advanced-index gather raises an + opaque `ValueError: Incompatible shapes for broadcasting`. +- Why it's a bug: `remove_diag` removes element `[i, i]` from each row, which is + only well-defined when every row owns a diagonal element, i.e. `m <= n`. The + historical implementation (boolean-mask + reshape) also failed for `m > n`, + just at the reshape step — so this was never supported, but the new error + message is misleading and hard to diagnose. +- Repro (verified): `remove_diag(jnp.arange(12).reshape(4, 3))` → broadcasting + `ValueError` referencing internal gather shapes. +- Fix: Added an explicit guard that raises a clear `ValueError` (matching the + existing `ndim` guard style) explaining the `m <= n` requirement, before the + gather. The `m <= n` path is unchanged. +- Tests: `test_remove_diag_square_and_wide`, + `test_remove_diag_tall_raises_clear_error`, `test_remove_diag_still_rejects_non_2d`. +- Status: fixed + +--- + +### P2-L1 — `IdScaling._reject_overrides` raises a confusing truth-value error for array `bias`/`scale` [Low] +- File: brainpy/math/scales.py:87-98 +- Category: edge/error +- What: `_reject_overrides` does `if bias is not None and bias != 0.` / `scale != 1.`. + When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array + and the `and`/`if` coerces it to bool, raising + `ValueError: The truth value of an array with more than one element is ambiguous`. +- Why it's a bug: misleading error for an unusual-but-legal input. The intent is + to reject non-default overrides; an array override should be rejected with the + intended "IdScaling ignores bias/scale" message, not a numpy truthiness error. +- Repro (verified): `IdScaling().offset_scaling(jnp.zeros(3), bias=jnp.zeros(3))`. +- Fix: recorded only (Low; out of fix scope). Suggested: compare with + `np.ndim(bias) == 0 and bias != 0.` or `np.any(np.asarray(bias) != 0.)`. +- Tests: none +- Status: recorded-only + +--- + +### P2-L2 — `set()` does not validate `bp_object_as_pytree`, unlike `environment()` [Low] +- File: brainpy/math/environment.py:354-442 (vs `environment.__init__` :217-219) +- Category: edge/error / api-drift +- What: `environment.set()` validates `dt`, `mode`, `x64`, `float_`, `int_`, + `bool_`, `complex_`, `numpy_func_return` up front (M-07 fix) but never checks + that `bp_object_as_pytree` is a `bool`. `environment.__init__` does assert it. + So `bm.set(bp_object_as_pytree='nope')` silently stores a string. +- Why it's a bug: minor API inconsistency; a bad value is stored and only + surfaces later where the flag is consumed. Not silently-wrong numerics. +- Repro (verified): `bm.set(bp_object_as_pytree='not a bool')` stores the string. +- Fix: recorded only (Low). Suggested: add + `if bp_object_as_pytree is not None: assert isinstance(bp_object_as_pytree, bool)` + to the validation block. +- Tests: none +- Status: recorded-only + +--- + +### P2-L3 — `keep_constraint` / `_keep_constraint` do not skip `SingleDeviceSharding` (inconsistent with M-09 fix) [Low] +- File: brainpy/math/sharding.py:227-248 +- Category: perf / style +- What: The M-09 fix made `ShardedArray.value` skip inserting + `with_sharding_constraint` for `SingleDeviceSharding` (pure overhead on a + single device). The standalone `keep_constraint`/`_keep_constraint` helpers + still insert the constraint unconditionally. For symmetry they should apply + the same guard. +- Why it's a bug: only a consistency/perf nit — verified that on a single CPU + device XLA elides the constraint to an empty jaxpr (`jax.make_jaxpr` shows no + equations), so there is no real runtime cost in jax 0.10.2. Recorded for + consistency, not correctness. +- Repro: static / `jax.make_jaxpr(keep_constraint)(jnp.arange(3.))` → no eqns. +- Fix: recorded only (Low). Suggested: mirror the `SingleDeviceSharding` check. +- Tests: none +- Status: recorded-only + +--- + +### P2-L4 — `Scaling.transform` raises bare `ZeroDivisionError` on a degenerate `scaled_V_range` [Low] +- File: brainpy/math/scales.py:29-48 +- Category: edge/error +- What: `scale = (V_max - V_min) / (scaled_V_max - scaled_V_min)` divides by zero + when `scaled_V_min == scaled_V_max`, surfacing as a bare `ZeroDivisionError` + with no context. +- Why it's a bug: invalid user input produces an unhelpful error. Low impact — + the exception is already raised, just not descriptive. +- Repro (verified): `Scaling.transform([0., 10.], scaled_V_range=(1., 1.))`. +- Fix: recorded only (Low). Suggested: validate + `scaled_V_max != scaled_V_min` with a clear message. +- Tests: none +- Status: recorded-only + +--- + +## Re-verified as already-correct (prior 2026-06-18 fixes, no action) +- `enable_x64()` / `disable_x64()` keep brainstate `precision` and JAX + `jax_enable_x64` in sync (C-10) — verified: enable→`(64, True)`, disable→`(32, False)`. +- `set()` validates before mutating (M-07). +- `Mode` is hashable and usable in sets / dict keys (H-10). +- `Array.device` is a property returning a `jax.Device`; `device_buffer`, + `block_host_until_ready`, `block_until_ready`, `at` all work (H-11). +- `Array(scalar)` stores an array, `.shape` works (H-12). +- `_compatible_with_brainpy_array` returns `out` when `out=` is given (H-14). +- `remove_diag` traces cleanly under `jit`/`vmap` for `m <= n` (H-15). +- `ShardedArray.value` skips `with_sharding_constraint` on `SingleDeviceSharding` (M-09). +- `get_sharding` warns on a full axis-name mismatch (M-10). +- `remove_vmap` delegates to `brainstate.transform.unvmap`; global-reduction + semantics documented and verified under `vmap`/`jit` (M-08). +- `IdScaling` rejects non-default scalar `bias`/`scale` (L-02). +- base `Array` vs `ShardedArray` value-setter policy documented (L-03). +