Skip to content

fix(math/compat): gelu int input, unflatten negative dim, segment_mean Array#844

Merged
chaoming0625 merged 1 commit into
masterfrom
fix/audit-20260619-math-compat
Jun 18, 2026
Merged

fix(math/compat): gelu int input, unflatten negative dim, segment_mean Array#844
chaoming0625 merged 1 commit into
masterfrom
fix/audit-20260619-math-compat

Conversation

@chaoming0625

@chaoming0625 chaoming0625 commented Jun 18, 2026

Copy link
Copy Markdown
Member

Fresh review of brainpy/math NumPy/PyTorch/TF compat + activations.

  • Highgelu silently wrong on integer input (now promotes to float like jax.nn.gelu).
  • Highcompat_pytorch.unflatten ignored negative dim (torch-style dim=-1 raised); now normalized + range-checked.
  • Medium — TF-compat segment_mean/unsorted_segment_mean/unsorted_segment_sqrt_n passed an un-converted bm.Array to jnp.ones_like (relied on JAX's removed implicit __jax_array__); now converted once.

Regression tests added (119 passed in-scope). Findings: docs/issues-found-20260619-math-compat.md.

Summary by Sourcery

Fix math compatibility issues in activations and TensorFlow/PyTorch shims and add regression coverage and documentation for the 2026-06-19 audit.

Bug Fixes:

  • Correct gelu to promote integer and boolean inputs to floating dtypes and align the exact formulation with current JAX APIs.
  • Update compat_pytorch.unflatten to honor negative dimensions with PyTorch-style semantics and to validate dimension ranges.
  • Fix TensorFlow-compat segment_mean and unsorted_segment_* helpers to consistently convert inputs to JAX arrays before constructing denominators, avoiding reliance on deprecated implicit array conversion.

Documentation:

  • Add an audit report documenting math compatibility and activation issues, their impact, and fix status for the 2026-06-19 review.

Tests:

  • Add regression test suite covering gelu integer-input behavior, compat_pytorch.unflatten negative and out-of-range dims, and TensorFlow segment helpers on brainpy Array inputs and under jit.

…_mean Array

- gelu was silently wrong on integer input (both approximate and exact
  branches); promote to floating dtype like jax.nn.gelu (High)
- compat_pytorch.unflatten ignored negative dim (raised on torch-style
  dim=-1); normalize and range-check dim (High)
- compat_tensorflow segment_mean/unsorted_segment_mean/unsorted_segment_sqrt_n
  passed an un-converted brainpy Array to jnp.ones_like (relied on the removed
  implicit __jax_array__); convert once (Medium)

Findings recorded in docs/issues-found-20260619-math-compat.md
@chaoming0625 chaoming0625 merged commit efc7dc9 into master Jun 18, 2026
@sourcery-ai

sourcery-ai Bot commented Jun 18, 2026

Copy link
Copy Markdown

Reviewer's Guide

Fixes three math compatibility bugs: GELU now correctly handles non-floating inputs and uses a modern erf implementation, PyTorch-compat unflatten now normalizes and validates negative dims, and TensorFlow-compat segment_mean helpers correctly convert BrainPy Arrays to JAX arrays before building the ones-like denominator; adds focused regression tests and an audit notes doc.

File-Level Changes

Change Details Files
Ensure TensorFlow-compat segment_* mean helpers explicitly convert inputs to JAX arrays and avoid relying on deprecated implicit jax_array conversion.
  • Convert data and segment_ids once at the top of segment_mean, unsorted_segment_mean, and unsorted_segment_sqrt_n using as_jax_array.
  • Call jax.ops.segment_sum with already-converted data and segment_ids instead of re-wrapping each argument.
  • Compute denominators using jnp.ones_like(data) on the JAX array to keep JIT tracing and future JAX versions working.
brainpy/math/compat_tensorflow.py
Align PyTorch-compat unflatten with torch semantics by supporting negative dim indices and performing explicit range checking.
  • Convert x to a JAX array before inspecting ndim/shape.
  • Canonicalize dim into canon_dim in [0, ndim) by adding ndim when dim is negative.
  • Raise a ValueError with a PyTorch-style error message when the (possibly negative) dim is out of range instead of relying on a simple assertion.
  • Build the reshaped target shape using canon_dim, preserving behavior for non-negative dims while enabling negative dims.
brainpy/math/compat_pytorch.py
Make GELU numerically correct for integer/boolean inputs and modernize the exact branch implementation.
  • Normalize x to a JAX array and, when the dtype is not floating, promote to an inexact dtype via jnp.promote_types(..., jnp.float32) before computing GELU.
  • In the approximate branch, keep using the tanh-based formulation but with constants computed in the promoted floating dtype.
  • In the exact branch, switch from jax.lax.erf to jax.scipy.special.erf and precompute sqrt(2) in x.dtype for stability, constructing the final result as a float array rather than truncating back to integer.
brainpy/math/activations.py
Add regression tests and documentation for the math-compat audit findings around GELU, PyTorch unflatten, and TF segment helpers.
  • Introduce math_compat_p3_fixes_test.py with regression tests that pin GELU behavior for int inputs, including parity with jax.nn.gelu and unchanged float behavior.
  • Add tests verifying unflatten with negative dims, higher-rank tensors, valid positive dims, and out-of-range dim error handling.
  • Add tests for TF-compat segment_mean and unsorted_segment_* handling of BrainPy Array inputs plus a JIT path for unsorted_segment_mean.
  • Document the math-compat audit findings, their impact, repros, and fix status in docs/issues-found-20260619-math-compat.md.
brainpy/math/math_compat_p3_fixes_test.py
docs/issues-found-20260619-math-compat.md

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 2 issues

Prompt for AI Agents
Please address the comments from this code review:

## Individual Comments

### Comment 1
<location path="brainpy/math/math_compat_p3_fixes_test.py" line_range="102-106" />
<code_context>
+        np.asarray(_j(r)), np.asarray(_j(cpt.unflatten(x, 0, (2, 3)))))
+
+
+def test_unflatten_negative_dim_higher_rank():
+    x = bm.asarray(jnp.arange(24.).reshape(2, 12))
+    r = cpt.unflatten(x, -1, (3, 4))
+    assert _j(r).shape == (2, 3, 4)
+    r2 = cpt.unflatten(x, -2, (1, 2))
+    assert _j(r2).shape == (1, 2, 12)
+
</code_context>
<issue_to_address>
**suggestion (testing):** Strengthen `unflatten` negative-dim tests by asserting content equality, not just shapes.

