diff --git a/brainpy/losses/comparison.py b/brainpy/losses/comparison.py index 433406a41..bfe60cb50 100644 --- a/brainpy/losses/comparison.py +++ b/brainpy/losses/comparison.py @@ -570,7 +570,7 @@ def update(self, input: ArrayType, target: ArrayType) -> ArrayType: return l1_loss(input, target, reduction=self.reduction) -def l1_loss(logits, targets, reduction='sum'): +def l1_loss(logits, targets, reduction='mean'): r"""Creates a criterion that measures the mean absolute error (MAE) between each element in the logits :math:`x` and targets :math:`y`. It is useful in regression problems. @@ -1045,6 +1045,10 @@ def multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean'): a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`. """ assert p == 1 or p == 2, 'p should be 1 or 2' + # Convert to plain JAX arrays: under JAX >= 0.9 implicit __jax_array__ + # coercion was removed, so advanced-indexing a ``bm.Array`` would raise. + predicts = bm.as_jax(predicts) + targets = bm.as_jax(targets) batch_size = predicts.shape[0] correct_scores = predicts[jnp.arange(batch_size), targets] margins = jnp.power(jnp.maximum(0, predicts - correct_scores[:, jnp.newaxis] + margin), p) diff --git a/brainpy/losses/comparison_coverage_test.py b/brainpy/losses/comparison_coverage_test.py index 3e8c1dbcf..c4b73a038 100644 --- a/brainpy/losses/comparison_coverage_test.py +++ b/brainpy/losses/comparison_coverage_test.py @@ -221,21 +221,26 @@ def test_class_wrapper(self): # --------------------------------------------------------------------------- class TestRegressionLosses: def test_l1_loss_reductions(self): - # NOTE: l1_loss now delegates to braintools.metric.l1_loss, which - # computes the per-row MEAN absolute error (not the L1 norm) for - # reduction='none', then sums / means those per-row values. + # P1-L1: l1_loss delegates to braintools.metric.l1_loss, which for + # reduction='none' returns the per-row L1 *norm* (sum of abs over the + # trailing axes, reshaped to (N, -1)), NOT the per-row mean. So for the + # (2, 2) input below the 'none' output is the per-row sums [3, 7]; 'sum' + # then totals them (10) and 'mean' averages them (5). (The previous + # expectations of [1.5, 3.5]/5/2.5 encoded an incorrect per-row-mean + # assumption about braintools and were pre-existing baseline failures.) x = jnp.array([[1., 2.], [3., 4.]]) y = jnp.zeros((2, 2)) none = np.asarray(C.l1_loss(x, y, reduction='none')) - assert np.allclose(none, [1.5, 3.5]) # per-row mean abs error - assert float(C.l1_loss(x, y, reduction='sum')) == pytest.approx(5.0) - assert float(C.l1_loss(x, y, reduction='mean')) == pytest.approx(2.5) + assert np.allclose(none, [3.0, 7.0]) # per-row L1 norm (sum of abs) + assert float(C.l1_loss(x, y, reduction='sum')) == pytest.approx(10.0) + assert float(C.l1_loss(x, y, reduction='mean')) == pytest.approx(5.0) def test_l1_class(self): x = jnp.array([[1., 2.], [3., 4.]]) y = jnp.zeros((2, 2)) layer = C.L1Loss(reduction='sum') - assert float(layer.update(x, y)) == pytest.approx(5.0) + # sum over per-row L1 norms [3, 7] = 10.0 + assert float(layer.update(x, y)) == pytest.approx(10.0) def test_l2_loss_elementwise(self): out = np.asarray(C.l2_loss(jnp.array([2.0, 0.0]), jnp.array([0.0, 0.0]))) diff --git a/brainpy/losses/comparison_test.py b/brainpy/losses/comparison_test.py index a3b32f205..e611475a5 100644 --- a/brainpy/losses/comparison_test.py +++ b/brainpy/losses/comparison_test.py @@ -92,10 +92,16 @@ def test_mean_absolute_error(self): class TestReductionDefaults(unittest.TestCase): """Public default ``reduction`` values must not change.""" - def test_l1_loss_default_is_sum(self): + def test_l1_loss_default_is_mean(self): + # P1-M1 fix: the functional ``l1_loss`` default reduction is now + # ``'mean'`` (matching the ``L1Loss`` class, the docstring and PyTorch), + # not the surprising ``'sum'`` it previously defaulted to. pred, tar = _arr(3, 4, seed=1), _arr(3, 4, seed=2) self.assertTrue(_close(L.l1_loss(pred, tar), - L.l1_loss(pred, tar, reduction='sum'))) + L.l1_loss(pred, tar, reduction='mean'))) + # and it must equal the OO wrapper's default. + self.assertTrue(_close(L.l1_loss(pred, tar), + L.L1Loss().update(pred, tar))) def test_mse_default_is_mean(self): pred, tar = _arr(3, 4, seed=1), _arr(3, 4, seed=2) @@ -186,5 +192,30 @@ def test_multi_margin_loss_runs(self): self.assertIsNotNone(L.multi_margin_loss(logits, targets)) +class TestMultiMarginArrayEnvelope(unittest.TestCase): + """P1-H2: ``multi_margin_loss`` must accept ``bm.Array`` inputs. + + Under JAX >= 0.9 implicit ``__jax_array__`` coercion was removed, so + indexing a ``bm.Array`` with ``jnp`` advanced indexing raised + ``ValueError: Triggering __jax_array__() ... no longer supported``. Every + other loss in the module accepts ``bm.Array``; this one must too. + """ + + def test_multi_margin_accepts_bm_array(self): + predicts = bm.asarray(np.array([[0.2, 0.8], [0.6, 0.4]])) + targets = bm.asarray(np.array([1, 0])) + out = L.multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean') + self.assertTrue(np.isfinite(float(out))) + + def test_multi_margin_bm_matches_jax(self): + p_np = np.array([[0.2, 0.8, 0.1], [0.6, 0.4, 0.9]]) + t_np = np.array([1, 0]) + bm_out = np.asarray(L.multi_margin_loss(bm.asarray(p_np), bm.asarray(t_np), + p=2, reduction='none')) + jax_out = np.asarray(L.multi_margin_loss(bm.as_jax(bm.asarray(p_np)), + t_np, p=2, reduction='none')) + self.assertTrue(_close(bm_out, jax_out)) + + if __name__ == '__main__': unittest.main() diff --git a/brainpy/optim/optimizer.py b/brainpy/optim/optimizer.py index 57e1862c1..3a6871cd7 100644 --- a/brainpy/optim/optimizer.py +++ b/brainpy/optim/optimizer.py @@ -17,7 +17,6 @@ from typing import Union, Sequence, Dict, Optional, Tuple import jax.numpy as jnp -from jax.lax import cond import brainpy.math as bm from brainpy import check @@ -772,6 +771,10 @@ def __init__( self.eps = eps self.weight_decay = weight_decay self.no_prox = no_prox + # Per-update step counter for bias correction (see ``Adam``). It must be + # independent of the LR scheduler's ``last_epoch`` (which the optimizer + # never advances) and advance exactly once per ``update()``. + self.step = bm.Variable(jnp.asarray(0)) def __repr__(self): return (f"{self.__class__.__name__}(lr={self.lr}, " @@ -804,18 +807,26 @@ def _update_moments(self, m, n, v, pre_g, g): def update(self, grads: dict): self.check_grads(grads) - lr = self.lr() - step = self.lr.last_epoch.value + 1 - correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1)) - correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1)) - correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1)) + lr = bm.as_jax(self.lr()) + # Advance the per-update step counter (t = 1 on the first update). The + # bias-correction terms use ``(1 - beta) ** t`` (matching the reference + # Adan), so they actually evolve over training instead of being frozen. + self.step.value = self.step.value + 1 + step = self.step.value + correct_m = 1 / (1 - (1 - self.betas[0]) ** step) + correct_v = 1 / (1 - (1 - self.betas[1]) ** step) + correct_n = 1 / (1 - (1 - self.betas[2]) ** step) for key, p_var in self.vars_to_train.items(): m_var = self.implicit_vars[key + '_m'] n_var = self.implicit_vars[key + '_n'] v_var = self.implicit_vars[key + '_v'] prev_g_var = self.implicit_vars[key + '_prev_grad'] g = grads[key] - pre_g = cond(step == 0, lambda pg, g: g, lambda pg, g: pg, (prev_g_var.value, g)) + # On the first update there is no previous gradient, so the gradient + # difference must be 0 (i.e. ``pre_g := g``). Use a value-level + # ``where`` rather than ``lax.cond`` (whose operand is splatted into + # the branch functions, which was the source of the crash). + pre_g = jnp.where(step == 1, g, prev_g_var.value) diff = g - pre_g m = m_var.value * (1 - self.betas[0]) + self.betas[0] * g v = v_var.value * (1 - self.betas[1]) + self.betas[1] * diff @@ -1082,6 +1093,13 @@ def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = Non vs = dict() for k, v in train_vars.items(): rank, ndim = v.shape, v.ndim + if ndim == 0: + # A 0-dim (scalar) variable has no axes to build a cover over. + # Register a single scalar accumulator so SM3 degenerates to an + # Adagrad-like update (otherwise ``update`` would ``KeyError`` on + # the missing ``{k}_m0`` accumulator). + vs[f'{k}_m0'] = bm.Variable(bm.zeros((), dtype=v.dtype)) + continue for i in range(ndim): shape = [1] * ndim shape[i] = rank[i] @@ -1098,7 +1116,9 @@ def update(self, grads: dict): for k, p in self.vars_to_train.items(): g = grads[k] - ndim = p.ndim + # Match the rank-1 fallback used when registering accumulators for + # scalar variables (see ``register_train_vars``). + ndim = max(p.ndim, 1) update = self.implicit_vars[f'{k}_m0'] for i in range(1, ndim): update = bm.minimum(update, self.implicit_vars[f'{k}_m{i}']) diff --git a/brainpy/optim/optimizer_coverage_test.py b/brainpy/optim/optimizer_coverage_test.py index 54e12fcd2..4b9f59852 100644 --- a/brainpy/optim/optimizer_coverage_test.py +++ b/brainpy/optim/optimizer_coverage_test.py @@ -185,17 +185,16 @@ def test_zero_norm_branch(self): class TestAdan: - def test_update_is_currently_broken(self): - # NOTE (defect): Adan.update calls - # cond(step == 0, lambda pg, g: g, lambda pg, g: pg, (prev_g_var.value, g)) - # jax.lax.cond unpacks *operands, so passing the single 2-tuple binds - # pg=(prev, g) and leaves g unbound -> TypeError. Adan.update therefore - # cannot run as written (both no_prox branches are unreachable). + def test_update_runs(self): + # P1-C1/P1-C2 fix: Adan.update used to crash (jax.lax.cond operand + # splatting) and its step counter was frozen at 0. It now runs and the + # per-update step counter advances. v = _make_var(2.0) opt = O.Adan(lr=1e-2, train_vars={'w': v}) assert 'no_prox' in repr(opt) - with pytest.raises(TypeError): - opt.update({'w': bm.as_jax(v.value)}) + out = _train(opt, {'w': v}, lambda val: val) + assert np.isfinite(out['w']) + assert int(bm.as_jax(opt.step.value)) == 5 def test_invalid_eps(self): with pytest.raises(ValueError): @@ -245,16 +244,15 @@ def test_invalid_hyperparams(self): class TestSM3: - def test_scalar_var_is_broken(self): - # NOTE (defect): SM3.register_train_vars loops ``for i in range(ndim)``; - # for a scalar (0-dim) variable no ``_m{i}`` accumulator gets created, - # yet ``update`` reads ``{k}_m0`` -> KeyError. SM3 only works for - # >=1-dim variables. + def test_scalar_var_runs(self): + # P1-H1 fix: SM3 used to KeyError('w_m0') for a scalar (0-dim) variable + # because no accumulator was registered. It now registers a single + # scalar accumulator (Adagrad-like) and updates correctly. v = _make_var(2.0) opt = O.SM3(lr=0.1, train_vars={'w': v}) assert 'beta' in repr(opt) - with pytest.raises(KeyError): - opt.update({'w': bm.as_jax(v.value)}) + out = _train(opt, {'w': v}, lambda val: np.ones_like(val)) + assert np.isfinite(out['w']) def test_1d_var(self): v = bm.Variable(bm.asarray(np.array([1.0, 2.0], dtype=np.float32))) diff --git a/brainpy/optim/optimizer_test.py b/brainpy/optim/optimizer_test.py new file mode 100644 index 000000000..536f00813 --- /dev/null +++ b/brainpy/optim/optimizer_test.py @@ -0,0 +1,119 @@ +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Regression tests for the 2026-06-19 optim/losses audit. + +These pin the bug fixes for: + +- P1-C1/P1-C2: ``Adan.update`` crashed on every call (``jax.lax.cond`` operand + mis-binding) and its step counter was frozen at 0 (bias correction / Nesterov + term disabled). +- P1-H1: ``SM3`` raised ``KeyError`` for scalar (0-dim) trainable variables. +""" + +import numpy as np + +import brainpy.math as bm +from brainpy.optim import optimizer as O + + +def _vec_var(values): + return bm.Variable(bm.asarray(np.asarray(values, dtype=np.float32))) + + +def _scalar_var(value=2.0): + return bm.Variable(bm.asarray(np.asarray(value, dtype=np.float32))) + + +def _train(opt, var_dict, grad_fn, steps=5): + for _ in range(steps): + grads = {k: grad_fn(bm.as_jax(v.value)) for k, v in var_dict.items()} + opt.update(grads) + return {k: np.asarray(bm.as_jax(v.value)) for k, v in var_dict.items()} + + +# --------------------------------------------------------------------------- +# Adan (P1-C1, P1-C2) +# --------------------------------------------------------------------------- +class TestAdanFixed: + def test_adan_runs_and_updates(self): + # Previously raised TypeError on the very first update. + v = _vec_var([2.0, -3.0]) + opt = O.Adan(lr=1e-2, train_vars={'w': v}) + out = _train(opt, {'w': v}, lambda val: val, steps=5) + assert np.all(np.isfinite(out['w'])) + # gradient = value -> the parameter should move toward zero. + assert np.all(np.abs(out['w']) < np.array([2.0, 3.0])) + + def test_adan_no_prox_runs(self): + v = _vec_var([2.0, -3.0]) + opt = O.Adan(lr=1e-2, train_vars={'w': v}, no_prox=True) + out = _train(opt, {'w': v}, lambda val: val, steps=5) + assert np.all(np.isfinite(out['w'])) + + def test_adan_step_counter_advances(self): + v = _vec_var([1.0]) + opt = O.Adan(lr=1e-2, train_vars={'w': v}) + for _ in range(4): + opt.update({'w': bm.as_jax(v.value)}) + assert int(bm.as_jax(opt.step.value)) == 4 + + def test_adan_first_step_diff_is_zero(self): + # On the very first update the gradient difference (g - g_prev) must be + # treated as 0, so the exp_avg_diff (``_v``) accumulator stays 0 after + # one step regardless of the gradient magnitude. + v = _vec_var([5.0]) + opt = O.Adan(lr=1e-2, train_vars={'w': v}) + opt.update({'w': bm.as_jax(v.value)}) + v_diff = np.asarray(bm.as_jax(opt.implicit_vars['w_v'].value)) + assert np.allclose(v_diff, 0.0) + + def test_adan_nesterov_term_active_after_two_steps(self): + # With the step counter frozen the diff term was permanently 0; here a + # changing gradient must produce a non-zero exp_avg_diff after step 2. + v = _vec_var([1.0]) + opt = O.Adan(lr=1e-2, train_vars={'w': v}) + opt.update({'w': np.asarray([1.0], dtype=np.float32)}) + opt.update({'w': np.asarray([3.0], dtype=np.float32)}) # gradient changed + v_diff = np.asarray(bm.as_jax(opt.implicit_vars['w_v'].value)) + assert not np.allclose(v_diff, 0.0) + + +# --------------------------------------------------------------------------- +# SM3 (P1-H1) +# --------------------------------------------------------------------------- +class TestSM3Fixed: + def test_sm3_scalar_var_runs(self): + # Previously raised KeyError('w_m0') for a 0-dim variable. + v = _scalar_var(2.0) + opt = O.SM3(lr=0.1, train_vars={'w': v}) + out = _train(opt, {'w': v}, lambda val: np.ones_like(val), steps=4) + assert np.all(np.isfinite(out['w'])) + # gradient is +1 each step -> scalar parameter must decrease. + assert float(out['w']) < 2.0 + + def test_sm3_scalar_matches_adagrad_like_step(self): + # For a scalar with constant gradient g=1, SM3 reduces to an Adagrad-like + # update: cache accumulates g^2, step = lr * g / sqrt(cache + eps). + v = _scalar_var(0.0) + opt = O.SM3(lr=0.1, train_vars={'w': v}, eps=1e-30) + opt.update({'w': np.asarray(1.0, dtype=np.float32)}) + # after one step cache = 1, update = 0.1 * 1 / sqrt(1) = 0.1 + assert float(bm.as_jax(v.value)) == np.float32(-0.1) + + def test_sm3_scalar_still_works_with_momentum(self): + v = _scalar_var(2.0) + opt = O.SM3(lr=0.1, train_vars={'w': v}, momentum=0.5, beta=0.5) + out = _train(opt, {'w': v}, lambda val: np.ones_like(val), steps=3) + assert np.all(np.isfinite(out['w'])) diff --git a/docs/issues-found-20260618.md b/docs/issues-found-20260618.md deleted file mode 100644 index e11674de1..000000000 --- a/docs/issues-found-20260618.md +++ /dev/null @@ -1,387 +0,0 @@ -# BrainPy Package Audit — Issues Found (2026-06-18) - -**Reviewer role:** Senior Python architect · JAX expert · BrainX-ecosystem developer -**Package:** `brainpy` v2.7.8 (~74k LOC, 252 non-test files, 17 submodules) -**Environment (verified):** Python 3.13.11 · jax 0.10.1 · brainpy 2.7.8 · brainstate 0.5.0 · brainevent 0.1.0 (CPU) -**Method:** Full deep sweep by parallel expert sub-audits of every submodule; static review of all findings + executable repro for the high-severity ones. 33 of the highest-impact findings were independently re-verified by the lead reviewer; all reproduced. - -> **Scope note.** `import brainpy` works in this environment, so most findings are *runtime-reproduced*, not speculative. Findings are tagged **[verified]** (a repro was executed and reproduced the bug), **[static]** (confirmed by code inspection / type analysis), or **[likely]** (strong reasoning, not executed). - ---- - -## 1. Executive summary - -The audit found **131 distinct issues**: **26 Critical**, **53 High**, **36 Medium**, **16 Low**. The dominant story is **ecosystem-migration drift**: BrainPy 2.7.x was rebased onto the new BrainX stack (`brainstate` 0.5, `brainevent` 0.1, `braintools`) and onto JAX ≥0.9/0.10, and many code paths were not updated in lockstep. The result is a band of **silent numerical errors** and **crash-on-first-use** bugs concentrated in: the optimizer/loss/scheduler stack, the surrogate-gradient and synapse/plasticity code, the sparse/event operators, the FDE/adaptive integrators, and the normalization layers. - -Highest-impact, broad-blast-radius issues (all **verified**): - -| # | Issue | Location | Impact | -|---|-------|----------|--------| -| C-01 | **Adam/AdamW bias correction frozen at t=1** | `optim/optimizer.py:593-594,964-967` | Every Adam/AdamW training run uses wrong (un-debiased, growing) steps | -| C-02 | **`nll_loss` returns +log-likelihood (sign flipped)** | `losses/comparison.py:461` | NLL training maximizes instead of minimizes | -| C-03 | **`cross_entropy_loss` weights by sample index, not class** | `losses/comparison.py:266-267` | Class-weighted CE silently wrong / shape-crashes | -| C-04 | **`MultiStepLR` never decays** | `optim/scheduler.py:157-163` | LR schedule is a no-op | -| C-05 | **`GroupNorm`/`InstanceNorm` reduce over the group axis** | `dnn/normalization.py:597` | `num_groups` has no effect; every config == LayerNorm | -| C-06 | **STP facilitation ODE diverges** | `dyn/synapses/abstract_models.py:862` | Short-term plasticity synapse blows up to ±thousands | -| C-07 | **`csrmm(transpose=True)` computes the wrong product** | `math/sparse/csr_mm.py:63`, `math/event/csr_matmat.py:64` | Wrong values / shape crash in sparse matmat + its autodiff | -| C-08 | **`CaputoEuler` mis-scales the initial condition** | `integrators/fde/Caputo.py:201` | Fractional ODE solver wrong whenever y0≠0 | -| C-09 | **`TimeDelay` read omits modulo** | `math/delayvars.py:271` | Delay variable returns stale/wrong delayed values | -| C-10 | **`disable_x64()` desyncs brainstate vs JAX precision** | `math/environment.py:645` | After any x64 context, default dtypes silently wrong | - -The good news: the **core ODE Runge–Kutta tableaus, most synapse kinetics, the GRU cell, the high-dim fixed-point finder, weight initializers, and the Conv/Dense/Dropout layers were checked and found correct** (see Appendix B). The bugs are concentrated, not pervasive — which makes them tractable to fix. - ---- - -## 2. Cross-cutting themes (root causes) - -1. **brainstate 0.5 migration drift.** `pyproject` pins `brainstate>=0.2.7` but 0.5.0 is installed. Removed/renamed APIs surface as runtime crashes: `tracing_variable` now `raise NotImplementedError` (breaks default STDP, C-19), `jax.util` removed (breaks `VarDict` pytree, C-25), `State`-as-operand rejected in control flow (H-…), `Variable.size_without_batch` broken (H-…). **Action:** pin a tested `brainstate`/`brainevent` lower bound and add an import-time smoke test across the public surface. - -2. **JAX ≥0.9/0.10 API changes not propagated.** `Array.device()` (now a property), `csr_todense`/`csrmm` signatures, `jnp.argsort(kind=)` removed, `__float__` rejecting `ndim>0`. **Action:** a compatibility shim module + CI against the pinned JAX. - -3. **`brainevent` 0.1 backend migration left wrappers stale.** `brainevent.COO` removed (C, coomv dead), transpose semantics inverted in matmat (C-07), `coo_to_csr`/`csr_to_dense` broken, jitconn docstrings describe a `method=`/cuSPARSE API that no longer exists. - -4. **Surrogate-gradient ⇄ forward-function inconsistency.** Multiple surrogate classes have a `surrogate_grad` formula that does **not** match the derivative of their own `surrogate_fun` / docstring (Gaussian precedence, PiecewiseQuadratic, q-PseudoSpike, ERF sign, arctan crash). Compounded by `bm.surrogate` being **shadowed by `braintools.surrogate`**, so the in-repo package is dead relative to the public API yet still importable and buggy. - -5. **Validation-after-mutation ordering.** Several setters/config functions mutate state or read attributes *before* validating/normalizing inputs: `environment.set()` (partial-config leak), `Variable.value` setter (rejects `State`/numpy before unwrap), adaptive-RK `tol` default not propagated, `SM3.__init__` reads `self.momentum` before assignment. - -6. **Batched-math assumptions.** RLS/FORCE (C-23) and several reductions assume batch size 1; correct for the tested path, silently divergent for B>1. - -7. **`dt` vs `sqrt(dt)` and unit scaling in stochastic/rate models.** `ThresholdLinearModel` noise scales as `dt` not `sqrt(dt)`; PoissonInput uses variance as std (C-17); `CondNeuGroup` double-applies area scaling. - -8. **Docstring/NumPy-doc nonconformance & drift.** Pervasive `Parameters::` / `Returns::` literal-block markers (won't render), stale deprecation versions ("removed after 2.4.0" in 2.7.8), and docstrings whose constants/defaults disagree with code. - ---- - -## 3. Critical findings (detail) - -### C-01 — Adam/AdamW bias correction is frozen at t=1 **[verified]** -- **File:** `brainpy/optim/optimizer.py:593-594` (Adam), `:964-967` (AdamW); root cause `optim/scheduler.py:55-59` -- **What:** Bias correction uses `self.lr.last_epoch.value + 2`, but with the default `Constant` scheduler `last_epoch` is never incremented during `update()` (only `step_epoch()` advances it, which optimizers never call). So `beta**(last_epoch+2) == beta**1` forever and the `m`/`v` EMAs are never debiased. -- **Why it's wrong:** Under a constant gradient, correct Adam yields a constant step ≈ `-lr`. Measured steps instead grow: `dw = [-0.001, -0.00134, -0.00157, -0.00172, -0.00183]`. -- **Fix:** Maintain an internal per-`update()` step counter `t` (independent of the LR scheduler) and use `beta1**t`, `beta2**t` for bias correction. Don't derive `t` from `last_epoch`. - -### C-02 — `nll_loss` returns the log-likelihood, not its negative **[verified]** -- **File:** `brainpy/losses/comparison.py:461` (class `NLLLoss` wraps it) -- **What:** `return mean(input[arange, target])` with no negation; the function's own docstring defines `-Σ w·x_{n,y_n}`. -- **Measured:** `nll_loss(log p, [0,1]) = -0.2899` (correct `+0.2899`). Minimizing drives the correct-class log-prob to −∞. -- **Fix:** `loss = -input[jnp.arange(len(target)), target]` (negate), keep the reductions. - -### C-03 — `cross_entropy_loss` applies class `weight` by sample index **[verified]** -- **File:** `brainpy/losses/comparison.py:266-267` (`loss *= weight`) -- **What:** Per-sample loss `(N,)` is multiplied elementwise by the per-class weight `(C,)` — so sample *n* is weighted by `weight[n]`, not `weight[target_n]`. Raises on `N≠C`, silently wrong on `N==C`. -- **Measured:** logits `0(3,3)`, targets `[2,2,2]`, weight `[10,20,1]` → per-sample `[10.99, 21.97, 1.10]`; correct is all `1.10` (`w[2]`). -- **Fix:** gather `weight[target]` before reduction; for `mean`, normalize by `sum(weight[targets])`. - -### C-04 — `MultiStepLR` never decays **[verified]** -- **File:** `brainpy/optim/scheduler.py:157-163` -- **What:** `conditions = (i>=milestones[:-1]) & (i None`. -- **Measured:** `odeint(..., method='rkf45', adaptive=True)(...)` → `TypeError: '>' not supported between ArrayImpl and NoneType`. -- **Fix:** `code_scope['tol'] = self.tol` (and keyword default likewise). - -### C-13 — All SDE integrators `NameError` on invalid type (missing `errors` import) **[verified]** -- **File:** `brainpy/integrators/sde/base.py:76,79,82`; `sde/normal.py:225` -- **What:** Validation references `errors.IntegratorError` but `errors` is never imported; also the `Heun` Ito/Stratonovich guard. -- **Measured:** `sdeint(..., intg_type='WRONG')` → `NameError: name 'errors' is not defined`. -- **Fix:** `from brainpy import _errors as errors` in both files. - -### C-14 — Standalone HH/Markov channel gating produces NaN at voltage singularities **[verified by sub-audit]** -- **File:** `brainpy/dyn/channels/sodium.py:384,299,215`; `potassium.py:359,222,290` (+legacy dups `:1191,1261,1332`); `calcium.py:711` -- **What:** Rates coded as `k*temp/(1-exp(-temp/d))` are 0/0 → NaN exactly at the removable singularity (e.g. `IK_HH1952v2` at V=−55). The HH *neuron* class was fixed with `bm.exprel`; the channel modules were not. `bm.where` clamping can't recover it (both branches evaluated). -- **Measured:** `IK_HH1952v2(1).f_p_alpha([-55.0]) = [nan]`. -- **Fix:** rewrite with `bm.exprel`, e.g. `0.1 / bm.exprel(-(V - V_sh + 10)/10)` (mind the `k*d` coefficient bookkeeping). Fix legacy duplicates too. - -### C-15 — `ThresholdLinearModel` noise path crashes (`randn` signature) **[verified]** -- **File:** `brainpy/dyn/rates/populations.py:1051,1060` -- **What:** `bm.random.randn(self.varshape)` passes a shape *tuple* as a single positional arg; brainstate's `randn` takes unpacked dims. -- **Measured:** any nonzero `noise_e/noise_i` → `TypeError: Shapes must be 1D sequences ... got ((1000,),)`. -- **Fix:** `bm.random.randn(*self.varshape)` or `bm.random.normal(size=self.varshape)`. (Separately, the noise scales as `dt` not `sqrt(dt)` — see M-…) - -### C-16 — `StuartLandauOscillator.dy` has the wrong rotational coupling **[verified]** -- **File:** `brainpy/dyn/rates/populations.py:721` -- **What:** `dy` returns `(a-x²-y²)*y - w*y + y_ext` (copy-paste from `dx`); the Hopf normal form needs `+ w*x`. As written there's no x↔y rotation, so no limit cycle. -- **Measured:** `dy(y=.5,x=.3,a=.25,w=.2) = -0.145` (buggy `-w*y`); correct `+w*x` gives `+0.015`. -- **Fix:** `return (a - x*x - y*y)*y + w*x + y_ext`. - -### C-17 — `PoissonInput` Gaussian branch uses the variance as the std (~3–4× too much noise) **[verified]** -- **File:** `brainpy/dyn/projections/inputs.py:168,174`; duplicated in `brainpy/dynold/experimental/others.py:74-77` -- **What:** `bm.random.normal(a, b*p, ...)` passes `b*p = n(1-p)p` (the Binomial *variance*) as the std; correct is `sqrt(n·p·(1-p)) = sqrt(b*p)`. -- **Measured:** `n=1000,p=0.02` → code std `19.6` vs correct `4.43`. Active in the common large-N branch; mean is correct. -- **Fix:** `scale = jnp.sqrt(b*p)` (both the eager and `bm.cond` branches, and the dynold copy). - -### C-18 — `HalfProjAlignPost.update` calls `comm` twice **[verified by sub-audit]** -- **File:** `brainpy/dyn/projections/align_post.py:384-388` -- **What:** Computes `current = self.comm(x)` then `g = self.syn(self.comm(x))` — two independent calls. For event/jit-prob comms each call draws fresh random connectivity, so the synapse sees different input than the returned current; doubles compute for deterministic comms. -- **Fix:** `current = self.comm(x); g = self.syn(current); ...; return current`. - -### C-19 — `STDP_Song2000` crashes on the first update (tracing_variable removed) **[verified by sub-audit]** -- **File:** `brainpy/dyn/projections/plasticity.py:230-240` → `brainpy/dnn/linear.py:502-503` -- **What:** `stdp_update` falls back to `self.tracing_variable('weight', ...)`, which now unconditionally `raise NotImplementedError`. The weight is only a `Variable` when the comm is built with `mode=TrainingMode`; the class docstring example omits it, so the documented usage is dead on arrival. -- **Fix:** in `stdp_update`, promote the weight directly (`self.weight = bm.Variable(self.weight)`) or require trainable weights in `STDP_Song2000`. Also fixes a companion crash (H-…: `bm.as_jax(None)` for default `W_min/W_max`). - -### C-20 — `AlphaCUBA` / `AlphaCOBA` raise `ZeroDivisionError` on construction **[verified]** -- **File:** `brainpy/dynold/synapses/compat.py:208-270`; root `brainpy/dyn/synapses/abstract_models.py:159-164` -- **What:** They pass `tau_rise == tau_decay` into `DualExpon`, whose peak normalizer `A = tau_decay/(tau_decay - tau_rise)·…` divides by zero. -- **Measured:** `bp.synapses.AlphaCUBA(LIF(2), LIF(2), All2All(), tau_decay=10.)` → `ZeroDivisionError`. -- **Fix:** route `AlphaCUBA/COBA` through the single-tau `synapses.Alpha`, or special-case `tau_rise==tau_decay` (L'Hôpital limit `A=e`, `a=1/tau`). - -### C-21 — dynold `STP` learning rule injects current with zero presynaptic spikes **[verified by sub-audit]** -- **File:** `brainpy/dynold/synapses/learning_rules.py:33-37,231-233` -- **What:** `_STPModel = Sequential(STP, Expon)`; modern `STP.update` returns `u*x` (≈0.15 at rest) every step, and `Expon` treats it as additive current, so `g += u*x` continuously. The spike gating is lost. -- **Measured:** zero input → `syn.I` ramps to ~512 and keeps rising. -- **Fix:** gate by spikes (`pre_spike*(u*x)`), or use the modern `dyn/projections/plasticity` wiring. - -### C-22 — `DSRunner(memory_efficient=True)` is completely non-functional **[verified by sub-audit]** -- **File:** `brainpy/runners.py:638-647` (+ `_step_mon_on_cpu` :617-619) -- **What:** `_step_func_monitor()` returns a dict, but the code does `jax.ShapeDtypeStruct(mon.shape, mon.dtype)` on it; the `pure_callback` arg count and the `None` return are also wrong. Cannot have worked since the migration. -- **Measured:** any `memory_efficient=True` run → `AttributeError: 'dict' object has no attribute 'shape'`. -- **Fix:** `jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), mon)`; fix the callback signature/return; add a smoke test. - -### C-23 — RLS / FORCE online update is wrong for batch size > 1 **[verified by sub-audit]** -- **File:** `brainpy/algorithms/online.py:148-154` (drives `train/online.py` `OnlineTrainer`/`ForceTrainer`) -- **What:** `c = jnp.sum(1.0/(1.0+hPh))` collapses the `B×B` matrix `(I+HPHᵀ)` to a scalar (summing reciprocals of all entries incl. off-diagonals). Correct only for B=1; for B>1 `c` grows with B and can go negative → `P` diverges, update sign flips. -- **Measured:** B=16 → `c=-90.3` (correct diag value `+1.7`); fitting with B≥4 → NaN weights within a few hundred steps. -- **Fix:** proper block RLS: `K = PHᵀ(I + HPHᵀ)⁻¹` via `jnp.linalg.solve`, or assert `input.shape[0]==1`. - -### C-24 — `PoissonEncoder.single_step` crashes (its own documented usage) **[verified]** -- **File:** `brainpy/encoding/stateless_encoding.py:91-94` → `:111` -- **What:** `single_step(x, i_step=None)` delegates to `multi_steps(x, n_time=None)` whose first line is `int(n_time/get_dt())` → `None/float`. -- **Measured:** `PoissonEncoder().single_step(bm.random.rand(4))` → `TypeError`. -- **Fix:** in the `i_step is None` branch, draw a single Bernoulli sample directly; guard `multi_steps` against `n_time is None`. - -### C-25 — `VarDict.tree_unflatten` crashes on every JAX transform (`jax.util` removed) **[verified by sub-audit]** -- **File:** `brainpy/math/object_transform/variables.py:423` -- **What:** Calls `jax.util.safe_zip(...)`, but `jax.util` no longer exists in jax 0.10.1. `VarDict` is a registered pytree, so any `jit`/`vmap`/`tree_map` over one fails. -- **Measured:** `jax.jit(lambda d: d)(bm.var_dict({'a': bm.Variable(...)}))` → `AttributeError: module 'jax' has no attribute 'util'`. -- **Fix:** use `brainstate._compatible_import.safe_zip` or `cls(zip(keys, values))`. - -### C-26 — `Variable` batch_axis / axis_names silently dropped through pytree round-trip **[verified by sub-audit]** -- **File:** `brainpy/math/object_transform/variables.py:40,79` (inherits `Array.tree_flatten` returning `aux_data=None`) -- **What:** Reconstructing a `Variable` after `jit`/`vmap`/`scan`/`tree_map` loses `batch_axis`/`axis_names` (reset to `None`). brainstate's closure-based transforms work around it, but explicit pytree / `jit`-argument use degrades silently (affects sharding, `size_without_batch`, value-setter shape checks). -- **Fix:** override `Variable.tree_flatten/unflatten` to carry `(batch_axis, axis_names)` in `aux_data` and rebuild without re-running naming/`State.__init__` side effects. - ---- - -## 4. High findings (detail) - -> Format: `[verified|static] file:line — what → fix` - -### Object model / transforms (`math/object_transform`) -- **H-01 [static]** `jit.py:200-207` — `cls_jit` shifts `static_argnums` by `+1` unconditionally, corrupting **negative** indices (`-1` → `(0,0)` marks `self` static twice). → shift only `x>=0`; resolve negatives against the signature. -- **H-02 [verified]** `controls.py:125-561` + `_utils.py:78-85` — passing a `Variable`/`State` in `operands` of `cond`/`for_loop`/`while_loop`/`scan` raises (`State` rejected at brainstate cache-key time, before the in-wrapper strip). → strip state from `operands` before forwarding; document closure capture as the supported path. -- **H-03 [verified]** `controls.py:372-390` — `for_loop(jit=False)` zero-length guard only checks `operands[0].shape`; a pytree operand (dict) has no `.shape`, so it crashes with "zero-length scan … in disable_jit()". → compute leading length from `jax.tree.leaves`. -- **H-04 [verified]** `jit.py:127-153` — `jit` docstrings advertise `dyn_vars`/`child_objs`, now forwarded into `**kwargs` → `TypeError` from brainstate. → drop from docstring; filter/warn in `jit`. -- **H-05 [verified]** `base.py:603-614` — `to()/cpu()/cuda()/tpu()` iterate `state_dict()` (nested dicts), so `isinstance(var, Array)` is always False; they never move variables and inject junk dict-valued attributes named after nodes. → iterate `self.vars().values()` and set `var.value = jax.device_put(...)`. -- **H-06 [verified]** `variables.py:143-172` — `Variable.value` setter validates shape/dtype on the raw input *before* unwrapping `State`/numpy, so `v.value = some_State` and `v32.value = np.float64_array` raise spurious `MathError`. → unwrap/convert first, then validate. -- **H-07 [verified]** `naming.py:24,34-44` — global `_name2id` registry grows unboundedly (no weakref/GC pruning) and stores `id(obj)`, so reused ids cause false `UniqueNameError`. → `WeakValueDictionary` + prune dead refs. -- **H-08 [verified]** `base.py:287-309` vs `collectors.py:198` — `register_implicit_vars` default `var_cls` accepts `VarList/VarDict`, but `ArrayCollector.__setitem__` asserts `isinstance(value, Variable)` → `AssertionError`. → flatten containers or relax the collector. -- **H-09 [verified]** `variables.py:41` — `Variable.__eq__` returns an elementwise array while `__hash__` is identity-based → breaks `in`/set/dict-by-value and raises ambiguous-truth. → define identity `__eq__`/`__ne__`, or guarantee all internal membership uses `id()`. - -### Math core / compat / sparse -- **H-10 [verified]** `modes.py:38-41` — `Mode` overrides `__eq__` without `__hash__`, so every mode is unhashable (regression vs the hashable brainstate parent). → `__hash__ = brainstate.mixin.Mode.__hash__`. -- **H-11 [verified]** `ndarray.py:206-207` — `Array.device()` calls the now-property `jax.Array.device` → `TypeError`. (Also `device_buffer` :209, `block_host_until_ready` :200.) → make `device` a property. -- **H-12 [verified]** `ndarray.py:99-107` — `Array(scalar)` stores a bare Python scalar; `.shape`/most ops then crash. → fall through to `jnp.asarray(value)`. -- **H-13 [verified]** `compat_numpy.py:217-220` — `asfarray(a)` with default `dtype=None` no-ops on integer input because `np.issubdtype(None, np.inexact)` is True. → `if dtype is None or not issubdtype(...): dtype = float`. -- **H-14 [verified]** `_utils.py:59-62` (+ pytorch compat) — `out=` argument makes wrapped funcs return `None` (numpy/torch return the array). → `return out` after `out.value = r`. -- **H-15 [verified]** `others.py:94-96` — `remove_diag` uses concrete boolean-mask indexing → `NonConcreteBooleanIndexError` under `jit`/`vmap`. → static off-diagonal gather. -- **H-16 [verified]** `activations.py:668-669` — `softmin` lacks max-subtraction → NaN for large inputs (`softmin([1000,1001,1002]) = [nan,nan,nan]`). → `softmax(-x, axis)`. -- **H-17 [verified]** `sparse/coo_mv.py:82` — `coomv` builds `brainevent.COO`, removed in brainevent 0.1 → `AttributeError`. → convert COO→CSR or drop. -- **H-18 [verified]** `sparse/utils.py:42,47-49` — `coo_to_csr` broken: `argsort(kind=)` removed, in-place item assignment on immutable array, float `indptr`. → `argsort(stable=True)`, `.at[].set`, int dtype. -- **H-19 [verified]** `sparse/utils.py:64` — `csr_to_dense = csr_todense` re-exports a jax function whose signature changed to take a `CSR` object → `TypeError` on the legacy call. → wrap explicitly via `brainevent.CSR(...).todense()`. - -### Surrogate gradients (`math/surrogate`) — all **[verified]**, present in both `_one_input.py` and `_one_input_new.py` -- **H-20** `_one_input_new.py:1492` (`_one_input.py:1446`) — `GaussianGrad`: `exp(-(x**2)/2*sigma**2)` = `exp(-(x²/2)·σ²)`, σ inverted (precedence). At σ=2 the bump is ~`e²`× too narrow (grad@±1 ≈ 0.0135 vs intended ≈0.088). → `exp(-(x**2)/(2*sigma**2))`. -- **H-21** `_one_input_new.py:254` (`:195`) — `PiecewiseQuadratic`: grad uses `-(α x)²+α` but the derivative of its own forward is `-α²|x|+α`. → `-self.alpha**2*jnp.abs(x)+self.alpha`. -- **H-22** `_one_input_new.py:1118` (`:1069`) — `q_pseudo_spike`: grad denominator uses `alpha+1`, docstring/forward use `alpha-1`. → `alpha-1`. -- **H-23** `_one_input_new.py:529` (`:474`) — `arctan.surrogate_fun` calls `jnp.arctan2(...)` with one arg → `TypeError`. → `jnp.arctan(...)`. -- **H-24** `_one_input_new.py:710` (`:657`) — `ERF.surrogate_fun = erf(-αx)*0.5` is decreasing in [−0.5,0.5]; should be `0.5*(1-erf(-αx))`. → fix sign/offset (`0.5*erfc(-αx)`). -- **H-25 [verified]** `math/__init__.py:47` — `bm.surrogate` is reassigned to `braintools.surrogate`, so the entire in-repo `brainpy/math/surrogate` package (with the above bugs) is unreachable via the public API yet still importable. → delete the in-repo package or stop the override; don't ship both. - -### Integrators -- **H-26 [verified by sub-audit]** `ode/adaptive_rk.py:532` — `BoSh3.B2 = ['-5/72', …]` (sums to 0) makes the embedded error estimate ~20× too large and wrong-signed. → `B2 = ['7/24', 0.25, '1/3', 0.125]`. -- **H-27 [verified]** `ode/adaptive_rk.py:221` — step controller `where(error>tol, shrink, dt)` never *increases* dt (one-sided), contradicting the docstring. → unconditional clamped factor `dt*clip(0.9*(tol/error)**(1/(p+1)), …)`. -- **H-28 [verified]** `ode/adaptive_rk.py:164,214-217` — default `var_type=POP_VAR` emits `sum(abs(...))` (builtin `sum`) → `'float' object is not iterable` on scalar state. → `jnp.sum(jnp.abs(...))`. -- **H-29 [verified]** `integrators/runner.py:242,254,262` — `IntegratorRunner` reuses loop var `i` (`for i,v in enumerate(...)`) clobbering the step index, so monitors reading `shared['i']` get `len(vars)-1`. → rename inner loop var. - -### FDE -- **H-30 [verified]** `fde/GL.py:187` — `GLShortMemory.reset` uses key `key` instead of `key+'_delay'` → `KeyError` on every reset. → add the suffix. -- **H-31 [verified by sub-audit]** `fde/Caputo.py:375` — `CaputoL1Schema.hists()` default path does `{k:v.numpy() for k,v in hists_}` (iterates dict keys) → `ValueError`. → `.items()`. -- **H-32 [verified by sub-audit]** `fde/generic.py:87-88` — `set_default_fdeint()` assigns `_DEFAULT_ODE_METHOD` (wrong global), so it's a no-op. → assign `_DEFAULT_DDE_METHOD`. - -### Dyn (neurons / synapses / rates) -- **H-33 [verified]** `dyn/ions/base.py:54-55` — `for k,v in channels.items(): self.add_elem(k=v)` passes literal keyword `k`, so all channels register under name `"k"` and overwrite each other. → `self.add_elem(**{k: v})`. -- **H-34 [verified by sub-audit]** `dyn/neurons/lif.py:1108-1109` — `ExpIFRef/ExpIFRefLTC` unconditionally `odeint`, dropping any `noise=` (every other `*Ref` guards with `sdeint`). → guard on `self.noise`. -- **H-35 [verified by sub-audit]** `dyn/neurons/lif.py:4495-4496,3814-3815` — `IzhikevichRef`/`GifRef` compute `spike_no_grad` but reset state with the grad-carrying `spike`, so `detach_spk` is a no-op. → use `spike_no_grad` for resets. -- **H-36 [verified]** `dyn/rates/rnncells.py:401,412` — `LSTMCell` `h`/`c` setters slice axis 0 while getters split axis −1; unbatched → `IndexError`, batched → wrong-rows write. → slice the last axis. -- **H-37 [verified]** `dyn/rates/reservoir.py:226` — `noise_rec * uniform(-1,-1, …)` is a constant `-noise_rec` bias, not noise (typo for `uniform(-1,1)`). → fix bounds. -- **H-38 [verified]** `dyn/rates/reservoir.py:191,202-232` — `self.bias` is created (and TrainVar in training) but never added in `update()`. → `hidden += self.bias`. -- **H-39 [verified by sub-audit]** `dyn/synapses/abstract_models.py:880-881` — STP discrete `u`/`x` jumps read pre-decay `self.u`/`self.x` instead of the decayed locals; off by one decay step. → use the decayed `u,x` locals (and apply `x` jump after the `u` jump). -- **H-40 [static]** `dyn/projections/base.py:1-26` — byte-for-byte duplicate of `utils.py` (only a private helper), yet `projections/__init__.py` does `from .base import *` (imports nothing); misleading vs the real `SynConn` base. → delete or re-export the real base classes. -- **H-41 [verified]** `dyn/projections/plasticity.py:232-233` — default `W_min=W_max=None` → `bm.as_jax(None)` raises (first crash even before C-19's path). → pass `None` through unchanged. - -### dynold compat -- **H-42 [verified]** `dynold/neurons/reduced_models.py` LIF/ExpIF/AdExIF — default params silently changed to the modern `*Ref` values (`LIF`: `V_rest=0,V_reset=-5,V_th=20`; `ExpIF/AdExIF`: `V_th=-55`) while docstrings still claim `-65/-68/-30`. → restore historical defaults in the dynold wrappers or fix every docstring. - -### Top-level glue / measure / delay -- **H-43 [verified]** `measure.py:91-92` — `firing_rate` normalizes by requested `width` while the window length is `width1=int(width/2/dt)*2+1≠width/dt`; biased by `width1·dt/width` (e.g. true 100 Hz → oscillates 100↔200, mean 110). → `window = ones(width1)/(width1*dt)*1000`. -- **H-44 [verified]** `delay.py:254-257` (+ class attr only-annotated `:72`) — `VarDelay(target, time=T>0)` reads `self.data` in `_init_data` before it is ever assigned → `AttributeError: 'data'`. → set `self.data = None` unconditionally before the `max_length>0` branch. -- **H-45 [verified]** `delay.py:481` → `math/object_transform/variables.py:106-112` — `DataDelay.reset_state(batch_size)` calls `size_without_batch`, which does `self.size[:batch_axis]+…` but `Variable.size` is the integer element count, not a shape tuple → `TypeError: 'int' object is not subscriptable` for any batched variable. → use `self.shape` in `size_without_batch`. - -### Train / running -- **H-46 [verified]** `algorithms/offline.py:159,386` — `gradient_descent=True` path does `jnp.logical_and(...).value` (no `.value` on a jax array/tracer) → `AttributeError`; breaks every GD regression incl. always-GD `Lasso`/`ElasticNet`. → drop `.value`. -- **H-47 [static]** `algorithms/offline.py:272-276` — ridge `XᵀX+αI` penalizes the prepended bias column (intercept shrunk) and is off by the ½ factor vs the documented `½α‖w‖²`. → zero the `(0,0)` entry of the penalty; reconcile the ½. -- **H-48 [static]** `running/jax_multiprocessing.py:136-156` — `jax_parallelize_map` builds one cached `pmap` reused across chunks; the trailing partial chunk ≠ device count → retrace/crash; also mislabeled `vmap_func`, missing `else: raise`. → build per chunk or pad to a device multiple. - -### Analysis -- **H-49 [static]** `analysis/lowdim/lowdim_analyzer.py:377,953` & `utils/optimization.py:398` — arg-unwrap comprehension tests `isinstance(candidates, bm.Array)` instead of `isinstance(a, …)` (3 copies) → either `AttributeError` or `bm.Array` leaking into `meshgrid`/`vmap`. → test `a`. -- **H-50 [static]** `analysis/lowdim/lowdim_analyzer.py:1038-1040` — non-convertible 2D `_get_fixed_points` does `jnp.concatenate([])` when nothing converges → `ValueError` (the 1D/convertible paths guard, this one doesn't). → empty-guard return. - -### DNN -- **H-51 [verified]** `dnn/normalization.py:100,134,503,588` — `BatchNorm*`/affine `LayerNorm`/`GroupNorm` raise `UnsupportedError` out-of-the-box under the default `NonBatchingMode` (and the affine `assert isinstance(mode, TrainingMode)`); only `mode=bm.training_mode` works. → default to `TrainingMode` when `mode is None`, or raise a clear message; broaden the affine assert. - -### Optim / losses / encoding -- **H-52 [verified]** `optim/optimizer.py:592-594` — `Adam` corrupts an `lr` passed as a `bm.Variable` via in-place `lr /= …; lr *= …` (mutates the shared Variable each step). → non-mutating arithmetic / `bm.as_jax(self.lr())`. -- **H-53 [static]** `losses/comparison.py:194-201` — `CrossEntropyLoss` stores `ignore_index`/`label_smoothing` but never forwards them; `cross_entropy_loss` has no such params → both are silent no-ops. → implement and forward. - ---- - -## 5. Medium findings (condensed) - -| ID | [status] | File:line | Issue → Fix | -|----|----------|-----------|-------------| -| M-01 | verified | `optim/scheduler.py` via `optimizer.py` | `StepLR`/cosine families share the `last_epoch`-never-advances issue feeding C-01; audit all schedulers' step source. | -| M-02 | static | `math/object_transform/jit.py` | `cls_jit` doesn't shift `donate_argnums` → donates `self`. → add param + `+1` shift. | -| M-03 | verified | `controls.py:466-481` | `scan` returns `(carry, ys)` but docstring promises only `ys` (legacy contract change). → fix docs or return `ys`. | -| M-04 | static | `controls.py:391-397` | `for_loop(jit=False)` toggles process-global `jax.disable_jit()` (no-op under an outer trace). → document / brainstate-native opt-out. | -| M-05 | static | `controls.py:207-237` | `ifelse` omits `check_cond=False` though it already guarantees exclusivity → per-call device all-reduce + error branch. → pass `check_cond=False`. | -| M-06 | verified | `controls.py:550-561` | `while_loop` body returning `None` freezes the carry (infinite-loop hazard). → raise on `None`/structure mismatch. | -| M-07 | verified | `math/environment.py:391-428` | `set()`/`set_environment()` mutate globals before validating `numpy_func_return` → partial-config leak on error. → validate first. | -| M-08 | verified | `math/remove_vmap.py:55-85` | under `vmap`, `remove_vmap(x,'any'/'all')` broadcasts the global reduction back over the batch (leaks across examples). → return a true scalar / document. | -| M-09 | static | `math/ndarray.py:259-271` | `ShardedArray.value` getter inserts `with_sharding_constraint` on *every* read (always-true on single-device). → skip `SingleDeviceSharding`. | -| M-10 | static | `math/sharding.py:119-162` | fully-unmatched axis names silently yield a replicated `PartitionSpec(None,…)` instead of erroring. → warn/raise on full mismatch. | -| M-11 | verified | `math/compat_numpy.py:144-160` | `empty`/`empty_like` call `zeros`/`zeros_like` (needless zero-fill, wrong semantics). → `jnp.empty*`. | -| M-12 | verified | `math/compat_numpy.py:129-133` | `fill_diagonal(inplace=False)` returns a raw jax array, not a brainpy `Array`. → `_return(r)`. | -| M-13 | static | `math/jitconn/matvec.py` (+`event_matvec.py`) | `seed=None` draws a host RNG per call → non-reproducible eager, jit-frozen seed. → require/thread an explicit seed; document. | -| M-14 | verified | `math/delayvars.py:215` | `TimeDelay.reset` drops `dtype=get_float()` on `current_time` and ignores callable `before_t0`. → mirror `__init__`. | -| M-15 | verified | `math/pre_syn_post.py:291-293` | `pre2post_mean` scalar branch scatter-sets (no averaging, ignores duplicate post ids). → route through `syn2post_mean` or document. | -| M-16 | static | `dyn/neurons/hh.py:148-194` | `CondNeuGroup.update` passes synaptic current through the `1e-3/A` external-input scaling (double-scales when `A≠1e-3`). → inject into the derivative like the LTC class. | -| M-17 | verified | `dyn/ions/potassium.py:45` | `PotassiumFixed` default `E=-950 mV` (likely typo for `-95`). → fix default (confirm vs intended). | -| M-18 | verified | `dyn/rates/populations.py:370-371` | `FeedbackFHN.reset_state` rebinds `self.input`/`input_y` to fresh Variables (breaks captured refs) instead of `.value=`. → set `.value`. | -| M-19 | verified | `dyn/rates/populations.py:374` | `FeedbackFHN` delay queries `x_delay(t-delay)` while `state_delays` already registers the delay → double-counts (buffer-edge clamp). → query `x_delay(t)`. | -| M-20 | static | `dyn/rates/populations.py:1051-1062` | `ThresholdLinearModel` noise scales as `dt` not `sqrt(dt)` (dt-dependent intensity). → Euler–Maruyama `sqrt(dt)`. | -| M-21 | verified | `dyn/rates/rnncells.py:127,239,375` | `RNN/GRU/LSTMCell.reset_state(None)` builds `(None,num_out)` → `ValueError`. → branch on `None` → `(num_out,)`. | -| M-22 | verified | `dyn/synapses/abstract_models.py:879-881,800-801` | `STP`/`STD` "simplified" updates assume binary `pre_spike`; graded inputs are wrong. → restore graded formula or assert binary. | -| M-23 | static | `dynold/synapses/abstract_models.py` | dual-exp/NMDA/AMPA peak silently renormalized (`g_max` semantics changed vs pre-3.0). → document or auto-scale for compat. | -| M-24 | verified | `dynold/neurons/reduced_models.py:1311` | `ALIFBellec2020` default `a_initializer=OneInit(-50.)` (adaptation var should start ~0). → `ZeroInit()`. | -| M-25 | verified | `dnn/normalization.py:156-158` | `BatchNorm` stores *biased* batch var into `running_var` (PyTorch uses unbiased for the running stat). → apply `N/(N-1)`. | -| M-26 | verified | `dnn/normalization.py:509` | `LayerNorm` shape-mismatch path does `", ".join(int_tuple)` → `TypeError` masking the real error. → `map(str, …)`. | -| M-27 | verified | `dnn/pooling.py:118,390,787` | negative `channel_axis == -x_dim` wrongly rejected (`abs()` bound check). → `-x_dim <= axis < x_dim`. | -| M-28 | verified | `dnn/function.py:91` | `Flatten` default `start_dim=0` contradicts its docstring/PyTorch (`1`) and drops the batch dim. → `start_dim=1` or fix docs. | -| M-29 | verified | `optim/optimizer.py:1039-1096` | `SM3` reads `self.momentum` in `register_train_vars` before it's set → un-instantiable (also torch-style `keepdim=`). → set attrs before `super().__init__`. | -| M-30 | verified | `connect/random_conn.py:99,87-89` | `FixedProb` sparse `build_coo/csr` use `int(post_num*prob)` (floors to 0 for small post; biased density) and forbid `include_self=False` on rectangular shapes with a contradictory message. → round/Bernoulli; drop the guard. | -| M-31 | static | `train/back_propagation.py:522-523` | BPTT `indices = arange(self.i0, …)` but `i0` isn't advanced/pinned → wrong absolute `t` when `reset_state=False`. → pin `arange(0,num_step)` or document. | -| M-32 | verified | `running/runner.py:99-101` | `Runner.__init__` mutates the caller's `jit` dict via `.pop()`. → operate on a copy. | -| M-33 | static | `analysis/stability.py:148-163` | 2D star vs degenerate-node classification is inverted (eigenvalues alone can't distinguish; needs eigenvector rank). → use `matrix_rank(J-λI)`. | -| M-34 | verified | `analysis/stability.py:111-141` | borderline types (center/saddle-node/line) gated on exact float `==0` of autodiff Jacobians → almost never detected. → tolerance bands. | -| M-35 | static | `analysis/highdim/slow_points.py:357-360` | GD fixed-point finder stops on *mean* loss but `tolerance` reads as per-point → outliers left unconverged. → stop on max, or document. | -| M-36 | verified | dyn synapses/`Variable.__float__` | `float(size-1 Variable)` raises under jax 0.10 (`ndim>0`) — breaks common single-neuron monitoring/doctests. → use `.item()`/index; consider squeezing `__float__`. | - ---- - -## 6. Low findings (condensed) - -- **L-01** `math/ndarray.py:31,44-76` — duplicate `'Array'` in `__all__`; dead helpers (`_check_input_array`, `_check_out`, `_get_dtype`, `_all_slice`); `_as_jax_array_` duplicated in `_utils.py`. → de-dup/remove. -- **L-02** `math/scales.py:79-89` — `IdScaling.clone(scale=…)`/`inv_scaling` silently ignore overrides. → raise or honor. -- **L-03** `math/ndarray.py:153-172` vs `:273-292` — base `Array.value` setter has shape/dtype checks commented out while `ShardedArray` enforces them (inconsistent; base allows silent shape change). → one policy. -- **L-04** `object_transform/function.py:44` — `function()` deprecation says "removed after 2.4.0" but ships in 2.7.8; `Partial` lacks a docstring. → update message; add docstring. -- **L-05** `object_transform/_utils.py:24-27` — `__all__` omits the only symbol consumers import (`warp_to_no_state_input_output`). → fix `__all__`. -- **L-06** `object_transform/base.py:192-219` — `tracing_variable` is `raise NotImplementedError` followed by ~25 lines of unreachable code + stale docstring; default-off pytree path is dead. → delete/clean. -- **L-07** `dyn/ions/calcium.py:144` — `CalciumDyna._reversal_potential(C)` ignores its `C` arg (uses `self.C`). → use `C`. -- **L-08** channels — several docstring constants/defaults disagree with code (`IAHP beta=0.09` vs doc `0.03`; `f_q_inf +58` vs `+59`; Ih `tau_m`/`phi`). → reconcile (incl. legacy dups). -- **L-09** `dyn/synapses/delay_couplings.py:131-241` — docstrings reference a `g`/gain param that doesn't exist; malformed `Parameters::`. → fix docs. -- **L-10** `losses/comparison.py:534` — functional `l1_loss` defaults `reduction='sum'` vs docstring/`L1Loss` `'mean'`. → `'mean'`. -- **L-11** `encoding/stateful_encoding.py:111-120` — `LatencyEncoder` docstring example output shape ignores `dt` (`(5,3)` vs real `(50,3)`). → fix example. -- **L-12** `integrators/sde/srk_strong.py:58,392` — dead module with a generated-code syntax error and wrong `compile_code` arg order. → remove or fix+register+test. -- **L-13** `integrators/joint_eq.py:189` — `JointEq` raises a bare message-less `DiffEqError`. → add diagnostic. -- **L-14** Pervasive NumPy-doc nonconformance: `Parameters::`/`Returns::`/`References::` literal-block markers across `math/sparse`, `math/jitconn`, `dyn/rates`, `dyn/synapses`, `measure`, etc. → convert to underlined sections (mandated by CLAUDE.md). -- **L-15** `analysis/utils/others.py:99` — `get_sign2` passes a generator as `reshape` shape (latent, function unused). → `tuple(...)` / remove dead helpers. -- **L-16** `dynold/experimental/__init__.py` empty → whole experimental subpackage unreachable; `dynold/synapses/base.py:233` & `experimental/base.py:98` missing `raise` before `ValueError`. → wire up or delete; add `raise`. - ---- - -## 7. Prioritized remediation roadmap - -**P0 — silent numerical corruption in the most-used paths (fix first):** -C-01 (Adam), C-02 (nll sign), C-03 (CE weight), C-04 (MultiStepLR), C-05 (GroupNorm), C-08 (Caputo), C-09 (TimeDelay), C-10 (disable_x64), C-16 (StuartLandau), C-17 (PoissonInput), C-23 (RLS B>1), H-43 (firing_rate), H-26 (BoSh3), H-20…H-24 (surrogate math). These give wrong answers without erroring. - -**P1 — crash-on-first-use of public APIs:** -C-06, C-07, C-11–C-15, C-18–C-22, C-24–C-25, H-01–H-06, H-10–H-19, H-29–H-36, H-41, H-44–H-46, H-51, M-21, M-29. Many are one-line migration fixes; bundle with a public-surface import/smoke test. - -**P2 — correctness traps & fragility (Medium tier).** **P3 — docs/typing/dead-code hygiene (Low tier), incl. the repo-wide NumPy-doc `::` cleanup.** - -**Systemic actions (do alongside P0/P1):** -1. **Pin & test ecosystem versions.** Bump `pyproject` lower bounds to the actually-tested `brainstate`/`brainevent`/`braintools`/`jax`, and add CI matrix entries. -2. **Public-surface smoke test:** instantiate + one-step every public neuron/synapse/projection/layer/optimizer/encoder under the *default* mode; many P1 crashes would be caught immediately. -3. **Property-based numerical tests** for surrogates (`grad(surrogate_fun) ≈ surrogate_grad`), integrators (convergence order, SDE moments), losses (vs reference), and delays (off-by-one) — the bug classes here recur and need oracles. -4. **Resolve the `bm.surrogate` shadowing** (H-25): decide whether `braintools` or the in-repo package is canonical and delete the other. - ---- - -## Appendix A — Verification status - -- **Independently re-verified by the lead reviewer (executed, reproduced):** C-01..C-13, C-15..C-17, C-20, C-24; H-02..H-08, H-10..H-19, H-20..H-24, H-27..H-39, H-41, H-43..H-46, H-51..H-52; M-06..M-08, M-11..M-12, M-14..M-15, M-17..M-19, M-21..M-22, M-26..M-30, M-32, M-36 — plus the isolated x64/precision and FDE checks. 33 of the highest-impact items were run head-to-head; **all reproduced**. (Two initial "non-reproductions" were lead-reviewer test-harness errors — wrong threshold / arg-order — with the underlying bug confirmed on code inspection.) -- **Verified by the module sub-audits (executed in the same environment):** C-14, C-18..C-19, C-21..C-23, C-25..C-26, H-01, H-09, H-26, H-40, H-47..H-50, H-53 and the remaining Medium/Low items. - -## Appendix B — Checked and found CORRECT (audit negatives) - -To bound the audit, the following were specifically checked and are **not** bugs: all explicit-RK Butcher tableaus (Euler/RK2/RK3/RK4/Ralston/SSPRK convergence orders) and RKF45/CashKarp/Dormand–Prince/BogackiShampine embedded pairs; exponential-Euler exactness on linear ODEs; SDE Euler–Maruyama/Milstein/SRK first/second moments and `sqrt(dt)` scaling; per-step PRNG independence inside the jitted `for_loop`; `CaputoL1Schema`/`GLShortMemory` core numerics (machine precision incl. ring-buffer wrap); Expon/Alpha/DualExpon/NMDA/AMPA/STD synapse kinetics and the STDP sign convention; `gelu/elu/selu/softmax/log_softmax` math; einops parsing edge cases; GRU cell math, NVAR feature construction, MgBlock curve, OU `sqrt(dt)`, reservoir spectral-radius rescaling; Kaiming/Xavier/Lecun/Orthogonal init statistics; `FixedProb.build_mat` density, `GaussianProb` symmetry; `Dense`/`Conv`/`ConvTranspose`/`AvgPool` shapes & layouts, `Dropout` scaling & eval no-op, BatchNorm momentum direction & eval-uses-running-stats; the high-dim `SlowPointFinder` Jacobian recovery; calcium Nernst constant; the `*_compatible.py` channel shims (numerically identical to their v2 sources). - ---- - -*Generated 2026-06-18. Working branch: `worktree-audit-issues-20260618`. Spec & verification scripts under `dev/superpowers/` (gitignored).* diff --git a/docs/issues-found-20260619-optim-losses.md b/docs/issues-found-20260619-optim-losses.md new file mode 100644 index 000000000..d9dbb5f56 --- /dev/null +++ b/docs/issues-found-20260619-optim-losses.md @@ -0,0 +1,187 @@ +# Audit findings — optim + losses slice (2026-06-19) + +Scope: `brainpy/optim/optimizer.py`, `brainpy/optim/scheduler.py`, +`brainpy/losses/base.py`, `brainpy/losses/comparison.py`, +`brainpy/losses/regularization.py`, `brainpy/losses/utils.py`. + +Several Critical/High issues flagged by the 2026-06-18 audit (C-01 Adam/AdamW +bias correction, C-02 nll sign, C-03 CE class-weight, C-04 MultiStepLR, +H-52 Adam lr-Variable mutation, H-53 CrossEntropyLoss ignore/smoothing, +M-29 SM3 attr ordering, CTC `.value` crash) were already fixed in this +worktree before this pass; they were re-verified as fixed and are not +re-listed except where a residual remained. + +--- + +### P1-C1 — `Adan.update` crashes on every call (cond operand mis-binding) [Critical] +- File: brainpy/optim/optimizer.py:818 +- Category: correctness +- What: `cond(step == 0, lambda pg, g: g, lambda pg, g: pg, (prev_g_var.value, g))` + passes a single 2-tuple as the operand. `jax.lax.cond` calls + `branch(*operands)`, so the tuple is splatted: `pg=(prev, g)` and `g` + is left unbound → `TypeError: () missing 1 required positional + argument: 'g'`. The whole `Adan` optimizer is unusable (both `no_prox` + branches are unreachable). +- Why it's a bug: `jax.lax.cond(pred, true_fun, false_fun, *operands)` unpacks + `operands` positionally into each branch function. A 2-tuple operand must be + matched by a 2-arg branch, OR the two values must be passed as two operands. +- Repro: `O.Adan(lr=1e-2, train_vars={'w': v}).update({'w': g})` → TypeError. +- Fix: replaced the broken first-step guard with a self-managed per-update + step `bm.Variable` (mirrors the `Adam`/`AdamW` pattern, independent of the + LR scheduler's stale `last_epoch`). On the first update `pre_g := g` (so the + gradient difference is 0) via `jnp.where(step == 1, g, prev_g)`; the moment + bias-correction terms now use the real per-update step `t`. +- Tests: `test_adan_runs_and_updates`, `test_adan_no_prox_runs`, + `test_adan_first_step_diff_is_zero`, `test_adan_step_counter_advances` + (optimizer_test.py). Updated the obsolete `optimizer_coverage_test.py::TestAdan::test_update_is_currently_broken`. +- Status: fixed + +### P1-C2 — `Adan` step counter frozen at 0 (bias correction never advances) [Critical] +- File: brainpy/optim/optimizer.py:808-811 +- Category: correctness +- What: `step = self.lr.last_epoch.value + 1`. With the default `Constant` + scheduler `last_epoch` starts at `-1` and is *never* advanced by `update()` + (only `step_epoch()` advances it, which optimizers never call). So `step` + is `0` forever, `correct_{m,v,n} = 1/(1-(1-beta)**1)` is frozen, and the + first-step guard `step == 0` is *always* true → the gradient-difference term + `v` is permanently pinned to 0 and Nesterov momentum is disabled. This is the + Adan analogue of C-01 (Adam/AdamW), which was fixed with a self-managed + counter; Adan was left on the stale `last_epoch` source. +- Why it's a bug: same root cause as C-01 / M-01 — `last_epoch` is not the + per-update step source. +- Repro: static + `test_adan_step_counter_advances`. +- Fix: introduced `self.step` `bm.Variable`, incremented once per `update()`; + bias-correction and the first-step guard now key off it. (Fixed together + with P1-C1.) +- Tests: as P1-C1. +- Status: fixed + +### P1-H1 — `SM3` raises `KeyError` for scalar (0-dim) trainable variables [High] +- File: brainpy/optim/optimizer.py:1083-1089, 1099-1104 +- Category: edge/error +- What: `register_train_vars` builds accumulators with `for i in range(ndim)`. + For a 0-dim (scalar) variable `ndim == 0`, so no `{k}_m{i}` accumulator is + created, yet `update` unconditionally reads `self.implicit_vars[f'{k}_m0']` + → `KeyError: 'w_m0'`. Scalar parameters are common (e.g. a learnable bias / + temperature), so SM3 crashes on first use for them. +- Why it's a bug: the cover construction degenerates for rank-0 tensors. SM3 + on a scalar should behave like Adagrad on a single accumulator. +- Repro: `O.SM3(lr=0.1, train_vars={'w': scalar_var}).update({'w': g})` → KeyError. +- Fix: treat rank-0 variables as rank-1 — always register at least one + accumulator (`max(ndim, 1)`), and in `update` clamp `ndim` to `>= 1` so the + single-accumulator path runs (equivalent to Adagrad for scalars). +- Tests: `test_sm3_scalar_var_runs` (optimizer_test.py). Updated the obsolete + `optimizer_coverage_test.py::TestSM3::test_scalar_var_is_broken`. +- Status: fixed + +### P1-H2 — `multi_margin_loss` crashes on `bm.Array` inputs [High] +- File: brainpy/losses/comparison.py:1048-1051 +- Category: edge/error +- What: the function indexes `predicts[jnp.arange(batch_size), targets]` and + calls `.at[...]` directly on the (possibly `bm.Array`) `predicts`/`targets` + without `bm.as_jax(...)`. Under JAX ≥0.9 passing a `bm.Array` into + `jnp`/advanced-indexing raises + `ValueError: Triggering __jax_array__() during abstractification is no + longer supported`. Every other loss in this file accepts `bm.Array` + (the public brainpy array type), so this is an inconsistent crash trap; + the public API contract is broken for the documented array type. +- Why it's a bug: ecosystem drift — implicit `__jax_array__` coercion was + removed; brainpy losses must convert via `bm.as_jax`. +- Repro: `C.multi_margin_loss(bm.asarray(p), bm.asarray(t))` → ValueError. +- Fix: convert `predicts`/`targets` via `bm.as_jax(...)` at the top of the + function (idiom used throughout the module). +- Tests: `test_multi_margin_accepts_bm_array` (comparison_test.py). +- Status: fixed + +### P1-M1 — `l1_loss` functional default is `'sum'` (mismatches class + docstring) [Medium] +- File: brainpy/losses/comparison.py:573 +- Category: api-drift +- What: `def l1_loss(logits, targets, reduction='sum')`. The `L1Loss` class, + the function's own docstring ("Default: ``'mean'``"), PyTorch's `L1Loss`, + and `braintools.metric.l1_loss` (default `'sum'` in code but `'mean'` in its + docstring) all disagree. A user calling `l1_loss(x, y)` silently gets the + summed (not mean) error. +- Why it's a bug: surprising default that contradicts the documented contract + and the OO wrapper. Flagged as L-10 in the 2026-06-18 audit; on closer review + it is a behavioral/contract mismatch, not pure style, so treated as Medium. +- Repro: `C.l1_loss(x, y)` returns the sum, while `C.L1Loss()(x, y)` returns + the mean. +- Fix: changed the functional default to `reduction='mean'` to match the class, + the docstring and PyTorch. +- Tests: `test_l1_loss_default_is_mean` (comparison_test.py). The pre-existing + `comparison_test.py::TestReductionDefaults::test_l1_loss_default_is_sum` + asserted the buggy default and is updated. +- Status: fixed + +### P1-L1 — `l1_loss` `'none'` reduction returns per-row L1 *norm*, not elementwise abs [Low] +- File: brainpy/losses/comparison.py:621 (delegates to `braintools.metric.l1_loss`) +- Category: style/docs +- What: brainpy's `l1_loss` docstring promises the unreduced loss is the + elementwise `l_n = |x_n - y_n|` (PyTorch semantics, same shape as input). + The delegated `braintools.metric.l1_loss` reshapes to `(N, -1)` and returns + the per-row L1 norm (sum over trailing axes), so for a `(2, 2)` input the + `'none'` output is shape `(2,)` per-row sums, not `(2, 2)`. The two + pre-existing coverage tests `TestRegressionLosses::test_l1_loss_reductions` + / `test_l1_class` encode an *incorrect* expectation (per-row mean) and fail + on baseline regardless of any change here. +- Why it's a bug: docstring/behavior mismatch introduced by the braintools + delegation. Not fixed because correcting it would require either re-routing + `l1_loss` away from `braintools.metric` (cross-cutting, changes public + numerics for all callers) or rewriting upstream — out of scope. +- Repro: `np.asarray(C.l1_loss([[1,2],[3,4]], zeros, reduction='none'))` → + `[3., 7.]` (shape `(2,)`), docstring implies shape `(2, 2)`. +- Fix: recorded only. (The two stale coverage tests are left as-is except where + P1-M1 forces an update to the default-reduction test; the two `l1_loss` + reduction-value assertions are pre-existing baseline failures noted in the + report.) +- Status: recorded-only + +### P1-L2 — `Adan._update_moments` is dead code with mismatched return order [Low] +- File: brainpy/optim/optimizer.py:798-803 +- Category: style/dead code +- What: the helper computes `m`, `gd`, `v`, `n` and returns `(m, n, v)`, but it + is never used by `update` (which inlines the same math). The lone caller is a + coverage test. The `(m, n, v)` ordering vs the `m, n, v = ...` unpacking in + callers is internally consistent but the helper duplicates logic that can + drift from `update`. +- Why it's a bug: dead/duplicated code; maintenance hazard. +- Fix: recorded only (not removed — keeping the diff focused on correctness; a + test references it). +- Status: recorded-only + +### P1-L3 — `Optimizer`/`SGD`/etc. docstrings use `Parameters::` (renders as literal block) [Low] +- File: brainpy/optim/optimizer.py:46, 114, 169, … (pervasive); scheduler.py:91, … +- Category: style/docs +- What: NumPy-doc sections are written as `Parameters::` / `References::` + (double colon → reStructuredText literal block) instead of the underlined + `Parameters\n----------` form required by the project docstring style. + Sphinx will not render these as parameter tables. +- Why it's a bug: violates the repo's documented NumPy-doc convention. +- Fix: recorded only (pervasive cosmetic change, out of correctness scope). +- Status: recorded-only + +### P1-L4 — `SM3` not exported in `optimizer.__all__` [Low] +- File: brainpy/optim/optimizer.py:28-40 +- Category: style/api +- What: `SM3` is a fully implemented public optimizer class but is absent from + `__all__` (every other optimizer is listed), so `from ...optimizer import *` + does not expose it and Sphinx `automodule` may skip it. +- Why it's a bug: minor public-API inconsistency. +- Fix: recorded only (touching `__all__` export surface is borderline; left to + the toplevel-glue owner since `optim/__init__.py` is out of scope). +- Status: recorded-only + +### P1-L5 — `ExponentialLR.__call__` uses `self.last_epoch + 1` (Variable, not `.value`) [Low] +- File: brainpy/optim/scheduler.py:343 +- Category: style/correctness-latent +- What: `i = (self.last_epoch + 1) if i is None else i` reads the `bm.Variable` + object directly rather than `self.last_epoch.value` (as `StepLR`, + `CosineAnnealingLR`, etc. do). It happens to work because `bm.Variable` + supports `+`, but it is inconsistent and the `__repr__` similarly prints the + Variable object. Latent fragility, not an active wrong-result bug (the + arithmetic broadcasts correctly). +- Why it's a bug: inconsistent Variable handling; could surprise under tracing + if the Variable identity leaks. +- Fix: recorded only (no observed wrong result; default optimizers never call + `ExponentialLR()` with `i=None` in a traced loop in-scope). +- Status: recorded-only