Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion brainpy/losses/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def update(self, input: ArrayType, target: ArrayType) -> ArrayType:
return l1_loss(input, target, reduction=self.reduction)


def l1_loss(logits, targets, reduction='sum'):
def l1_loss(logits, targets, reduction='mean'):
r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
the logits :math:`x` and targets :math:`y`. It is useful in regression problems.

Expand Down Expand Up @@ -1045,6 +1045,10 @@ def multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean'):
a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`.
"""
assert p == 1 or p == 2, 'p should be 1 or 2'
# Convert to plain JAX arrays: under JAX >= 0.9 implicit __jax_array__
# coercion was removed, so advanced-indexing a ``bm.Array`` would raise.
predicts = bm.as_jax(predicts)
targets = bm.as_jax(targets)
batch_size = predicts.shape[0]
correct_scores = predicts[jnp.arange(batch_size), targets]
margins = jnp.power(jnp.maximum(0, predicts - correct_scores[:, jnp.newaxis] + margin), p)
Expand Down
19 changes: 12 additions & 7 deletions brainpy/losses/comparison_coverage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,21 +221,26 @@ def test_class_wrapper(self):
# ---------------------------------------------------------------------------
class TestRegressionLosses:
def test_l1_loss_reductions(self):
# NOTE: l1_loss now delegates to braintools.metric.l1_loss, which
# computes the per-row MEAN absolute error (not the L1 norm) for
# reduction='none', then sums / means those per-row values.
# P1-L1: l1_loss delegates to braintools.metric.l1_loss, which for
# reduction='none' returns the per-row L1 *norm* (sum of abs over the
# trailing axes, reshaped to (N, -1)), NOT the per-row mean. So for the
# (2, 2) input below the 'none' output is the per-row sums [3, 7]; 'sum'
# then totals them (10) and 'mean' averages them (5). (The previous
# expectations of [1.5, 3.5]/5/2.5 encoded an incorrect per-row-mean
# assumption about braintools and were pre-existing baseline failures.)
x = jnp.array([[1., 2.], [3., 4.]])
y = jnp.zeros((2, 2))
none = np.asarray(C.l1_loss(x, y, reduction='none'))
assert np.allclose(none, [1.5, 3.5]) # per-row mean abs error
assert float(C.l1_loss(x, y, reduction='sum')) == pytest.approx(5.0)
assert float(C.l1_loss(x, y, reduction='mean')) == pytest.approx(2.5)
assert np.allclose(none, [3.0, 7.0]) # per-row L1 norm (sum of abs)
assert float(C.l1_loss(x, y, reduction='sum')) == pytest.approx(10.0)
assert float(C.l1_loss(x, y, reduction='mean')) == pytest.approx(5.0)

def test_l1_class(self):
x = jnp.array([[1., 2.], [3., 4.]])
y = jnp.zeros((2, 2))
layer = C.L1Loss(reduction='sum')
assert float(layer.update(x, y)) == pytest.approx(5.0)
# sum over per-row L1 norms [3, 7] = 10.0
assert float(layer.update(x, y)) == pytest.approx(10.0)

def test_l2_loss_elementwise(self):
out = np.asarray(C.l2_loss(jnp.array([2.0, 0.0]), jnp.array([0.0, 0.0])))
Expand Down
35 changes: 33 additions & 2 deletions brainpy/losses/comparison_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,16 @@ def test_mean_absolute_error(self):
class TestReductionDefaults(unittest.TestCase):
"""Public default ``reduction`` values must not change."""

def test_l1_loss_default_is_sum(self):
def test_l1_loss_default_is_mean(self):
# P1-M1 fix: the functional ``l1_loss`` default reduction is now
# ``'mean'`` (matching the ``L1Loss`` class, the docstring and PyTorch),
# not the surprising ``'sum'`` it previously defaulted to.
pred, tar = _arr(3, 4, seed=1), _arr(3, 4, seed=2)
self.assertTrue(_close(L.l1_loss(pred, tar),
L.l1_loss(pred, tar, reduction='sum')))
L.l1_loss(pred, tar, reduction='mean')))
# and it must equal the OO wrapper's default.
self.assertTrue(_close(L.l1_loss(pred, tar),
L.L1Loss().update(pred, tar)))

def test_mse_default_is_mean(self):
pred, tar = _arr(3, 4, seed=1), _arr(3, 4, seed=2)
Expand Down Expand Up @@ -186,5 +192,30 @@ def test_multi_margin_loss_runs(self):
self.assertIsNotNone(L.multi_margin_loss(logits, targets))


class TestMultiMarginArrayEnvelope(unittest.TestCase):
"""P1-H2: ``multi_margin_loss`` must accept ``bm.Array`` inputs.

Under JAX >= 0.9 implicit ``__jax_array__`` coercion was removed, so
indexing a ``bm.Array`` with ``jnp`` advanced indexing raised
``ValueError: Triggering __jax_array__() ... no longer supported``. Every
other loss in the module accepts ``bm.Array``; this one must too.
"""

def test_multi_margin_accepts_bm_array(self):
predicts = bm.asarray(np.array([[0.2, 0.8], [0.6, 0.4]]))
targets = bm.asarray(np.array([1, 0]))
out = L.multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean')
self.assertTrue(np.isfinite(float(out)))

def test_multi_margin_bm_matches_jax(self):
p_np = np.array([[0.2, 0.8, 0.1], [0.6, 0.4, 0.9]])
t_np = np.array([1, 0])
bm_out = np.asarray(L.multi_margin_loss(bm.asarray(p_np), bm.asarray(t_np),
p=2, reduction='none'))
jax_out = np.asarray(L.multi_margin_loss(bm.as_jax(bm.asarray(p_np)),
t_np, p=2, reduction='none'))
self.assertTrue(_close(bm_out, jax_out))


if __name__ == '__main__':
unittest.main()
36 changes: 28 additions & 8 deletions brainpy/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Union, Sequence, Dict, Optional, Tuple

import jax.numpy as jnp
from jax.lax import cond

import brainpy.math as bm
from brainpy import check
Expand Down Expand Up @@ -772,6 +771,10 @@ def __init__(
self.eps = eps
self.weight_decay = weight_decay
self.no_prox = no_prox
# Per-update step counter for bias correction (see ``Adam``). It must be
# independent of the LR scheduler's ``last_epoch`` (which the optimizer
# never advances) and advance exactly once per ``update()``.
self.step = bm.Variable(jnp.asarray(0))

def __repr__(self):
return (f"{self.__class__.__name__}(lr={self.lr}, "
Expand Down Expand Up @@ -804,18 +807,26 @@ def _update_moments(self, m, n, v, pre_g, g):

def update(self, grads: dict):
self.check_grads(grads)
lr = self.lr()
step = self.lr.last_epoch.value + 1
correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1))
correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1))
correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1))
lr = bm.as_jax(self.lr())
# Advance the per-update step counter (t = 1 on the first update). The
# bias-correction terms use ``(1 - beta) ** t`` (matching the reference
# Adan), so they actually evolve over training instead of being frozen.
self.step.value = self.step.value + 1
step = self.step.value
correct_m = 1 / (1 - (1 - self.betas[0]) ** step)
correct_v = 1 / (1 - (1 - self.betas[1]) ** step)
correct_n = 1 / (1 - (1 - self.betas[2]) ** step)
for key, p_var in self.vars_to_train.items():
m_var = self.implicit_vars[key + '_m']
n_var = self.implicit_vars[key + '_n']
v_var = self.implicit_vars[key + '_v']
prev_g_var = self.implicit_vars[key + '_prev_grad']
g = grads[key]
pre_g = cond(step == 0, lambda pg, g: g, lambda pg, g: pg, (prev_g_var.value, g))
# On the first update there is no previous gradient, so the gradient
# difference must be 0 (i.e. ``pre_g := g``). Use a value-level
# ``where`` rather than ``lax.cond`` (whose operand is splatted into
# the branch functions, which was the source of the crash).
pre_g = jnp.where(step == 1, g, prev_g_var.value)
diff = g - pre_g
m = m_var.value * (1 - self.betas[0]) + self.betas[0] * g
v = v_var.value * (1 - self.betas[1]) + self.betas[1] * diff
Expand Down Expand Up @@ -1082,6 +1093,13 @@ def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = Non
vs = dict()
for k, v in train_vars.items():
rank, ndim = v.shape, v.ndim
if ndim == 0:
# A 0-dim (scalar) variable has no axes to build a cover over.
# Register a single scalar accumulator so SM3 degenerates to an
# Adagrad-like update (otherwise ``update`` would ``KeyError`` on
# the missing ``{k}_m0`` accumulator).
vs[f'{k}_m0'] = bm.Variable(bm.zeros((), dtype=v.dtype))
continue
for i in range(ndim):
shape = [1] * ndim
shape[i] = rank[i]
Expand All @@ -1098,7 +1116,9 @@ def update(self, grads: dict):

for k, p in self.vars_to_train.items():
g = grads[k]
ndim = p.ndim
# Match the rank-1 fallback used when registering accumulators for
# scalar variables (see ``register_train_vars``).
ndim = max(p.ndim, 1)
update = self.implicit_vars[f'{k}_m0']
for i in range(1, ndim):
update = bm.minimum(update, self.implicit_vars[f'{k}_m{i}'])
Expand Down
28 changes: 13 additions & 15 deletions brainpy/optim/optimizer_coverage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,16 @@ def test_zero_norm_branch(self):


class TestAdan:
def test_update_is_currently_broken(self):
# NOTE (defect): Adan.update calls
# cond(step == 0, lambda pg, g: g, lambda pg, g: pg, (prev_g_var.value, g))
# jax.lax.cond unpacks *operands, so passing the single 2-tuple binds
# pg=(prev, g) and leaves g unbound -> TypeError. Adan.update therefore
# cannot run as written (both no_prox branches are unreachable).
def test_update_runs(self):
# P1-C1/P1-C2 fix: Adan.update used to crash (jax.lax.cond operand
# splatting) and its step counter was frozen at 0. It now runs and the
# per-update step counter advances.
v = _make_var(2.0)
opt = O.Adan(lr=1e-2, train_vars={'w': v})
assert 'no_prox' in repr(opt)
with pytest.raises(TypeError):
opt.update({'w': bm.as_jax(v.value)})
out = _train(opt, {'w': v}, lambda val: val)
assert np.isfinite(out['w'])
assert int(bm.as_jax(opt.step.value)) == 5

def test_invalid_eps(self):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -245,16 +244,15 @@ def test_invalid_hyperparams(self):


class TestSM3:
def test_scalar_var_is_broken(self):
# NOTE (defect): SM3.register_train_vars loops ``for i in range(ndim)``;
# for a scalar (0-dim) variable no ``_m{i}`` accumulator gets created,
# yet ``update`` reads ``{k}_m0`` -> KeyError. SM3 only works for
# >=1-dim variables.
def test_scalar_var_runs(self):
# P1-H1 fix: SM3 used to KeyError('w_m0') for a scalar (0-dim) variable
# because no accumulator was registered. It now registers a single
# scalar accumulator (Adagrad-like) and updates correctly.
v = _make_var(2.0)
opt = O.SM3(lr=0.1, train_vars={'w': v})
assert 'beta' in repr(opt)
with pytest.raises(KeyError):
opt.update({'w': bm.as_jax(v.value)})
out = _train(opt, {'w': v}, lambda val: np.ones_like(val))
assert np.isfinite(out['w'])

def test_1d_var(self):
v = bm.Variable(bm.asarray(np.array([1.0, 2.0], dtype=np.float32)))
Expand Down
119 changes: 119 additions & 0 deletions brainpy/optim/optimizer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Regression tests for the 2026-06-19 optim/losses audit.

These pin the bug fixes for:

- P1-C1/P1-C2: ``Adan.update`` crashed on every call (``jax.lax.cond`` operand
mis-binding) and its step counter was frozen at 0 (bias correction / Nesterov
term disabled).
- P1-H1: ``SM3`` raised ``KeyError`` for scalar (0-dim) trainable variables.
"""

import numpy as np

import brainpy.math as bm
from brainpy.optim import optimizer as O


def _vec_var(values):
return bm.Variable(bm.asarray(np.asarray(values, dtype=np.float32)))


def _scalar_var(value=2.0):
return bm.Variable(bm.asarray(np.asarray(value, dtype=np.float32)))


def _train(opt, var_dict, grad_fn, steps=5):
for _ in range(steps):
grads = {k: grad_fn(bm.as_jax(v.value)) for k, v in var_dict.items()}
opt.update(grads)
return {k: np.asarray(bm.as_jax(v.value)) for k, v in var_dict.items()}


# ---------------------------------------------------------------------------
# Adan (P1-C1, P1-C2)
# ---------------------------------------------------------------------------
class TestAdanFixed:
def test_adan_runs_and_updates(self):
# Previously raised TypeError on the very first update.
v = _vec_var([2.0, -3.0])
opt = O.Adan(lr=1e-2, train_vars={'w': v})
out = _train(opt, {'w': v}, lambda val: val, steps=5)
assert np.all(np.isfinite(out['w']))
# gradient = value -> the parameter should move toward zero.
assert np.all(np.abs(out['w']) < np.array([2.0, 3.0]))

def test_adan_no_prox_runs(self):
v = _vec_var([2.0, -3.0])
opt = O.Adan(lr=1e-2, train_vars={'w': v}, no_prox=True)
out = _train(opt, {'w': v}, lambda val: val, steps=5)
assert np.all(np.isfinite(out['w']))

def test_adan_step_counter_advances(self):
v = _vec_var([1.0])
opt = O.Adan(lr=1e-2, train_vars={'w': v})
for _ in range(4):
opt.update({'w': bm.as_jax(v.value)})
assert int(bm.as_jax(opt.step.value)) == 4

def test_adan_first_step_diff_is_zero(self):
# On the very first update the gradient difference (g - g_prev) must be
# treated as 0, so the exp_avg_diff (``_v``) accumulator stays 0 after
# one step regardless of the gradient magnitude.
v = _vec_var([5.0])
opt = O.Adan(lr=1e-2, train_vars={'w': v})
opt.update({'w': bm.as_jax(v.value)})
v_diff = np.asarray(bm.as_jax(opt.implicit_vars['w_v'].value))
assert np.allclose(v_diff, 0.0)

def test_adan_nesterov_term_active_after_two_steps(self):
# With the step counter frozen the diff term was permanently 0; here a
# changing gradient must produce a non-zero exp_avg_diff after step 2.
v = _vec_var([1.0])
opt = O.Adan(lr=1e-2, train_vars={'w': v})
opt.update({'w': np.asarray([1.0], dtype=np.float32)})
opt.update({'w': np.asarray([3.0], dtype=np.float32)}) # gradient changed
v_diff = np.asarray(bm.as_jax(opt.implicit_vars['w_v'].value))
assert not np.allclose(v_diff, 0.0)


# ---------------------------------------------------------------------------
# SM3 (P1-H1)
# ---------------------------------------------------------------------------
class TestSM3Fixed:
def test_sm3_scalar_var_runs(self):
# Previously raised KeyError('w_m0') for a 0-dim variable.
v = _scalar_var(2.0)
opt = O.SM3(lr=0.1, train_vars={'w': v})
out = _train(opt, {'w': v}, lambda val: np.ones_like(val), steps=4)
assert np.all(np.isfinite(out['w']))
# gradient is +1 each step -> scalar parameter must decrease.
assert float(out['w']) < 2.0

def test_sm3_scalar_matches_adagrad_like_step(self):
# For a scalar with constant gradient g=1, SM3 reduces to an Adagrad-like
# update: cache accumulates g^2, step = lr * g / sqrt(cache + eps).
v = _scalar_var(0.0)
opt = O.SM3(lr=0.1, train_vars={'w': v}, eps=1e-30)
opt.update({'w': np.asarray(1.0, dtype=np.float32)})
# after one step cache = 1, update = 0.1 * 1 / sqrt(1) = 0.1
assert float(bm.as_jax(v.value)) == np.float32(-0.1)

def test_sm3_scalar_still_works_with_momentum(self):
v = _scalar_var(2.0)
opt = O.SM3(lr=0.1, train_vars={'w': v}, momentum=0.5, beta=0.5)
out = _train(opt, {'w': v}, lambda val: np.ones_like(val), steps=3)
assert np.all(np.isfinite(out['w']))
Loading
Loading