These tests validate shapes for `dim=-1` and `dim=-2` but not that the data layout matches the expected reshape. To better guard against regressions where `new_shape` has the right length but incorrect structure, also assert that `np.asarray(_j(r))` and `np.asarray(_j(r2))` equal the results of `jnp.reshape` with the expected target shapes.
</issue_to_address>

### Comment 2
<location path="docs/issues-found-20260619-math-compat.md" line_range="21-26" />
<code_context>
+- Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns
+  `[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]`
+  (matches `jax.nn.gelu`). `approximate=False` on the same input returns
+  `[0,0,0,1]`. JAX's own `gelu` promotes the argument to inexact first.
+- Repro:
+  ```python
</code_context>
<issue_to_address>
**issue (typo):** Integer GELU example output length seems inconsistent with the input length.

In the P3-H1 section, the `approximate=False` example states that `gelu(jnp.array([1,2,3], jnp.int32), approximate=False)` returns `[0,0,0,1]`, but the input has only three elements. This is likely a typo in the documented output (e.g., `[0,0,1]` or another three-element vector). Please correct the example to match the input shape to avoid confusion when reproducing it.

```suggestion
  - `approximate=False`: `jnp.array(..., dtype=x.dtype)` truncates the float
    result back to int.
- Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns
  `[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]`
  (matches `jax.nn.gelu`). `approximate=False` on the same input returns
  `[0,0,2]`. JAX's own `gelu` promotes the argument to inexact first.
```
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment on lines +102 to +106
def test_unflatten_negative_dim_higher_rank():
x = bm.asarray(jnp.arange(24.).reshape(2, 12))
r = cpt.unflatten(x, -1, (3, 4))
assert _j(r).shape == (2, 3, 4)
r2 = cpt.unflatten(x, -2, (1, 2))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Strengthen unflatten negative-dim tests by asserting content equality, not just shapes.

These tests validate shapes for dim=-1 and dim=-2 but not that the data layout matches the expected reshape. To better guard against regressions where new_shape has the right length but incorrect structure, also assert that np.asarray(_j(r)) and np.asarray(_j(r2)) equal the results of jnp.reshape with the expected target shapes.

Comment on lines +21 to +26
- `approximate=False`: `jnp.array(..., dtype=x.dtype)` truncates the float
result back to int.
- Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns
`[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]`
(matches `jax.nn.gelu`). `approximate=False` on the same input returns
`[0,0,0,1]`. JAX's own `gelu` promotes the argument to inexact first.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (typo): Integer GELU example output length seems inconsistent with the input length.

In the P3-H1 section, the approximate=False example states that gelu(jnp.array([1,2,3], jnp.int32), approximate=False) returns [0,0,0,1], but the input has only three elements. This is likely a typo in the documented output (e.g., [0,0,1] or another three-element vector). Please correct the example to match the input shape to avoid confusion when reproducing it.

Suggested change
- `approximate=False`: `jnp.array(..., dtype=x.dtype)` truncates the float
result back to int.
- Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns
`[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]`
(matches `jax.nn.gelu`). `approximate=False` on the same input returns
`[0,0,0,1]`. JAX's own `gelu` promotes the argument to inexact first.
- `approximate=False`: `jnp.array(..., dtype=x.dtype)` truncates the float
result back to int.
- Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns
`[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]`
(matches `jax.nn.gelu`). `approximate=False` on the same input returns
`[0,0,2]`. JAX's own `gelu` promotes the argument to inexact first.

@github-actions github-actions Bot added documentation Improvements or additions to documentation tests labels Jun 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant