Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion brainpy/math/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,21 @@ def gelu(x, approximate=True):
whether to use the approximate or exact formulation.
"""
x = x.value if isinstance(x, Array) else x
# Promote integer / boolean inputs to a floating dtype before computing.
# Without this the ``sqrt(2/pi)`` and ``0.044715`` constants are truncated to
# zero (approximate branch) or the float result is cast back to integer
# (exact branch), silently producing wrong values. This mirrors the
# ``promote_args_inexact`` step in ``jax.nn.gelu``.
x = jnp.asarray(x)
if not jnp.issubdtype(x.dtype, jnp.floating):
x = x.astype(jnp.promote_types(x.dtype, jnp.float32))
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
y = x * cdf
else:
y = jnp.array(x * (jax.lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype)
sqrt_2 = np.sqrt(2).astype(x.dtype)
y = jnp.array(x * (jax.scipy.special.erf(x / sqrt_2) + 1) / 2, dtype=x.dtype)
return y


Expand Down
13 changes: 10 additions & 3 deletions brainpy/math/compat_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,18 @@ def unflatten(x: Union[jax.Array, Array], dim: int, sizes: Sequence[int]) -> Arr
The returned tensor has one more dimension than the input tensor.
The returned tensor shares the same underlying data with this tensor.
"""
assert x.ndim > dim, ('The dimension to be unflattened should be less than the tensor dimension. '
f'Got {dim} and {x.ndim}.')
x = _as_jax_array_(x)
ndim = x.ndim
# Normalise a negative ``dim`` to PyTorch semantics, where ``dim`` indexes
# into ``x.shape`` and may be in ``[-ndim, ndim)``.
canon_dim = dim + ndim if dim < 0 else dim
if not 0 <= canon_dim < ndim:
raise ValueError(
f'Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], '
f'but got {dim}).'
)
shape = x.shape
new_shape = shape[:dim] + tuple(sizes) + shape[dim + 1:]
new_shape = shape[:canon_dim] + tuple(sizes) + shape[canon_dim + 1:]
r = jnp.reshape(x, new_shape)
return _return(r)

Expand Down
26 changes: 12 additions & 14 deletions brainpy/math/compat_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,10 @@ def segment_mean(data, segment_ids):
See https://tensorflow.google.cn/api_docs/python/tf/math/segment_mean

"""
r = jax.ops.segment_sum(_as_jax_array_(data),
_as_jax_array_(segment_ids),
indices_are_sorted=False)
d = jax.ops.segment_sum(jnp.ones_like(data),
_as_jax_array_(segment_ids),
indices_are_sorted=False)
data = _as_jax_array_(data)
segment_ids = _as_jax_array_(segment_ids)
r = jax.ops.segment_sum(data, segment_ids, indices_are_sorted=False)
d = jax.ops.segment_sum(jnp.ones_like(data), segment_ids, indices_are_sorted=False)
return _return(jnp.nan_to_num(r / d))


Expand Down Expand Up @@ -204,12 +202,12 @@ def unsorted_segment_sqrt_n(data, segment_ids, num_segments):
See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_sqrt_n

"""
r = jax.ops.segment_sum(_as_jax_array_(data),
_as_jax_array_(segment_ids),
data = _as_jax_array_(data)
segment_ids = _as_jax_array_(segment_ids)
r = jax.ops.segment_sum(data, segment_ids,
num_segments=num_segments,
indices_are_sorted=False)
d = jax.ops.segment_sum(jnp.ones_like(data),
_as_jax_array_(segment_ids),
d = jax.ops.segment_sum(jnp.ones_like(data), segment_ids,
num_segments=num_segments,
indices_are_sorted=False)
return _return(jnp.nan_to_num(r / jnp.sqrt(d)))
Expand All @@ -221,12 +219,12 @@ def unsorted_segment_mean(data, segment_ids, num_segments):
See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_mean

"""
r = jax.ops.segment_sum(_as_jax_array_(data),
_as_jax_array_(segment_ids),
data = _as_jax_array_(data)
segment_ids = _as_jax_array_(segment_ids)
r = jax.ops.segment_sum(data, segment_ids,
num_segments=num_segments,
indices_are_sorted=False)
d = jax.ops.segment_sum(jnp.ones_like(data),
_as_jax_array_(segment_ids),
d = jax.ops.segment_sum(jnp.ones_like(data), segment_ids,
num_segments=num_segments,
indices_are_sorted=False)
return _return(jnp.nan_to_num(r / d))
Expand Down
161 changes: 161 additions & 0 deletions brainpy/math/math_compat_p3_fixes_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
"""Regression tests for the 2026-06-19 ``math-compat`` audit (P3-* findings).

Covered findings (see ``docs/issues-found-20260619-math-compat.md``):

* P3-H1 (``activations.py``) -- ``gelu`` must promote integer inputs to a
floating dtype before computing; both the approximate and exact branches were
silently wrong on integer input.
* P3-H2 (``compat_pytorch.py``) -- ``unflatten`` must honour a negative ``dim``
(PyTorch semantics) and reject out-of-range dims.
* P3-M1 (``compat_tensorflow.py``) -- ``segment_mean`` / ``unsorted_segment_mean``
/ ``unsorted_segment_sqrt_n`` must convert ``data`` to a jax array before
``jnp.ones_like`` (do not rely on the deprecated implicit ``__jax_array__``).
"""

import jax
import jax.nn as jnn
import jax.numpy as jnp
import numpy as np
import pytest

import brainpy.math as bm
from brainpy.math import (
activations as act,
compat_pytorch as cpt,
compat_tensorflow as ctf,
)


def _j(x):
return bm.as_jax(x)


def _finite(x):
return bool(jnp.all(jnp.isfinite(_j(x))))


# ---------------------------------------------------------------------------
# P3-H1: gelu integer-input promotion
# ---------------------------------------------------------------------------

def test_gelu_integer_input_matches_float_approximate():
"""P3-H1: approximate gelu on int input must equal the float computation."""
xi = jnp.array([1, 2, 3], dtype=jnp.int32)
xf = jnp.array([1., 2., 3.])
ri = np.asarray(_j(act.gelu(xi, approximate=True)))
rf = np.asarray(_j(act.gelu(xf, approximate=True)))
np.testing.assert_allclose(ri, rf, atol=1e-6)
# and it must agree with jax.nn.gelu (the reference implementation)
np.testing.assert_allclose(ri, np.asarray(jnn.gelu(xi, approximate=True)), atol=1e-6)
# specifically: NOT the truncated x/2 result the bug produced
assert not np.allclose(ri, np.asarray(xf) / 2.0)


def test_gelu_integer_input_matches_float_exact():
"""P3-H1: exact gelu on int input must not be truncated back to int."""
xi = jnp.array([1, 2, 3], dtype=jnp.int32)
xf = jnp.array([1., 2., 3.])
ri = act.gelu(xi, approximate=False)
assert jnp.issubdtype(_j(ri).dtype, jnp.floating)
np.testing.assert_allclose(
np.asarray(_j(ri)),
np.asarray(_j(act.gelu(xf, approximate=False))),
atol=1e-6,
)
# reference parity
np.testing.assert_allclose(
np.asarray(_j(ri)), np.asarray(jnn.gelu(xi, approximate=False)), atol=1e-5)


def test_gelu_float_unchanged():
"""The float path must be unchanged by the promotion fix."""
x = bm.asarray([-1., 0., 1., 2.])
for approx in (True, False):
np.testing.assert_allclose(
np.asarray(_j(act.gelu(x, approximate=approx))),
np.asarray(jnn.gelu(_j(x), approximate=approx)),
atol=1e-5,
)


def test_gelu_accepts_brainpy_array_and_is_finite():
x = bm.asarray([-3., -1., 0., 1., 3.])
assert _finite(act.gelu(x, approximate=True))
assert _finite(act.gelu(x, approximate=False))


# ---------------------------------------------------------------------------
# P3-H2: unflatten negative dim
# ---------------------------------------------------------------------------

def test_unflatten_negative_dim():
"""P3-H2: negative dim must be normalised like torch.unflatten."""
x = bm.asarray(jnp.arange(6.))
r = cpt.unflatten(x, -1, (2, 3))
assert _j(r).shape == (2, 3)
# equivalent to the positive-dim call
np.testing.assert_allclose(
np.asarray(_j(r)), np.asarray(_j(cpt.unflatten(x, 0, (2, 3)))))


def test_unflatten_negative_dim_higher_rank():
x = bm.asarray(jnp.arange(24.).reshape(2, 12))
r = cpt.unflatten(x, -1, (3, 4))
assert _j(r).shape == (2, 3, 4)
r2 = cpt.unflatten(x, -2, (1, 2))
Comment on lines +102 to +106

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Strengthen unflatten negative-dim tests by asserting content equality, not just shapes.

These tests validate shapes for dim=-1 and dim=-2 but not that the data layout matches the expected reshape. To better guard against regressions where new_shape has the right length but incorrect structure, also assert that np.asarray(_j(r)) and np.asarray(_j(r2)) equal the results of jnp.reshape with the expected target shapes.

assert _j(r2).shape == (1, 2, 12)


def test_unflatten_positive_dim_still_works():
x = bm.asarray(jnp.arange(6.))
assert _j(cpt.unflatten(x, 0, (2, 3))).shape == (2, 3)
assert _j(cpt.unflatten(x, 0, (-1, 3))).shape == (2, 3)


def test_unflatten_dim_out_of_range():
x = bm.asarray(jnp.arange(6.))
with pytest.raises((ValueError, AssertionError, IndexError)):
cpt.unflatten(x, 5, (2, 3))
with pytest.raises((ValueError, AssertionError, IndexError)):
cpt.unflatten(x, -5, (2, 3))


# ---------------------------------------------------------------------------
# P3-M1: TF segment helpers must not lean on implicit __jax_array__
# ---------------------------------------------------------------------------

def test_segment_mean_array_input():
data = bm.asarray([1., 2., 3., 4.])
seg = bm.asarray([0, 0, 1, 1])
np.testing.assert_allclose(np.asarray(_j(ctf.segment_mean(data, seg))), [1.5, 3.5])


def test_unsorted_segment_mean_array_input():
data = bm.asarray([1., 2., 3., 4.])
seg = bm.asarray([0, 0, 1, 1])
np.testing.assert_allclose(
np.asarray(_j(ctf.unsorted_segment_mean(data, seg, 2))), [1.5, 3.5])


def test_unsorted_segment_sqrt_n_array_input():
data = bm.asarray([1., 1., 1., 1.])
seg = bm.asarray([0, 0, 1, 1])
# sum over 2-element segments divided by sqrt(2)
np.testing.assert_allclose(
np.asarray(_j(ctf.unsorted_segment_sqrt_n(data, seg, 2))),
[2.0 / np.sqrt(2.0), 2.0 / np.sqrt(2.0)],
atol=1e-6,
)


def test_unsorted_segment_mean_under_jit():
"""The denominator (``jnp.ones_like``) must trace cleanly under jit.

``unsorted_segment_mean`` takes a static ``num_segments`` so it is
jit-compatible (unlike ``segment_mean`` which infers it from the data).
"""
data = jnp.array([1., 2., 3., 4.])
seg = jnp.array([0, 0, 1, 1])
f = jax.jit(lambda d: bm.as_jax(ctf.unsorted_segment_mean(bm.asarray(d), bm.asarray(seg), 2)))
np.testing.assert_allclose(np.asarray(f(data)), [1.5, 3.5])
Loading
Loading