From b87e976c1f9af9f448025d43066c1e5562fb2b4e Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 19 Jun 2026 03:09:18 +0800 Subject: [PATCH] fix(core): DSRunner memory_efficient output/monitor axes; eager check.is_float/is_integer bounds - DSRunner(memory_efficient=True) returned None instead of model outputs (list.append inside tree_map); now accumulates and stacks per-step outputs along a leading time axis (Critical) - memory_efficient monitors came out time-major vs the standard path's batch-major for BatchingMode; moveaxis to match (High) - check.is_float/is_integer min_bound/max_bound never raised eagerly (pure_callback swallowed the raise for concrete predicates); add a concrete-predicate fast path that raises directly, keeping the deferred cond/callback path under tracing (High) Findings recorded in docs/issues-found-20260619-toplevel-glue.md --- brainpy/check.py | 21 +- brainpy/check_coverage_test.py | 60 +++-- brainpy/check_test.py | 53 +++++ brainpy/dyn_runner_test.py | 153 ++++++++++++- brainpy/runners.py | 34 ++- docs/issues-found-20260619-toplevel-glue.md | 240 ++++++++++++++++++++ 6 files changed, 527 insertions(+), 34 deletions(-) create mode 100644 docs/issues-found-20260619-toplevel-glue.md diff --git a/brainpy/check.py b/brainpy/check.py index 3af3d74fa..b3724d629 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -620,6 +620,16 @@ def jit_error(pred, err_fun, err_arg=None): The arguments which passed into `err_f`. """ from brainpy.math.interoperability import as_jax + + # Fast path for a concrete (non-traced) predicate. ``jax.pure_callback`` only + # surfaces its raised exception when the staged computation is *executed*; for + # an eager, concrete predicate the callback never runs synchronously, so the + # error would be silently swallowed. Evaluate and raise directly instead. + if not isinstance(pred, jax.core.Tracer): + if bool(np.asarray(pred)): + err_fun(err_arg) + return + partial(_cond, err_fun)(as_jax(pred), err_arg) @@ -641,7 +651,16 @@ def jit_error_checking_no_args(pred: bool, err: Exception): assert isinstance(err, Exception), 'Must be instance of Exception.' - def true_err_fun(arg, transforms): + # Fast path for a concrete (non-traced) predicate. The ``jax.pure_callback`` + # below only raises when the staged computation is *executed*; eagerly the + # exception is never surfaced synchronously, so an out-of-bound value would + # be silently accepted (e.g. ``is_float(2.0, max_bound=1.0)``). Raise here. + if not isinstance(pred, jax.core.Tracer): + if bool(np.asarray(pred)): + raise err + return + + def true_err_fun(*args): raise err cond(unvmap(as_jax(pred)), diff --git a/brainpy/check_coverage_test.py b/brainpy/check_coverage_test.py index e2070bcd1..b9e880dbf 100644 --- a/brainpy/check_coverage_test.py +++ b/brainpy/check_coverage_test.py @@ -14,6 +14,7 @@ ``is_all_vars``, ``is_all_objs``), ``serialize_kwargs``, and the JIT error helpers (``jit_error``, ``jit_error_checking``, ``jit_error_checking_no_args``). """ +import jax import numpy as np import pytest @@ -306,18 +307,18 @@ def test_min_bound_ok(self): assert checking.is_float(5.0, min_bound=1.0) == 5.0 def test_min_bound_branch(self): - # NOTE: min/max bound checks route through jit_error_checking_no_args - # which uses jax.pure_callback. Eagerly the side-effect raise is NOT - # propagated, so an out-of-bound value returns rather than raising. - # We exercise the branch (line coverage) and pin the no-raise behavior. - assert checking.is_float(0.5, min_bound=1.0, name='v') == 0.5 + # P14-H2: eager out-of-bound checks now raise (previously the + # jit_error_checking_no_args pure_callback path silently accepted them). + with pytest.raises(Exception): + checking.is_float(0.5, min_bound=1.0, name='v') def test_max_bound_ok(self): assert checking.is_float(5.0, max_bound=10.0) == 5.0 def test_max_bound_branch(self): - # NOTE: see test_min_bound_branch -- eager pure_callback does not raise. - assert checking.is_float(20.0, max_bound=10.0, name='v') == 20.0 + # P14-H2: eager out-of-bound checks now raise. + with pytest.raises(Exception): + checking.is_float(20.0, max_bound=10.0, name='v') # --------------------------------------------------------------------------- # @@ -354,16 +355,17 @@ def test_min_bound_ok(self): assert checking.is_integer(5, min_bound=1) == 5 def test_min_bound_branch(self): - # NOTE: bound check goes through jit_error_checking_no_args (pure_callback); - # eager evaluation does not propagate the raise -- value is returned. - assert checking.is_integer(0, min_bound=1, name='v') == 0 + # P14-H2: eager out-of-bound checks now raise. + with pytest.raises(Exception): + checking.is_integer(0, min_bound=1, name='v') def test_max_bound_ok(self): assert checking.is_integer(5, max_bound=10) == 5 def test_max_bound_branch(self): - # NOTE: see test_min_bound_branch. - assert checking.is_integer(20, max_bound=10, name='v') == 20 + # P14-H2: eager out-of-bound checks now raise. + with pytest.raises(Exception): + checking.is_integer(20, max_bound=10, name='v') # --------------------------------------------------------------------------- # @@ -533,18 +535,26 @@ def test_all_objs_bad(self): # jit error helpers # --------------------------------------------------------------------------- # class TestJitErrors: - # NOTE: these helpers wrap ``jax.lax.cond`` + ``jax.pure_callback``. When the - # predicate is True the error callback executes, but eagerly the raised - # exception is NOT propagated synchronously to the caller (a known quirk of - # pure_callback used for in-jit error signalling). We therefore exercise the - # True/False branches for line coverage and assert they run without crashing. + # P14-H2: for a *concrete* predicate these helpers now raise synchronously + # (previously the ``jax.pure_callback`` path silently swallowed the raise). + # Under tracing they keep the deferred ``cond`` + ``pure_callback`` path. def test_no_args_pred_false(self): - # pred False -> false branch only + # pred False -> false branch only, no raise checking.jit_error_checking_no_args(False, ValueError('boom')) def test_no_args_pred_true(self): - # exercises the true branch (callback path), no propagated raise eagerly - checking.jit_error_checking_no_args(True, ValueError('boom')) + # concrete True -> raises synchronously + with pytest.raises(ValueError): + checking.jit_error_checking_no_args(True, ValueError('boom')) + + def test_no_args_under_jit(self): + # tracer predicate -> deferred, no raise at trace time + @jax.jit + def f(x): + checking.jit_error_checking_no_args(x > 1.0, ValueError('boom')) + return x + + assert float(f(0.0)) == 0.0 def test_no_args_bad_err(self): with pytest.raises(AssertionError): @@ -561,15 +571,17 @@ def test_jit_error_true(self): def err_fun(arg): raise ValueError('boom') - # exercises the true branch + _err_jit_true_branch single-array path - checking.jit_error(True, err_fun, bm.as_jax(bm.zeros(2))) + # concrete True -> raises synchronously + with pytest.raises(ValueError): + checking.jit_error(True, err_fun, bm.as_jax(bm.zeros(2))) def test_jit_error_true_tuple_arg(self): def err_fun(arg): raise ValueError('boom') - # exercises the tuple/list branch of _err_jit_true_branch - checking.jit_error(True, err_fun, (bm.as_jax(bm.zeros(2)), bm.as_jax(bm.ones(3)))) + # concrete True with a tuple err_arg -> raises synchronously + with pytest.raises(ValueError): + checking.jit_error(True, err_fun, (bm.as_jax(bm.zeros(2)), bm.as_jax(bm.ones(3)))) def test_alias(self): assert checking.jit_error_checking is checking.jit_error diff --git a/brainpy/check_test.py b/brainpy/check_test.py index e5bb713c6..220b28d4d 100644 --- a/brainpy/check_test.py +++ b/brainpy/check_test.py @@ -15,9 +15,62 @@ # ============================================================================== import unittest +import jax + from brainpy import check as checking +class TestBoundChecks(unittest.TestCase): + """Regression tests for P14-H2: eager bound checks must actually raise. + + ``is_float``/``is_integer`` route their ``min_bound``/``max_bound`` checks + through ``jit_error_checking_no_args``. The previous implementation used a + ``jax.pure_callback`` whose raise never propagated for a *concrete* (eager) + predicate, so out-of-bound values were silently accepted. + """ + + def test_is_float_min_bound_raises(self): + with self.assertRaises(Exception): + checking.is_float(0.5, 'v', min_bound=1.0) + + def test_is_float_max_bound_raises(self): + with self.assertRaises(Exception): + checking.is_float(20.0, 'v', max_bound=10.0) + + def test_is_float_within_bounds_ok(self): + self.assertEqual(checking.is_float(5.0, 'v', min_bound=1.0, max_bound=10.0), 5.0) + + def test_is_integer_min_bound_raises(self): + with self.assertRaises(Exception): + checking.is_integer(0, 'v', min_bound=1) + + def test_is_integer_max_bound_raises(self): + with self.assertRaises(Exception): + checking.is_integer(20, 'v', max_bound=10) + + def test_is_integer_within_bounds_ok(self): + self.assertEqual(checking.is_integer(5, 'v', min_bound=1, max_bound=10), 5) + + def test_no_args_concrete_true_raises(self): + with self.assertRaises(ValueError): + checking.jit_error_checking_no_args(True, ValueError('boom')) + + def test_no_args_concrete_false_ok(self): + # must not raise + checking.jit_error_checking_no_args(False, ValueError('boom')) + + def test_no_args_under_jit_does_not_raise_at_trace(self): + # When the predicate is a tracer (inside jit) the check must NOT raise + # at trace time; it stays a deferred in-jit error signal. + @jax.jit + def f(x): + checking.jit_error_checking_no_args(x > 1.0, ValueError('boom')) + return x + + # tracing/compiling with a value that does not trip the predicate runs fine + self.assertEqual(float(f(0.0)), 0.0) + + class TestUtils(unittest.TestCase): def test_check_shape(self): all_shapes = [ diff --git a/brainpy/dyn_runner_test.py b/brainpy/dyn_runner_test.py index 94ed8ee98..6b5638880 100644 --- a/brainpy/dyn_runner_test.py +++ b/brainpy/dyn_runner_test.py @@ -15,11 +15,26 @@ # ============================================================================== import unittest +import numpy as np + import brainpy as bp import brainpy.math as bm +# Capture the default ``dt`` before any test mutates it. ``DSRunner(dt=...)`` +# permanently writes ``dt`` into the global brainstate environment (via +# ``share.save(dt=...)``), which otherwise leaks ``dt`` into later test files +# (e.g. delay tests that assume the default ``dt=0.1``). +_DEFAULT_DT = bm.get_dt() + + +class _DtRestoreMixin: + """Restore the global ``dt`` after each test that runs a ``DSRunner``.""" + + def tearDown(self): + bm.set_dt(_DEFAULT_DT) -class TestDSRunner(unittest.TestCase): + +class TestDSRunner(_DtRestoreMixin, unittest.TestCase): def test1(self): class ExampleDS(bp.DynamicalSystem): def __init__(self): @@ -89,5 +104,137 @@ def __init__(self, scale=1.0, method='exp_auto'): inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2) -class TestMemoryEfficient(unittest.TestCase): - pass +class TestMemoryEfficient(_DtRestoreMixin, unittest.TestCase): + """Regression tests for ``DSRunner(memory_efficient=True)`` (P14-C1). + + The memory-efficient path collects monitors via a host-side callback and + stacks the model's ``update()`` outputs manually. A previous bug used + ``list.append`` inside ``tree_map`` (which returns ``None``) so ``.run()`` + silently returned ``None`` instead of the time-stacked outputs. + """ + + def _scalar_ds(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1. + return self.i.value + + return ExampleDS + + def test_output_matches_normal_scalar(self): + DS = self._scalar_ds() + + out_normal = bp.DSRunner(DS(), dt=1., progress_bar=False, + memory_efficient=False).run(5.) + out_mem = bp.DSRunner(DS(), dt=1., progress_bar=False, + memory_efficient=True).run(5.) + + out_normal = np.asarray(out_normal) + out_mem = np.asarray(out_mem) + # the memory-efficient output must not be lost + self.assertIsNotNone(out_mem.dtype) + self.assertEqual(out_normal.shape, out_mem.shape) + self.assertTrue(np.allclose(out_normal, out_mem)) + self.assertTrue(np.allclose(out_mem.ravel(), [1., 2., 3., 4., 5.])) + + def test_output_matches_normal_pytree(self): + # the update returns a dict (a non-trivial pytree of outputs) + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.i = bm.Variable(bm.zeros(2)) + + def update(self): + self.i += 1. + return {'a': self.i.value, 'b': self.i.value * 2.} + + out_normal = bp.DSRunner(ExampleDS(), dt=1., progress_bar=False, + memory_efficient=False).run(4.) + out_mem = bp.DSRunner(ExampleDS(), dt=1., progress_bar=False, + memory_efficient=True).run(4.) + + for key in ('a', 'b'): + a = np.asarray(out_normal[key]) + b = np.asarray(out_mem[key]) + self.assertEqual(a.shape, b.shape) + self.assertEqual(a.shape, (4, 2)) + self.assertTrue(np.allclose(a, b)) + + def test_monitors_still_match(self): + DS = self._scalar_ds() + r_n = bp.DSRunner(DS(), dt=1., monitors=['i'], progress_bar=False, + memory_efficient=False) + r_n.run(5.) + r_m = bp.DSRunner(DS(), dt=1., monitors=['i'], progress_bar=False, + memory_efficient=True) + r_m.run(5.) + self.assertTrue(np.allclose(np.asarray(r_n.mon['i']), + np.asarray(r_m.mon['i']))) + self.assertTrue(np.allclose(np.asarray(r_n.mon['ts']), + np.asarray(r_m.mon['ts']))) + + def test_output_none_when_update_returns_none(self): + # an ``update()`` with no explicit return must give ``None`` in both + # paths (the all-``None`` pytree collapses to ``None``), not crash. + class DS(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1. # returns None + + out_n = bp.DSRunner(DS(), dt=1., monitors=['i'], progress_bar=False, + memory_efficient=False).run(5.) + r_m = bp.DSRunner(DS(), dt=1., monitors=['i'], progress_bar=False, + memory_efficient=True) + out_m = r_m.run(5.) + self.assertIsNone(out_n) + self.assertIsNone(out_m) + self.assertTrue(np.allclose(np.asarray(r_m.mon['i']).ravel(), + [1., 2., 3., 4., 5.])) + + def test_batched_monitor_axis_matches_normal(self): + """P14-H1: for BatchingMode + data_first_axis='B' (the default for + batched models) the memory-efficient monitors must come out batch-major + ``(B, T, ...)``, identical to the standard path. The bug returned + time-major ``(T, B, ...)`` only for the memory-efficient path.""" + + class Net(bp.DynamicalSystem): + def __init__(self): + super().__init__(mode=bm.BatchingMode(4)) + self.n = bp.dyn.LifRef(3, mode=bm.BatchingMode(4)) + + def update(self, inp): + self.n(inp) + return self.n.V.value + + inp = bm.ones((4, 8, 3)) * 2.0 # (batch, time, features), data_first_axis='B' + + bm.random.seed(0) + net = Net(); net.reset(4) + r_n = bp.DSRunner(net, monitors=['n.V'], memory_efficient=False, + progress_bar=False) + out_n = r_n.run(inputs=inp) + + bm.random.seed(0) + net2 = Net(); net2.reset(4) + r_m = bp.DSRunner(net2, monitors=['n.V'], memory_efficient=True, + progress_bar=False) + out_m = r_m.run(inputs=inp) + + # outputs are batch-major in both paths + self.assertEqual(np.asarray(out_n).shape, np.asarray(out_m).shape) + self.assertEqual(np.asarray(out_m).shape, (4, 8, 3)) + self.assertTrue(np.allclose(np.asarray(out_n), np.asarray(out_m))) + + # monitors must share the same (batch, time, features) layout + self.assertEqual(np.asarray(r_n.mon['n.V']).shape, + np.asarray(r_m.mon['n.V']).shape) + self.assertEqual(np.asarray(r_m.mon['n.V']).shape, (4, 8, 3)) + self.assertTrue(np.allclose(np.asarray(r_n.mon['n.V']), + np.asarray(r_m.mon['n.V']))) diff --git a/brainpy/runners.py b/brainpy/runners.py index 9afb9951d..6ff421d83 100644 --- a/brainpy/runners.py +++ b/brainpy/runners.py @@ -495,8 +495,21 @@ def predict( # post-running for monitors if self._memory_efficient: self.mon['ts'] = indices * self.dt + self.t0 + # The memory-efficient path appends per-step monitor values on the + # host, so they are always stacked time-major ``(T, ...)``. For a + # batching target with ``data_first_axis='B'`` the standard path + # re-orders monitors to ``(B, T, ...)`` (see ``_predict``); mirror + # that here so the monitor layout does not silently depend on the + # ``memory_efficient`` flag. + batch_major = (isinstance(self.target.mode, bm.BatchingMode) + and self.data_first_axis == 'B') for key in self._monitors.keys(): - self.mon[key] = np.asarray(self.mon[key]) + arr = np.asarray(self.mon[key]) + # ``arr`` is stacked time-major ``(T, B, ...)``; ndim >= 2 means a + # batch axis is present to swap with time (mirrors ``_predict``). + if batch_major and arr.ndim >= 2: + arr = np.moveaxis(arr, 0, 1) + self.mon[key] = arr else: hists['ts'] = indices * self.dt + self.t0 if self.numpy_mon_after_run: @@ -658,13 +671,22 @@ def _fun_predict(self, indices, *inputs, shared_args=None): else: run_fun = self._step_func_predict - outs = None + # Collect the per-step ``update()`` outputs and stack them along a + # new leading time axis, matching ``bm.for_loop``'s time-major + # stacking used by the standard (non-memory-efficient) path. + # + # NOTE: do *not* accumulate via ``tree_map(lambda a, o: o.append(a), ...)``: + # ``list.append`` returns ``None`` (so the accumulator would be reset + # to a tree of ``None`` every iteration) and an empty list ``[]`` is an + # empty pytree node rather than a leaf. We accumulate the whole outputs + # in a flat Python list and stack once at the end. + outs = [] for i in range(indices.shape[0]): out, _ = run_fun(indices[i], *tree_map(lambda a: a[i], inputs)) - if outs is None: - outs = tree_map(lambda a: [], out) - outs = tree_map(lambda a, o: o.append(a), out, outs) - outs = tree_map(lambda a: bm.as_jax(a), outs) + outs.append(out) + if len(outs) == 0: + return None, None + outs = tree_map(lambda *os: jnp.stack([bm.as_jax(o) for o in os]), *outs) return outs, None else: diff --git a/docs/issues-found-20260619-toplevel-glue.md b/docs/issues-found-20260619-toplevel-glue.md new file mode 100644 index 000000000..306afb1e1 --- /dev/null +++ b/docs/issues-found-20260619-toplevel-glue.md @@ -0,0 +1,240 @@ +# Audit findings — top-level glue (P14) — 2026-06-19 + +Scope: brainpy/{dynsys,runners,delay,mixin,transform,context,check,measure,helpers, +checkpoints,_errors,deprecations,types,visualization,channels,neurons,synapses, +synouts,synplast,layers,rates}.py (+ co-located *_test.py). + +Branch: fix/audit-20260619-toplevel-glue + +Counts: Critical 1, High 2, Medium 1, Low 4. Fixed 4 (C1, H1, H2 + the M2 test-fragility); +recorded-only 4 (M2 source-level, L1, L2, L3, L4). + +| ID | Severity | File | Status | +|----|----------|------|--------| +| P14-C1 | Critical | runners.py:661 | fixed | +| P14-H1 | High | runners.py:496 | fixed | +| P14-H2 | High | check.py:610,629 | fixed | +| P14-M2 | Medium | runners.py:628 | recorded (source); test fragility fixed | +| P14-L4 | Low | measure.py:88 | recorded-only | +| P14-L1 | Low | dynsys.py:579 | recorded-only | +| P14-L2 | Low | mixin.py:258 | recorded-only | +| P14-L3 | Low | check.py:87 | recorded-only | + +--- + +### P14-C1 — `DSRunner(memory_efficient=True)` returns `None` instead of the model outputs [Critical] +- File: brainpy/runners.py:661-668 (`_fun_predict`, memory-efficient branch) +- Category: correctness +- What: The per-step output accumulation is doubly broken. + ```python + outs = None + for i in range(indices.shape[0]): + out, _ = run_fun(indices[i], *tree_map(lambda a: a[i], inputs)) + if outs is None: + outs = tree_map(lambda a: [], out) # (1) empty list is NOT a leaf + outs = tree_map(lambda a, o: o.append(a), out, outs) # (2) list.append() returns None + outs = tree_map(lambda a: bm.as_jax(a), outs) + ``` + (1) `tree_map(lambda a: [], out)` maps each leaf to an empty list `[]`; but jax treats `[]` + as an empty *container node*, not a leaf, so `outs` ends up as an empty pytree. + (2) `o.append(a)` returns `None`, so `tree_map(..., out, outs)` rebinds `outs` to a tree of + `None` on every iteration. The final return value of `.run()/.predict()` is `None`/garbage. +- Why it's a bug: A `memory_efficient=True` run silently returns the wrong output (`None`) for + the model's `update()` return value in the default, documented usage. Monitors happen to be + collected separately via `jax.debug.callback`, so the existing tests (which only check + `runner.mon[...]`) never caught it. The standard (`memory_efficient=False`) path returns the + outputs stacked along a leading time axis; the two paths must agree. +- Repro: + ```python + class DS(bp.DynamicalSystem): + def __init__(self): + super().__init__(); self.i = bm.Variable(bm.zeros(1)) + def update(self): + self.i += 1.; return self.i.value + out_n = bp.DSRunner(DS(), dt=1., progress_bar=False, memory_efficient=False).run(5.) # [1,2,3,4,5] + out_m = bp.DSRunner(DS(), dt=1., progress_bar=False, memory_efficient=True).run(5.) # [None] <-- BUG + ``` +- Fix: Accumulate per-step outputs in a flat Python list and stack along a new leading axis + (matching `bm.for_loop`'s time-major stacking): + ```python + outs = [] + for i in range(indices.shape[0]): + out, _ = run_fun(indices[i], *tree_map(lambda a: a[i], inputs)) + outs.append(out) + if len(outs) == 0: + return None, None + outs = tree_map(lambda *os: bm.as_jax(jnp.stack([bm.as_jax(o) for o in os])), *outs) + return outs, None + ``` +- Tests: dyn_runner_test.py::TestMemoryEfficient::test_output_matches_normal_scalar, + test_output_matches_normal_pytree, test_output_empty (in `dyn_runner_test.py`). +- Status: fixed + +--- + +### P14-H1 — `DSRunner(memory_efficient=True)` returns monitors with a different axis order than the standard path (batching mode) [High] +- File: brainpy/runners.py:496-499 (`predict`, memory-efficient monitor finalisation) +- Category: correctness +- What: The memory-efficient path appends per-step monitor values on the host, so they are + always stacked time-major ``(T, ...)``. The standard path re-orders monitors (and outputs) to + batch-major ``(B, T, ...)`` for a ``BatchingMode`` target with ``data_first_axis='B'`` (see + ``_predict`` :542-544). The memory-efficient finalisation did `self.mon[key] = np.asarray(...)` + with **no** re-ordering, so for batched models (where `data_first_axis` defaults to `'B'`) the + monitor layout silently depends on the `memory_efficient` flag. +- Why it's a bug: `memory_efficient` is documented as a memory/perf toggle; it must not change + the shape/orientation of the returned monitors. A user toggling it gets `(T, B, F)` vs + `(B, T, F)` monitors — silently wrong indexing downstream. The model *outputs* were already + re-ordered correctly by `_predict`, so outputs and monitors disagreed. +- Repro: + ```python + class Net(bp.DynamicalSystem): + def __init__(self): + super().__init__(mode=bm.BatchingMode(4)) + self.n = bp.dyn.LifRef(3, mode=bm.BatchingMode(4)) + def update(self, inp): + self.n(inp); return self.n.V.value + inp = bm.ones((4, 8, 3)) * 2.0 # (B, T, F), data_first_axis defaults to 'B' + r_n = bp.DSRunner(Net(), monitors=['n.V'], memory_efficient=False, progress_bar=False); r_n.run(inputs=inp) + r_m = bp.DSRunner(Net(), monitors=['n.V'], memory_efficient=True, progress_bar=False); r_m.run(inputs=inp) + r_n.mon['n.V'].shape # (4, 8, 3) + r_m.mon['n.V'].shape # (8, 4, 3) <-- BUG, time/batch swapped + ``` +- Fix: in the memory-efficient monitor finalisation, when the target is in `BatchingMode` and + `data_first_axis == 'B'`, `np.moveaxis(arr, 0, 1)` each monitor (ndim >= 2) to make it + batch-major, mirroring `_predict`. +- Tests: dyn_runner_test.py::TestMemoryEfficient::test_batched_monitor_axis_matches_normal +- Status: fixed + +--- + +### P14-H2 — `is_float`/`is_integer` `min_bound`/`max_bound` checks never raise (eager validation is a silent no-op) [High] +- File: brainpy/check.py:629-649 (`jit_error_checking_no_args`); also `jit_error` :610-623 +- Category: edge/error +- What: `is_float`/`is_integer` route their bound checks through + `jit_error_checking_no_args(value < min_bound, ValueError(...))`, which used + `jax.lax.cond(..., lambda: jax.pure_callback(true_err_fun, None), lambda: None)`. For a + *concrete* (eager) predicate, `jax.pure_callback` only raises when the staged computation is + executed — under an eager `cond` the callback's exception is never surfaced synchronously, so + the function returns normally. (`true_err_fun(arg, transforms)` also had the wrong arity for a + no-operand callback.) Result: every eager out-of-bound check is a no-op. +- Why it's a bug: parameter validation across the codebase silently accepts invalid values, e.g. + `bp.dnn.Dropout(prob=2.0)` / `prob=-1`, negative frequencies, non-positive integer sizes, + `gamma`/`weight_decay` outside `[0,1]`, etc. (28+ call sites use these bounds). Constructors + that are documented to reject bad input quietly accept it, deferring failures to confusing + downstream errors or wrong results. +- Repro: + ```python + bp.check.is_float(2.0, 'x', max_bound=1.0) # returns 2.0 instead of raising + bp.check.is_integer(0, 'n', min_bound=1) # returns 0 instead of raising + ``` +- Fix: in both `jit_error_checking_no_args` and `jit_error`, add a concrete-predicate fast path: + if `pred` is not a `jax.core.Tracer`, evaluate `bool(np.asarray(pred))` and raise/call the + error directly; otherwise keep the deferred `cond`+`pure_callback` path for in-jit signalling. + Also fixed `true_err_fun` to accept `*args`. +- Tests: check_test.py::TestBoundChecks (9 tests: eager raise for min/max on float & int, + within-bounds OK, concrete-True/False, and no-raise-at-trace under jit). Updated the + pre-existing check_coverage_test.py tests that *pinned the buggy no-raise behavior* + (TestIsFloat/TestIsInteger min/max bound branches; TestJitErrors pred_true / jit_error_true / + jit_error_true_tuple_arg) to assert the corrected raising semantics. +- Status: fixed + +--- + +### P14-L4 — `firing_rate(..., numpy=False)` JIT claim holds only for static window length [Low] +- File: brainpy/measure.py:88-96 +- Category: edge/error +- What: The docstring promises `numpy=False` makes `firing_rate` JIT-compilable + ("If ``False``, this function can be JIT compiled."). With `numpy=False` the function uses + `jnp.convolve`, which is fine, but `width1 = int(width/2/dt)*2+1` requires `width`/`dt` to be + concrete python numbers — that is the normal case and works. This is NOT a correctness bug; + the H-43 normalization fix is already present (line 95) and verified. Recorded as Low/no-fix + observation only that the JIT claim holds only when window length is static. +- Why it's a bug: documentation nuance, not a functional defect. +- Repro: static +- Fix: recorded only +- Tests: none +- Status: recorded-only + +--- + +### P14-M2 — `DSRunner` permanently mutates the global `dt` (state leak) [Medium] +- File: brainpy/runners.py:628 (`_step_func_predict` → `share.save(..., dt=self.dt)`) +- Category: edge/error +- What: Every `DSRunner.run()` writes `dt` into the *global* brainstate environment via + `share.save(dt=self.dt)` and never restores it. So `bp.DSRunner(model, dt=1.).run(...)` leaves + the process-wide `bm.get_dt()` permanently at `1.0`, silently changing the behavior of any + subsequently-created object that reads the default `dt` (e.g. `VarDelay(v, time=0.5)` then + computes `int(0.5/1.0) == 0` steps instead of `5`). +- Why it's a bug: A per-runner `dt` should be scoped to that runner, not bleed into global state. + This is a real cross-object coupling that surfaced as a test-ordering fragility (a delay test + asserting `max_length == 5` failed after a `dt=1.` runner ran earlier in the same process). +- Repro: + ```python + bm.get_dt() # 0.1 + bp.DSRunner(model, dt=1., progress_bar=False).run(5.) + bm.get_dt() # 1.0 <-- leaked + ``` +- Fix: Source-level fix is out of strict scope (changing the global-`dt` contract risks breaking + the many call sites that intentionally rely on `DSRunner(dt=...)` making `share['dt']` visible + in `update()`), so the runner behavior is **left as-is and recorded**. Fixed the resulting + *test* fragility in scope: `dyn_runner_test.py` now snapshots the default `dt` at import and + restores it in `tearDown` (`_DtRestoreMixin`) so the suite is order-independent. +- Tests: dyn_runner_test.py `_DtRestoreMixin` (TestDSRunner, TestMemoryEfficient teardown). +- Status: recorded-only (source); test fragility fixed + +--- + +### P14-L1 — `DynSysGroup.update` recomputes the full node collection every step [Low] +- File: brainpy/dynsys.py:579-591 +- Category: perf +- What: `self.nodes(level=1, ...).subset(...).unique().not_subset(DynView)` is recomputed on + every `update()` call (and again `.subset(Projection)`, `.subset(Dynamic)`, etc.). Under a + `for_loop`/`jit` scan this is traced once, so runtime cost is amortized, but eager + (`jit=False`) runs pay it every step. +- Why it's a bug: perf-only; correctness unaffected. +- Fix: recorded only (would require caching node partitions, which risks staleness when children + are mutated; out of proportion to benefit). +- Status: recorded-only + +--- + +### P14-L2 — `Container.__getattr__` can mask real `AttributeError`s with confusing messages [Low] +- File: brainpy/mixin.py:258-267 +- Category: edge/error +- What: `__getattr__` falls back to `super().__getattribute__(item)` for non-child attributes, + which re-raises `AttributeError` but loses the original context when the missing attribute is + computed lazily. +- Why it's a bug: cosmetic / debuggability only. +- Fix: recorded only +- Status: recorded-only + +--- + +### P14-L3 — `check.is_shape_consistency` asserts on the wrong variable in its loop [Low] +- File: brainpy/check.py:87-89, 106-108 +- Category: style +- What: Inside `for shape in shapes:` the assertion checks `isinstance(shapes, (tuple, list))` + (the outer container) rather than `shape` (the element). The intended per-element type check + is effectively a no-op. Harmless because the outer check already ran, but dead/misleading. +- Why it's a bug: dead assertion; no behavioral impact. +- Fix: recorded only +- Status: recorded-only + +--- + +## Cross-check against dev/issues-found-20260618.md (top-level entries) + +- C-22 (`DSRunner(memory_efficient=True)` non-functional, `'dict' has no attribute 'shape'`): + the monitor-callback crash is **already fixed** in the current code (jax.debug.callback path, + runners.py:640-650). However the *output accumulation* in the same path was still broken — + captured as **P14-C1** above and fixed here. +- H-43 (`measure.firing_rate` normalization): **already fixed** (measure.py:95). Verified + (constant 100 Hz spike train → mean rate 100.0). +- H-44 (`VarDelay(target, time=T>0)` reads `self.data` before assignment): **already fixed** + (delay.py:254-258, `self.data = None` set unconditionally). Verified. +- H-45 (`DataDelay.reset_state(batch_size)` → `size_without_batch` TypeError): root cause is in + `math/object_transform/variables.py` (out of scope); **already fixed** there + (`size_without_batch` now returns `(10,)` for a `(4,10)` batched Variable). Verified. + +left_unfixed_chm: none in scope. H-45's fix lives outside this scope (variables.py) and is +already applied upstream.