fix(math/compat): gelu int input, unflatten negative dim, segment_mean Array#844
Conversation
…_mean Array - gelu was silently wrong on integer input (both approximate and exact branches); promote to floating dtype like jax.nn.gelu (High) - compat_pytorch.unflatten ignored negative dim (raised on torch-style dim=-1); normalize and range-check dim (High) - compat_tensorflow segment_mean/unsorted_segment_mean/unsorted_segment_sqrt_n passed an un-converted brainpy Array to jnp.ones_like (relied on the removed implicit __jax_array__); convert once (Medium) Findings recorded in docs/issues-found-20260619-math-compat.md
Reviewer's GuideFixes three math compatibility bugs: GELU now correctly handles non-floating inputs and uses a modern erf implementation, PyTorch-compat unflatten now normalizes and validates negative dims, and TensorFlow-compat segment_mean helpers correctly convert BrainPy Arrays to JAX arrays before building the ones-like denominator; adds focused regression tests and an audit notes doc. File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 2 issues
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location path="brainpy/math/math_compat_p3_fixes_test.py" line_range="102-106" />
<code_context>
+ np.asarray(_j(r)), np.asarray(_j(cpt.unflatten(x, 0, (2, 3)))))
+
+
+def test_unflatten_negative_dim_higher_rank():
+ x = bm.asarray(jnp.arange(24.).reshape(2, 12))
+ r = cpt.unflatten(x, -1, (3, 4))
+ assert _j(r).shape == (2, 3, 4)
+ r2 = cpt.unflatten(x, -2, (1, 2))
+ assert _j(r2).shape == (1, 2, 12)
+
</code_context>
<issue_to_address>
**suggestion (testing):** Strengthen `unflatten` negative-dim tests by asserting content equality, not just shapes.
These tests validate shapes for `dim=-1` and `dim=-2` but not that the data layout matches the expected reshape. To better guard against regressions where `new_shape` has the right length but incorrect structure, also assert that `np.asarray(_j(r))` and `np.asarray(_j(r2))` equal the results of `jnp.reshape` with the expected target shapes.
</issue_to_address>
### Comment 2
<location path="docs/issues-found-20260619-math-compat.md" line_range="21-26" />
<code_context>
+- Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns
+ `[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]`
+ (matches `jax.nn.gelu`). `approximate=False` on the same input returns
+ `[0,0,0,1]`. JAX's own `gelu` promotes the argument to inexact first.
+- Repro:
+ ```python
</code_context>
<issue_to_address>
**issue (typo):** Integer GELU example output length seems inconsistent with the input length.
In the P3-H1 section, the `approximate=False` example states that `gelu(jnp.array([1,2,3], jnp.int32), approximate=False)` returns `[0,0,0,1]`, but the input has only three elements. This is likely a typo in the documented output (e.g., `[0,0,1]` or another three-element vector). Please correct the example to match the input shape to avoid confusion when reproducing it.
```suggestion
- `approximate=False`: `jnp.array(..., dtype=x.dtype)` truncates the float
result back to int.
- Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns
`[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]`
(matches `jax.nn.gelu`). `approximate=False` on the same input returns
`[0,0,2]`. JAX's own `gelu` promotes the argument to inexact first.
```
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| def test_unflatten_negative_dim_higher_rank(): | ||
| x = bm.asarray(jnp.arange(24.).reshape(2, 12)) | ||
| r = cpt.unflatten(x, -1, (3, 4)) | ||
| assert _j(r).shape == (2, 3, 4) | ||
| r2 = cpt.unflatten(x, -2, (1, 2)) |
There was a problem hiding this comment.
suggestion (testing): Strengthen unflatten negative-dim tests by asserting content equality, not just shapes.
These tests validate shapes for dim=-1 and dim=-2 but not that the data layout matches the expected reshape. To better guard against regressions where new_shape has the right length but incorrect structure, also assert that np.asarray(_j(r)) and np.asarray(_j(r2)) equal the results of jnp.reshape with the expected target shapes.
| - `approximate=False`: `jnp.array(..., dtype=x.dtype)` truncates the float | ||
| result back to int. | ||
| - Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns | ||
| `[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]` | ||
| (matches `jax.nn.gelu`). `approximate=False` on the same input returns | ||
| `[0,0,0,1]`. JAX's own `gelu` promotes the argument to inexact first. |
There was a problem hiding this comment.
issue (typo): Integer GELU example output length seems inconsistent with the input length.
In the P3-H1 section, the approximate=False example states that gelu(jnp.array([1,2,3], jnp.int32), approximate=False) returns [0,0,0,1], but the input has only three elements. This is likely a typo in the documented output (e.g., [0,0,1] or another three-element vector). Please correct the example to match the input shape to avoid confusion when reproducing it.
| - `approximate=False`: `jnp.array(..., dtype=x.dtype)` truncates the float | |
| result back to int. | |
| - Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns | |
| `[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]` | |
| (matches `jax.nn.gelu`). `approximate=False` on the same input returns | |
| `[0,0,0,1]`. JAX's own `gelu` promotes the argument to inexact first. | |
| - `approximate=False`: `jnp.array(..., dtype=x.dtype)` truncates the float | |
| result back to int. | |
| - Why it's a bug: `gelu(jnp.array([1,2,3], int32), approximate=True)` returns | |
| `[0.5, 1.0, 1.5]` instead of the correct `[0.8412, 1.9546, 2.9964]` | |
| (matches `jax.nn.gelu`). `approximate=False` on the same input returns | |
| `[0,0,2]`. JAX's own `gelu` promotes the argument to inexact first. |
Fresh review of
brainpy/mathNumPy/PyTorch/TF compat + activations.gelusilently wrong on integer input (now promotes to float likejax.nn.gelu).compat_pytorch.unflattenignored negativedim(torch-styledim=-1raised); now normalized + range-checked.segment_mean/unsorted_segment_mean/unsorted_segment_sqrt_npassed an un-convertedbm.Arraytojnp.ones_like(relied on JAX's removed implicit__jax_array__); now converted once.Regression tests added (119 passed in-scope). Findings:
docs/issues-found-20260619-math-compat.md.Summary by Sourcery
Fix math compatibility issues in activations and TensorFlow/PyTorch shims and add regression coverage and documentation for the 2026-06-19 audit.
Bug Fixes:
Documentation:
Tests: