diff --git a/brainpy/math/activations.py b/brainpy/math/activations.py index f3d66408c..37287ff84 100644 --- a/brainpy/math/activations.py +++ b/brainpy/math/activations.py @@ -169,12 +169,21 @@ def gelu(x, approximate=True): whether to use the approximate or exact formulation. """ x = x.value if isinstance(x, Array) else x + # Promote integer / boolean inputs to a floating dtype before computing. + # Without this the ``sqrt(2/pi)`` and ``0.044715`` constants are truncated to + # zero (approximate branch) or the float result is cast back to integer + # (exact branch), silently producing wrong values. This mirrors the + # ``promote_args_inexact`` step in ``jax.nn.gelu``. + x = jnp.asarray(x) + if not jnp.issubdtype(x.dtype, jnp.floating): + x = x.astype(jnp.promote_types(x.dtype, jnp.float32)) if approximate: sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3)))) y = x * cdf else: - y = jnp.array(x * (jax.lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype) + sqrt_2 = np.sqrt(2).astype(x.dtype) + y = jnp.array(x * (jax.scipy.special.erf(x / sqrt_2) + 1) / 2, dtype=x.dtype) return y diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index d9939a64a..71e8671e6 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -117,11 +117,18 @@ def unflatten(x: Union[jax.Array, Array], dim: int, sizes: Sequence[int]) -> Arr The returned tensor has one more dimension than the input tensor. The returned tensor shares the same underlying data with this tensor. """ - assert x.ndim > dim, ('The dimension to be unflattened should be less than the tensor dimension. ' - f'Got {dim} and {x.ndim}.') x = _as_jax_array_(x) + ndim = x.ndim + # Normalise a negative ``dim`` to PyTorch semantics, where ``dim`` indexes + # into ``x.shape`` and may be in ``[-ndim, ndim)``. + canon_dim = dim + ndim if dim < 0 else dim + if not 0 <= canon_dim < ndim: + raise ValueError( + f'Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], ' + f'but got {dim}).' + ) shape = x.shape - new_shape = shape[:dim] + tuple(sizes) + shape[dim + 1:] + new_shape = shape[:canon_dim] + tuple(sizes) + shape[canon_dim + 1:] r = jnp.reshape(x, new_shape) return _return(r) diff --git a/brainpy/math/compat_tensorflow.py b/brainpy/math/compat_tensorflow.py index 6f5a07063..89fe406ad 100644 --- a/brainpy/math/compat_tensorflow.py +++ b/brainpy/math/compat_tensorflow.py @@ -137,12 +137,10 @@ def segment_mean(data, segment_ids): See https://tensorflow.google.cn/api_docs/python/tf/math/segment_mean """ - r = jax.ops.segment_sum(_as_jax_array_(data), - _as_jax_array_(segment_ids), - indices_are_sorted=False) - d = jax.ops.segment_sum(jnp.ones_like(data), - _as_jax_array_(segment_ids), - indices_are_sorted=False) + data = _as_jax_array_(data) + segment_ids = _as_jax_array_(segment_ids) + r = jax.ops.segment_sum(data, segment_ids, indices_are_sorted=False) + d = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, indices_are_sorted=False) return _return(jnp.nan_to_num(r / d)) @@ -204,12 +202,12 @@ def unsorted_segment_sqrt_n(data, segment_ids, num_segments): See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_sqrt_n """ - r = jax.ops.segment_sum(_as_jax_array_(data), - _as_jax_array_(segment_ids), + data = _as_jax_array_(data) + segment_ids = _as_jax_array_(segment_ids) + r = jax.ops.segment_sum(data, segment_ids, num_segments=num_segments, indices_are_sorted=False) - d = jax.ops.segment_sum(jnp.ones_like(data), - _as_jax_array_(segment_ids), + d = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, num_segments=num_segments, indices_are_sorted=False) return _return(jnp.nan_to_num(r / jnp.sqrt(d))) @@ -221,12 +219,12 @@ def unsorted_segment_mean(data, segment_ids, num_segments): See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_mean """ - r = jax.ops.segment_sum(_as_jax_array_(data), - _as_jax_array_(segment_ids), + data = _as_jax_array_(data) + segment_ids = _as_jax_array_(segment_ids) + r = jax.ops.segment_sum(data, segment_ids, num_segments=num_segments, indices_are_sorted=False) - d = jax.ops.segment_sum(jnp.ones_like(data), - _as_jax_array_(segment_ids), + d = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, num_segments=num_segments, indices_are_sorted=False) return _return(jnp.nan_to_num(r / d)) diff --git a/brainpy/math/math_compat_p3_fixes_test.py b/brainpy/math/math_compat_p3_fixes_test.py new file mode 100644 index 000000000..522adb1eb --- /dev/null +++ b/brainpy/math/math_compat_p3_fixes_test.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +"""Regression tests for the 2026-06-19 ``math-compat`` audit (P3-* findings). + +Covered findings (see ``docs/issues-found-20260619-math-compat.md``): + +* P3-H1 (``activations.py``) -- ``gelu`` must promote integer inputs to a + floating dtype before computing; both the approximate and exact branches were + silently wrong on integer input. +* P3-H2 (``compat_pytorch.py``) -- ``unflatten`` must honour a negative ``dim`` + (PyTorch semantics) and reject out-of-range dims. +* P3-M1 (``compat_tensorflow.py``) -- ``segment_mean`` / ``unsorted_segment_mean`` + / ``unsorted_segment_sqrt_n`` must convert ``data`` to a jax array before + ``jnp.ones_like`` (do not rely on the deprecated implicit ``__jax_array__``). +""" + +import jax +import jax.nn as jnn +import jax.numpy as jnp +import numpy as np +import pytest + +import brainpy.math as bm +from brainpy.math import ( + activations as act, + compat_pytorch as cpt, + compat_tensorflow as ctf, +) + + +def _j(x): + return bm.as_jax(x) + + +def _finite(x): + return bool(jnp.all(jnp.isfinite(_j(x)))) + + +# --------------------------------------------------------------------------- +# P3-H1: gelu integer-input promotion +# --------------------------------------------------------------------------- + +def test_gelu_integer_input_matches_float_approximate(): + """P3-H1: approximate gelu on int input must equal the float computation.""" + xi = jnp.array([1, 2, 3], dtype=jnp.int32) + xf = jnp.array([1., 2., 3.]) + ri = np.asarray(_j(act.gelu(xi, approximate=True))) + rf = np.asarray(_j(act.gelu(xf, approximate=True))) + np.testing.assert_allclose(ri, rf, atol=1e-6) + # and it must agree with jax.nn.gelu (the reference implementation) + np.testing.assert_allclose(ri, np.asarray(jnn.gelu(xi, approximate=True)), atol=1e-6) + # specifically: NOT the truncated x/2 result the bug produced + assert not np.allclose(ri, np.asarray(xf) / 2.0) + + +def test_gelu_integer_input_matches_float_exact(): + """P3-H1: exact gelu on int input must not be truncated back to int.""" + xi = jnp.array([1, 2, 3], dtype=jnp.int32) + xf = jnp.array([1., 2., 3.]) + ri = act.gelu(xi, approximate=False) + assert jnp.issubdtype(_j(ri).dtype, jnp.floating) + np.testing.assert_allclose( + np.asarray(_j(ri)), + np.asarray(_j(act.gelu(xf, approximate=False))), + atol=1e-6, + ) + # reference parity + np.testing.assert_allclose( + np.asarray(_j(ri)), np.asarray(jnn.gelu(xi, approximate=False)), atol=1e-5) + + +def test_gelu_float_unchanged(): + """The float path must be unchanged by the promotion fix.""" + x = bm.asarray([-1., 0., 1., 2.]) + for approx in (True, False): + np.testing.assert_allclose( + np.asarray(_j(act.gelu(x, approximate=approx))), + np.asarray(jnn.gelu(_j(x), approximate=approx)), + atol=1e-5, + ) + + +def test_gelu_accepts_brainpy_array_and_is_finite(): + x = bm.asarray([-3., -1., 0., 1., 3.]) + assert _finite(act.gelu(x, approximate=True)) + assert _finite(act.gelu(x, approximate=False)) + + +# --------------------------------------------------------------------------- +# P3-H2: unflatten negative dim +# --------------------------------------------------------------------------- + +def test_unflatten_negative_dim(): + """P3-H2: negative dim must be normalised like torch.unflatten.""" + x = bm.asarray(jnp.arange(6.)) + r = cpt.unflatten(x, -1, (2, 3)) + assert _j(r).shape == (2, 3) + # equivalent to the positive-dim call + np.testing.assert_allclose( + np.asarray(_j(r)), np.asarray(_j(cpt.unflatten(x, 0, (2, 3))))) + + +def test_unflatten_negative_dim_higher_rank(): + x = bm.asarray(jnp.arange(24.).reshape(2, 12)) + r = cpt.unflatten(x, -1, (3, 4)) + assert _j(r).shape == (2, 3, 4) + r2 = cpt.unflatten(x, -2, (1, 2)) + assert _j(r2).shape == (1, 2, 12) + + +def test_unflatten_positive_dim_still_works(): + x = bm.asarray(jnp.arange(6.)) + assert _j(cpt.unflatten(x, 0, (2, 3))).shape == (2, 3) + assert _j(cpt.unflatten(x, 0, (-1, 3))).shape == (2, 3) + + +def test_unflatten_dim_out_of_range(): + x = bm.asarray(jnp.arange(6.)) + with pytest.raises((ValueError, AssertionError, IndexError)): + cpt.unflatten(x, 5, (2, 3)) + with pytest.raises((ValueError, AssertionError, IndexError)): + cpt.unflatten(x, -5, (2, 3)) + + +# --------------------------------------------------------------------------- +# P3-M1: TF segment helpers must not lean on implicit __jax_array__ +# --------------------------------------------------------------------------- + +def test_segment_mean_array_input(): + data = bm.asarray([1., 2., 3., 4.]) + seg = bm.asarray([0, 0, 1, 1]) + np.testing.assert_allclose(np.asarray(_j(ctf.segment_mean(data, seg))), [1.5, 3.5]) + + +def test_unsorted_segment_mean_array_input(): + data = bm.asarray([1., 2., 3., 4.]) + seg = bm.asarray([0, 0, 1, 1]) + np.testing.assert_allclose( + np.asarray(_j(ctf.unsorted_segment_mean(data, seg, 2))), [1.5, 3.5]) + + +def test_unsorted_segment_sqrt_n_array_input(): + data = bm.asarray([1., 1., 1., 1.]) + seg = bm.asarray([0, 0, 1, 1]) + # sum over 2-element segments divided by sqrt(2) + np.testing.assert_allclose( + np.asarray(_j(ctf.unsorted_segment_sqrt_n(data, seg, 2))), + [2.0 / np.sqrt(2.0), 2.0 / np.sqrt(2.0)], + atol=1e-6, + ) + + +def test_unsorted_segment_mean_under_jit(): + """The denominator (``jnp.ones_like``) must trace cleanly under jit. + + ``unsorted_segment_mean`` takes a static ``num_segments`` so it is + jit-compatible (unlike ``segment_mean`` which infers it from the data). + """ + data = jnp.array([1., 2., 3., 4.]) + seg = jnp.array([0, 0, 1, 1]) + f = jax.jit(lambda d: bm.as_jax(ctf.unsorted_segment_mean(bm.asarray(d), bm.asarray(seg), 2))) + np.testing.assert_allclose(np.asarray(f(data)), [1.5, 3.5]) diff --git a/docs/issues-found-20260619-math-compat.md b/docs/issues-found-20260619-math-compat.md new file mode 100644 index 000000000..3b0f9229a --- /dev/null +++ b/docs/issues-found-20260619-math-compat.md @@ -0,0 +1,152 @@ +# Audit — brainpy.math compatibility / activations / linalg / fft (2026-06-19) + +Scope: `brainpy/math/{activations,compat_numpy,compat_pytorch,compat_tensorflow,fft,interoperability,linalg}.py` + +Environment: jax 0.10.2, jaxlib 0.10.1, numpy 2.4.6 (CPU). Implicit +`__jax_array__` still honoured by `jax.numpy` in this version, so brainpy +`Array` leaves are still implicitly convertible. + +--- + +### P3-H1 — `gelu` silently wrong for integer (non-floating) inputs [High] +- File: brainpy/math/activations.py:146-178 +- Category: correctness / dtype +- What: `gelu` reads `x.dtype` and forces the result to that dtype without first + promoting the input to an inexact type. For an integer input both branches + are wrong: + - `approximate=True`: `sqrt_2_over_pi = np.sqrt(2/pi).astype(x.dtype)` truncates + the constant `0.7978…` to `0` for an int dtype, and `0.044715` likewise + vanishes, so the whole `tanh(...)` argument collapses to `0`, giving + `cdf = 0.5` and `gelu(x) = x/2`. + - `approximate=False`: `jnp.array(..., dtype=x.dtype)` truncates the float + result back to int. +- Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns + `[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]` + (matches `jax.nn.gelu`). `approximate=False` on the same input returns + `[0,0,0,1]`. JAX's own `gelu` promotes the argument to inexact first. +- Repro: + ```python + import brainpy.math.activations as act, jax.numpy as jnp + act.gelu(jnp.array([1,2,3], jnp.int32), approximate=True) # [0.5,1.,1.5] WRONG + ``` +- Fix: promote `x` to an inexact dtype (`jnp.promote_types(x.dtype, jnp.float32)` + / `jnp.asarray(x) * 1.0` style) before computing, mirroring `jax.nn.gelu`. +- Tests: test_gelu_integer_input_matches_float (both branches) +- Status: fixed + +### P3-H2 — `unflatten` ignores negative `dim` [High] +- File: brainpy/math/compat_pytorch.py:104-126 +- Category: correctness +- What: `unflatten(x, dim, sizes)` builds `shape[:dim] + sizes + shape[dim+1:]` + without normalising a negative `dim`. PyTorch's `torch.unflatten` accepts + negative `dim`. With `dim=-1` on a `(6,)` array the slices become + `shape[:-1]=()` and `shape[0:]=(6,)`, yielding the target shape `(2,3,6)` and a + reshape failure (size 6 → 36). The leading `assert x.ndim > dim` is also wrong + for negative `dim` (always true) and is evaluated on the raw input before + `_as_jax_array_`. +- Why it's a bug: `unflatten(arange(6), -1, (2,3))` raises `TypeError: cannot + reshape array of shape (6,) into shape (2,3,6)` instead of returning `(2,3)`. +- Repro: + ```python + import brainpy.math.compat_pytorch as cpt, brainpy.math as bm, jax.numpy as jnp + cpt.unflatten(bm.asarray(jnp.arange(6.)), -1, (2,3)) # raises, should be (2,3) + ``` +- Fix: normalise `dim` (`dim += x.ndim` when negative) and validate against + `x.ndim` after canonicalisation; convert to jax array before indexing `.shape`. +- Tests: test_unflatten_negative_dim, test_unflatten_dim_out_of_range +- Status: fixed + +### P3-M1 — `jnp.ones_like(data)` passed a brainpy `Array` in TF segment helpers [Medium] +- File: brainpy/math/compat_tensorflow.py:143,211,228 (segment_mean, + unsorted_segment_sqrt_n, unsorted_segment_mean) +- Category: api-drift / fragility +- What: the denominator is computed with `jnp.ones_like(data)` where `data` is + the *un-converted* argument (possibly a brainpy `Array`), while every other + argument is funnelled through `_as_jax_array_`. This only works because + `jax.numpy` still honours the implicit `__jax_array__` protocol; the audit + brief flags that protocol as scheduled for removal (JAX ≥ 0.9). The + inconsistency is a latent break. +- Why it's a bug: once implicit `__jax_array__` is dropped, `jnp.ones_like()` raises and `segment_mean` / `unsorted_segment_mean` / + `unsorted_segment_sqrt_n` crash on brainpy inputs (the common case). +- Repro: static (currently works via the deprecated protocol). +- Fix: convert once (`data = _as_jax_array_(data)`) and use `jnp.ones_like` on + the jax array. +- Tests: test_segment_mean_array_input, test_unsorted_segment_mean_array_input, + test_unsorted_segment_sqrt_n_array_input +- Status: fixed + +### P3-M2 — `gelu(approximate=False)` uses deprecated `jax.lax.erf` [Low] +- File: brainpy/math/activations.py:177 +- Category: api-drift +- What: uses `jax.lax.erf`, which is being phased out in favour of + `jax.scipy.special.erf` / `lax.erfc` (jax.nn.gelu switched to `lax.erfc`). Not + yet a warning on 0.10.2, but a drift risk. +- Why it's a bug: future removal would break the exact-GELU branch. +- Repro: static. +- Fix: recorded only (folded into the P3-H1 rewrite, which uses + `jax.scipy.special.erf`). +- Tests: covered indirectly by test_gelu_integer_input_matches_float. +- Status: fixed (as part of P3-H1) + +### P3-L1 — `one_hot` uses deprecated `jax.core.concrete_or_error` [Low] +- File: brainpy/math/activations.py:444 +- Category: api-drift +- What: `jax.core.concrete_or_error` emits a `DeprecationWarning` ("Use + jax.extend.core.concrete_or_error") on every `one_hot` call. +- Why it's a bug: noisy deprecation; future removal risk. +- Repro: `import warnings; act.one_hot(jnp.array([0,1]), 2)` warns. +- Fix: recorded only (Low). +- Tests: none +- Status: recorded-only + +### P3-L2 — `one_hot` hard-codes `jnp.float64` default dtype [Low] +- File: brainpy/math/activations.py:446 +- Category: numerics / api-drift +- What: `dtype = canonicalize_dtype(jnp.float64 if dtype is None else dtype)`. + With `jax_enable_x64` off (the default) this is canonicalised to float32, but + the literal `float64` path is more roundabout than `jnp.float_`/`None`. + Harmless on current default config (the canonicalize call absorbs it) so + recorded only. +- Why it's a bug: minor; no incorrect output on the default config. +- Repro: static. +- Fix: recorded only (Low). +- Tests: none +- Status: recorded-only + +### P3-L3 — `asfarray` forces `jnp.float64` → spurious truncation warning [Low] +- File: brainpy/math/compat_numpy.py:217-220 +- Category: numerics +- What: the H-13 fix coerces integer input to `jnp.float64`; with x64 off this + emits "Explicitly requested dtype float64 … will be truncated to float32" on + every call. Output is correct (float32), only noisy. +- Why it's a bug: noisy warning; cosmetic. +- Repro: `import warnings; cn.asfarray([1,2,3])` warns. +- Fix: recorded only (Low) — changing the dtype literal is a behaviour tweak + outside the verified-bug remit and the existing H-13 test pins float output. +- Tests: none +- Status: recorded-only + +### P3-L4 — `interoperability.from_numpy` returns a numpy ndarray [Low] +- File: brainpy/math/interoperability.py:119-120 +- Category: api / naming +- What: `from_numpy(arr)` delegates to `as_ndarray`, returning a + `numpy.ndarray`. The PyTorch-style name implies a conversion *into* the + framework array type (jax/brainpy). Long-standing BrainPy behaviour; callers + may depend on it. +- Why it's a bug: misleading name; not a correctness defect. +- Repro: `type(io.from_numpy(np.arange(3))) # numpy.ndarray`. +- Fix: recorded only (Low) — behaviour change with downstream risk. +- Tests: none +- Status: recorded-only + +### P3-L5 — `as_device_array` docstring references removed `DeviceArray` [Low] +- File: brainpy/math/interoperability.py:39-64 +- Category: style / docs +- What: docstring and name reference `jax.numpy.DeviceArray`, a type removed from + modern JAX (now `jax.Array`). Function itself is correct. +- Why it's a bug: stale docs. +- Repro: static. +- Fix: recorded only (Low). +- Tests: none +- Status: recorded-only