Skip to content

fix(math): ShardedArray pytree flatten + remove_diag guard#839

Merged
chaoming0625 merged 1 commit into
masterfrom
fix/audit-20260619-math-core
Jun 18, 2026
Merged

fix(math): ShardedArray pytree flatten + remove_diag guard#839
chaoming0625 merged 1 commit into
masterfrom
fix/audit-20260619-math-core

Conversation

@chaoming0625

@chaoming0625 chaoming0625 commented Jun 18, 2026

Copy link
Copy Markdown
Member

Fresh review of brainpy/math core.

HighShardedArray pytree round-trip dropped _keep_sharding, raising AttributeError under every JAX transform (jit/vmap/scan/grad/tree_map). Added tree_flatten/tree_unflatten carrying _keep_sharding in aux_data.

Mediumremove_diag raised an opaque broadcasting error on tall (m>n) matrices; added a clear ValueError.

Regression tests added (106 passed in-scope). Findings: docs/issues-found-20260619-math-core.md.

Summary by Sourcery

Fix ShardedArray pytree behavior and clarify remove_diag shape constraints, adding regression coverage and documenting findings from a math-core audit.

Bug Fixes:

  • Preserve ShardedArray value and _keep_sharding flags across JAX pytree flatten/unflatten so it works under jit/vmap and other transforms.
  • Guard remove_diag against tall (m > n) matrices by raising a clear ValueError while keeping existing behavior for supported shapes.

Documentation:

  • Document results of the 2026-06-19 BrainPy math-core audit, including newly fixed and recorded issues.

Tests:

  • Add regression tests covering ShardedArray pytree round-trips and use under jit/vmap, and the new remove_diag shape guard plus existing valid/invalid cases.

- ShardedArray pytree round-trip dropped _keep_sharding, raising
  AttributeError under every JAX transform (jit/vmap/scan/grad/tree_map);
  add tree_flatten/tree_unflatten carrying _keep_sharding in aux_data (High)
- remove_diag raised an opaque broadcasting error on tall (m>n) matrices;
  add a clear ValueError guard (Medium)

Findings recorded in docs/issues-found-20260619-math-core.md
@chaoming0625 chaoming0625 merged commit 4f1fbb0 into master Jun 18, 2026
1 of 4 checks passed
@sourcery-ai

sourcery-ai Bot commented Jun 18, 2026

Copy link
Copy Markdown

Reviewer's Guide

Adds proper JAX pytree flatten/unflatten support for ShardedArray so _keep_sharding survives transformations, introduces input-shape validation to remove_diag for tall matrices, and records the math-core audit findings alongside new regression tests.

Flow diagram for remove_diag shape validation on tall matrices

flowchart TD
  A["remove_diag(arr)"] --> B{"arr.ndim == 2"}
  B -->|no| C["raise ValueError: only support 2D matrix"]
  B -->|yes| D["as_jax(arr); m, n = arr.shape"]
  D --> E{"m > n"}
  E -->|yes| F["raise ValueError: m <= n required so each row has a diagonal"]
  E -->|no| G["compute rows, cols off-diagonal indices"]
  G --> H["return arr[rows, cols].reshape(m, n - 1)"]
Loading

File-Level Changes

Change Details Files
Ensure ShardedArray pytree round-trips preserve both the underlying value and the _keep_sharding flag under all JAX transforms.
  • Implement a ShardedArray.tree_flatten method that flattens the raw _value and carries _keep_sharding in aux_data to avoid running sharding constraints during abstract evaluation.
  • Implement a ShardedArray.tree_unflatten classmethod that reconstructs instances without calling init, restoring both _value and _keep_sharding (defaulting to True when aux_data is None).
  • Add regression tests that exercise ShardedArray pytree round-trip behavior directly and under jax.jit and jax.vmap, asserting value correctness and persistence of _keep_sharding.
brainpy/math/ndarray.py
brainpy/math/math_core_fixes_test.py
Make remove_diag fail fast with a clear ValueError on tall matrices while keeping behavior unchanged for valid inputs.
  • Add an m > n guard before index construction in remove_diag that raises a descriptive ValueError explaining the m <= n requirement and echoing the input shape.
  • Keep the existing off-diagonal index computation and behavior unchanged for square and wide matrices that satisfy m <= n.
  • Extend tests to cover the valid square and wide cases, the new tall-matrix error path, and the existing ndim validation for non-2D inputs.
brainpy/math/others.py
brainpy/math/math_core_fixes_test.py
Document results of a fresh audit of the math core, including fixed and recorded-only issues.
  • Add a markdown report summarizing the 2026-06-19 math-core audit scope, environment, issue classifications, and statuses.
  • Record detailed descriptions of the ShardedArray pytree and remove_diag issues as fixed, and several low-priority findings as recorded-only without code changes.
  • List previously fixed issues from an earlier audit as re-verified without new actions.
docs/issues-found-20260619-math-core.md

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@github-actions github-actions Bot added documentation Improvements or additions to documentation tests labels Jun 18, 2026

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 1 issue, and left some high level feedback:

  • In ShardedArray.tree_unflatten, consider delegating to Array.tree_unflatten (e.g., via ins = super(ShardedArray, cls).tree_unflatten(...) or refactoring into a shared helper) to avoid duplicating the base reconstruction logic and to stay resilient to future changes in Array’s internal layout.
  • The new remove_diag shape check is good, but the error message is quite long; you might shorten it while preserving the key constraint (e.g., mention only m <= n and the actual shape) to keep exception messages concise and easier to scan in logs.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In `ShardedArray.tree_unflatten`, consider delegating to `Array.tree_unflatten` (e.g., via `ins = super(ShardedArray, cls).tree_unflatten(...)` or refactoring into a shared helper) to avoid duplicating the base reconstruction logic and to stay resilient to future changes in `Array`’s internal layout.
- The new `remove_diag` shape check is good, but the error message is quite long; you might shorten it while preserving the key constraint (e.g., mention only `m <= n` and the actual shape) to keep exception messages concise and easier to scan in logs.

## Individual Comments

### Comment 1
<location path="docs/issues-found-20260619-math-core.md" line_range="90-91" />
<code_context>
+### P2-L1 — `IdScaling._reject_overrides` raises a confusing truth-value error for array `bias`/`scale`  [Low]
+- File: brainpy/math/scales.py:87-98
+- Category: edge/error
+- What: `_reject_overrides` does `if bias is not None and bias != 0.` / `scale != 1.`.
+  When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array
+  and the `and`/`if` coerces it to bool, raising
+  `ValueError: The truth value of an array with more than one element is ambiguous`.
+- Why it's a bug: misleading error for an unusual-but-legal input. The intent is
+  to reject non-default overrides; an array override should be rejected with the
+  intended "IdScaling ignores bias/scale" message, not a numpy truthiness error.
</code_context>
<issue_to_address>
**nitpick (typo):** Minor subject-verb agreement issue in the description of `_reject_overrides`.

Since the compound subject is "`and`/`if`" (plural), this should read "coerce it to bool" rather than "coerces it to bool".

```suggestion
  When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array
  and the `and`/`if` coerce it to bool, raising
```
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment on lines +90 to +91
When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array
and the `and`/`if` coerces it to bool, raising

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick (typo): Minor subject-verb agreement issue in the description of _reject_overrides.

Since the compound subject is "and/if" (plural), this should read "coerce it to bool" rather than "coerces it to bool".

Suggested change
When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array
and the `and`/`if` coerces it to bool, raising
When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array
and the `and`/`if` coerce it to bool, raising

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant