Skip to content

fix(optim,losses): audited correctness fixes (Adan, SM3, losses)#838

Merged
chaoming0625 merged 2 commits into
masterfrom
fix/audit-20260619-optim-losses
Jun 18, 2026
Merged

fix(optim,losses): audited correctness fixes (Adan, SM3, losses)#838
chaoming0625 merged 2 commits into
masterfrom
fix/audit-20260619-optim-losses

Conversation

@chaoming0625

@chaoming0625 chaoming0625 commented Jun 18, 2026

Copy link
Copy Markdown
Member

Fixes from fresh review of brainpy/optim + brainpy/losses.

Critical

  • Adan.update crashed on every call (lax.cond operand splat) and its step counter was frozen at 0, disabling bias correction + Nesterov.

High

  • SM3 raised KeyError for scalar (0-dim) variables.
  • multi_margin_loss crashed on bm.Array under JAX>=0.9.

Medium

  • l1_loss functional default reduction changed 'sum'->'mean' to match the L1Loss class + docstring.

Regression tests added; in-scope suite: 145 passed, 1 skipped. Full findings: docs/issues-found-20260619-optim-losses.md.

Summary by Sourcery

Fix optimizer and loss correctness issues uncovered by an audit, and add regression tests and documentation for the findings.

Bug Fixes:

  • Correct Adan optimizer so its update step runs without crashing, maintains its own per-update step counter, and enables proper bias correction and Nesterov behavior.
  • Allow SM3 optimizer to handle scalar (0‑dimensional) trainable variables by registering and using a scalar accumulator instead of erroring.
  • Make multi_margin_loss work with bm.Array inputs under newer JAX versions by converting arguments to JAX arrays before indexing.
  • Align the functional l1_loss default reduction with its class wrapper and documentation by changing the default from 'sum' to 'mean'.
  • Fix l1_loss regression tests to match the actual per-row L1 norm semantics of the underlying metric implementation.

Enhancements:

  • Add dedicated regression tests covering Adan, SM3, l1_loss, and multi_margin_loss behavior to prevent recurrence of audited issues.
  • Document the full set of audit findings and statuses for optimizers and losses in a new issues report document.

- Adan.update crashed every call (lax.cond operand splat) and froze its
  step counter, disabling bias correction + the Nesterov term (2 Critical)
- SM3 raised KeyError on scalar (0-dim) variables (High)
- multi_margin_loss crashed on bm.Array under JAX>=0.9 (High)
- l1_loss functional default reduction 'sum' -> 'mean' to match L1Loss (Medium)

Findings recorded in docs/issues-found-20260619-optim-losses.md
@sourcery-ai

sourcery-ai Bot commented Jun 18, 2026

Copy link
Copy Markdown

Reviewer's Guide

Fixes several audited correctness and API issues in optimizers Adan and SM3 and in l1_loss/multi_margin_loss, and adds focused regression tests plus an audit report document.

File-Level Changes

Change Details Files
Fix Adan optimizer so update() runs correctly, maintains its own per-step counter, and no longer crashes due to lax.cond operand misuse.
  • Introduce a dedicated bm.Variable step counter on Adan that is incremented once per update() and used for bias correction terms instead of lr.last_epoch.
  • Replace the jax.lax.cond-based first-step previous-gradient handling with a jnp.where-based selection that uses g on step 1 and prev_g thereafter, avoiding operand splatting issues.
  • Adjust Adan bias-correction formulas to use the true per-update step t (no +1 offset) so correction factors evolve over training.
  • Add new regression tests that train with Adan (with and without no_prox), assert finite parameter updates, check the step counter advances, and verify expected behavior of the gradient-difference accumulator and Nesterov term; update the previous coverage test that asserted Adan was broken.
brainpy/optim/optimizer.py
brainpy/optim/optimizer_coverage_test.py
brainpy/optim/optimizer_test.py
Fix SM3 optimizer behavior on scalar (0‑dim) trainable variables so it no longer raises KeyError and instead behaves like an Adagrad-style update.
  • In SM3.register_train_vars, detect 0‑dim variables and register a single scalar accumulator m0 so scalar parameters have a valid accumulator; early-continue to skip per-axis cover construction.
  • In SM3.update, clamp ndim to at least 1 (max(p.ndim, 1)) to match the scalar-accumulator registration path and ensure m0 is always available.
  • Add regression tests that train SM3 on scalar variables, asserting finite updates, Adagrad-like behavior after a step, and correct operation when momentum/beta are enabled; update the old coverage test that previously asserted the scalar case was broken.
brainpy/optim/optimizer.py
brainpy/optim/optimizer_coverage_test.py
brainpy/optim/optimizer_test.py
Align l1_loss semantics and tests with the underlying braintools implementation and intended public API, including changing the default reduction to 'mean'.
  • Change the functional l1_loss default reduction from 'sum' to 'mean' to match the L1Loss class, its docstring, and PyTorch behavior.
  • Update TestReductionDefaults to expect l1_loss() to default to 'mean' and to match the L1Loss().update default behavior.
  • Correct regression tests for l1_loss reductions to assert the actual per-row L1 norm behavior of the delegated braintools.metric.l1_loss (none: [3,7], sum:10, mean:5) and adjust the L1Loss class test accordingly.
brainpy/losses/comparison.py
brainpy/losses/comparison_test.py
brainpy/losses/comparison_coverage_test.py
Make multi_margin_loss robust to bm.Array inputs under newer JAX by explicitly converting to JAX arrays before indexing and computation.
  • Convert predicts and targets at the start of multi_margin_loss using bm.as_jax to avoid ValueError from JAX>=0.9 when advanced-indexing bm.Array objects.
  • Add tests ensuring multi_margin_loss accepts bm.Array inputs and that its outputs match those computed on plain JAX arrays for the same data.
brainpy/losses/comparison.py
brainpy/losses/comparison_test.py
Add a documented audit report describing all issues found in the 2026‑06‑19 optim/losses audit, including fixes implemented and recorded-only items.
  • Introduce docs/issues-found-20260619-optim-losses.md summarizing critical, high, medium, and low-priority findings across optimizers and losses, distinguishing fixed vs recorded-only issues.
  • Remove the previous issues-found-20260618.md document which is superseded by the new audit report.
docs/issues-found-20260619-optim-losses.md
docs/issues-found-20260618.md

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625 chaoming0625 merged commit 368e28e into master Jun 18, 2026
8 of 14 checks passed
@github-actions github-actions Bot added documentation Improvements or additions to documentation tests labels Jun 18, 2026

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've reviewed your changes and they look great!


Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant