fix(math): ShardedArray pytree flatten + remove_diag guard#839
Merged
Conversation
- ShardedArray pytree round-trip dropped _keep_sharding, raising AttributeError under every JAX transform (jit/vmap/scan/grad/tree_map); add tree_flatten/tree_unflatten carrying _keep_sharding in aux_data (High) - remove_diag raised an opaque broadcasting error on tall (m>n) matrices; add a clear ValueError guard (Medium) Findings recorded in docs/issues-found-20260619-math-core.md
Reviewer's GuideAdds proper JAX pytree flatten/unflatten support for ShardedArray so _keep_sharding survives transformations, introduces input-shape validation to remove_diag for tall matrices, and records the math-core audit findings alongside new regression tests. Flow diagram for remove_diag shape validation on tall matricesflowchart TD
A["remove_diag(arr)"] --> B{"arr.ndim == 2"}
B -->|no| C["raise ValueError: only support 2D matrix"]
B -->|yes| D["as_jax(arr); m, n = arr.shape"]
D --> E{"m > n"}
E -->|yes| F["raise ValueError: m <= n required so each row has a diagonal"]
E -->|no| G["compute rows, cols off-diagonal indices"]
G --> H["return arr[rows, cols].reshape(m, n - 1)"]
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 1 issue, and left some high level feedback:
- In
ShardedArray.tree_unflatten, consider delegating toArray.tree_unflatten(e.g., viains = super(ShardedArray, cls).tree_unflatten(...)or refactoring into a shared helper) to avoid duplicating the base reconstruction logic and to stay resilient to future changes inArray’s internal layout. - The new
remove_diagshape check is good, but the error message is quite long; you might shorten it while preserving the key constraint (e.g., mention onlym <= nand the actual shape) to keep exception messages concise and easier to scan in logs.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `ShardedArray.tree_unflatten`, consider delegating to `Array.tree_unflatten` (e.g., via `ins = super(ShardedArray, cls).tree_unflatten(...)` or refactoring into a shared helper) to avoid duplicating the base reconstruction logic and to stay resilient to future changes in `Array`’s internal layout.
- The new `remove_diag` shape check is good, but the error message is quite long; you might shorten it while preserving the key constraint (e.g., mention only `m <= n` and the actual shape) to keep exception messages concise and easier to scan in logs.
## Individual Comments
### Comment 1
<location path="docs/issues-found-20260619-math-core.md" line_range="90-91" />
<code_context>
+### P2-L1 — `IdScaling._reject_overrides` raises a confusing truth-value error for array `bias`/`scale` [Low]
+- File: brainpy/math/scales.py:87-98
+- Category: edge/error
+- What: `_reject_overrides` does `if bias is not None and bias != 0.` / `scale != 1.`.
+ When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array
+ and the `and`/`if` coerces it to bool, raising
+ `ValueError: The truth value of an array with more than one element is ambiguous`.
+- Why it's a bug: misleading error for an unusual-but-legal input. The intent is
+ to reject non-default overrides; an array override should be rejected with the
+ intended "IdScaling ignores bias/scale" message, not a numpy truthiness error.
</code_context>
<issue_to_address>
**nitpick (typo):** Minor subject-verb agreement issue in the description of `_reject_overrides`.
Since the compound subject is "`and`/`if`" (plural), this should read "coerce it to bool" rather than "coerces it to bool".
```suggestion
When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array
and the `and`/`if` coerce it to bool, raising
```
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
Comment on lines
+90
to
+91
| When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array | ||
| and the `and`/`if` coerces it to bool, raising |
There was a problem hiding this comment.
nitpick (typo): Minor subject-verb agreement issue in the description of _reject_overrides.
Since the compound subject is "and/if" (plural), this should read "coerce it to bool" rather than "coerces it to bool".
Suggested change
| When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array | |
| and the `and`/`if` coerces it to bool, raising | |
| When called with a non-scalar array `bias`/`scale`, `bias != 0.` is an array | |
| and the `and`/`if` coerce it to bool, raising |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fresh review of
brainpy/mathcore.High —
ShardedArraypytree round-trip dropped_keep_sharding, raisingAttributeErrorunder every JAX transform (jit/vmap/scan/grad/tree_map). Addedtree_flatten/tree_unflattencarrying_keep_shardingin aux_data.Medium —
remove_diagraised an opaque broadcasting error on tall (m>n) matrices; added a clearValueError.Regression tests added (106 passed in-scope). Findings:
docs/issues-found-20260619-math-core.md.Summary by Sourcery
Fix ShardedArray pytree behavior and clarify remove_diag shape constraints, adding regression coverage and documenting findings from a math-core audit.
Bug Fixes:
Documentation:
Tests: