diff --git a/brainpy/math/object_transform/collectors.py b/brainpy/math/object_transform/collectors.py index 010c8b797..7ca342e22 100644 --- a/brainpy/math/object_transform/collectors.py +++ b/brainpy/math/object_transform/collectors.py @@ -112,7 +112,15 @@ def __sub__(self, other: Union[Dict, Sequence]): if isinstance(key, str): keys_to_remove.append(key) else: - keys_to_remove.extend(id_to_keys[id(key)]) + # Look the value up by identity. A value object that is not + # present must raise the same descriptive ``ValueError`` as + # the other "not found" paths below, rather than a bare + # ``KeyError(id)`` from an unchecked dict access. + matched = id_to_keys.get(id(key)) + if matched is None: + raise ValueError(f'Cannot remove {key}, because we do not find it ' + f'in {self.keys()}.') + keys_to_remove.extend(matched) for key in set(keys_to_remove): if key in gather: diff --git a/brainpy/math/object_transform/collectors_test.py b/brainpy/math/object_transform/collectors_test.py index 7116dd2df..78443eced 100644 --- a/brainpy/math/object_transform/collectors_test.py +++ b/brainpy/math/object_transform/collectors_test.py @@ -151,6 +151,16 @@ def test_sub_with_list_missing_key_raises(): c - ['nope'] +def test_sub_with_list_missing_value_raises(): + # P4-M3: removing a *value* object that is not present must raise the same + # descriptive ValueError as the string-key path, not a bare KeyError(id). + present = object() + absent = object() + c = Collector({'a': present}) + with pytest.raises(ValueError): + c - [absent] + + def test_sub_rejects_bad_type(): c = Collector({'a': 1}) with pytest.raises(ValueError): diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py index 4253dc48b..dd3f0aeee 100644 --- a/brainpy/math/object_transform/controls.py +++ b/brainpy/math/object_transform/controls.py @@ -150,10 +150,22 @@ def cond( if not isinstance(operands, (tuple, list)): operands = (operands,) operands = _unwrap_state_operands(operands) + + # ``true_fun``/``false_fun`` may be constants (array/number), per the + # documented contract. Wrap any non-callable branch into a callable that + # ignores ``*operands`` and returns the (unwrapped) constant, mirroring the + # handling in ``ifelse``. Otherwise brainstate would try to *call* the + # constant and raise ``TypeError: '' object is not callable``. + def _make_branch(branch): + if callable(branch): + return warp_to_no_state_input_output(branch) + const = _unwrap_operand_leaf(branch) + return warp_to_no_state_input_output(lambda *args: const) + return brainstate.transform.cond( pred, - warp_to_no_state_input_output(true_fun), - warp_to_no_state_input_output(false_fun), + _make_branch(true_fun), + _make_branch(false_fun), *operands ) @@ -230,6 +242,14 @@ def make_callable(branch): branches = [make_callable(branch) for branch in branches] + # A single condition may be passed as a bare scalar bool/array (the + # docstring types ``conditions`` as ``bool, sequence of bool``). Normalise + # it into a one-element list so it flows through the conversion below; + # otherwise ``brainstate.transform.ifelse`` would call ``len()`` on the + # scalar and raise ``TypeError: object ... has no len()``. + if not isinstance(conditions, (list, tuple)): + conditions = [conditions] + # Convert if-elif-else chain to mutually exclusive conditions if isinstance(conditions, (list, tuple)) and len(conditions) > 0: conditions = list(conditions) diff --git a/brainpy/math/object_transform/controls_test.py b/brainpy/math/object_transform/controls_test.py index 66446d136..1504ac65c 100644 --- a/brainpy/math/object_transform/controls_test.py +++ b/brainpy/math/object_transform/controls_test.py @@ -17,6 +17,7 @@ from functools import partial import jax +import jax.numpy as jnp from absl.testing import parameterized from jax import vmap @@ -528,3 +529,49 @@ def run(a, b): run(0., 1.) self.assertIn("cond_fun should not have any write states", str(cm.exception)) + + +class TestCondBranchTypes(parameterized.TestCase): + """Regression for P4-M1: ``bm.cond`` must accept non-callable (constant) + branches, as advertised by its docstring (``callable, ArrayType, float, + int, bool``).""" + + def test_cond_with_constant_branches(self): + # Scalar Python constants as branches. + self.assertEqual(float(bm.cond(True, 1.0, 2.0)), 1.0) + self.assertEqual(float(bm.cond(False, 1.0, 2.0)), 2.0) + + def test_cond_with_array_branches(self): + # Array constants as branches (unwrapped before forwarding). + r_true = bm.cond(True, bm.asarray([1., 2.]), bm.asarray([3., 4.])) + r_false = bm.cond(False, bm.asarray([1., 2.]), bm.asarray([3., 4.])) + self.assertTrue(bm.allclose(r_true, bm.asarray([1., 2.]))) + self.assertTrue(bm.allclose(r_false, bm.asarray([3., 4.]))) + + def test_cond_callable_still_works(self): + # Callable branches keep working (and may mutate Variable state). + a = bm.Variable(bm.zeros(2)) + + def tf(op): + a.value += op + + def ff(op): + a.value -= op + + bm.cond(True, tf, ff, 5.0) + self.assertTrue(bm.allclose(a.value, bm.asarray([5., 5.]))) + + +class TestIfElseScalarCondition(parameterized.TestCase): + """Regression for P4-M2: ``bm.ifelse`` must accept a scalar-bool + ``conditions`` argument, as advertised by its docstring.""" + + def test_ifelse_scalar_python_bool(self): + self.assertEqual(int(bm.ifelse(conditions=True, branches=[lambda: 1, lambda: 2])), 1) + self.assertEqual(int(bm.ifelse(conditions=False, branches=[lambda: 1, lambda: 2])), 2) + + def test_ifelse_scalar_array_bool(self): + r = bm.ifelse(conditions=jnp.asarray(True), branches=[lambda: 10, lambda: 20]) + self.assertEqual(int(r), 10) + r = bm.ifelse(conditions=jnp.asarray(False), branches=[lambda: 10, lambda: 20]) + self.assertEqual(int(r), 20) diff --git a/brainpy/math/object_transform/object_transform_fixes_test.py b/brainpy/math/object_transform/object_transform_fixes_test.py index ef1a7f9ae..9fb29e967 100644 --- a/brainpy/math/object_transform/object_transform_fixes_test.py +++ b/brainpy/math/object_transform/object_transform_fixes_test.py @@ -928,6 +928,41 @@ def test_variable_view_setter_shape_and_dtype_checks(): view.value = jnp.zeros(3) # wrong shape +# P4-M4: ``VariableView.value`` setter must accept the same inputs as +# ``Variable.value`` (plain list / numpy / State), instead of crashing or +# silently mismatching dtype. + +def test_variable_view_setter_python_list_matches_variable(): + # A plain Python list is handled identically to ``Variable.value`` (both + # raise a descriptive MathError rather than the previous opaque + # ``AttributeError: 'list' object has no attribute 'shape'``). + origin = bm.Variable(jnp.arange(5.)) + view = bm.VariableView(origin, slice(None, 2, None)) + parent_var = bm.Variable(jnp.arange(2.)) + with pytest.raises(MathError): + parent_var.value = [10., 11.] + with pytest.raises(MathError): + view.value = [10., 11.] + + +def test_variable_view_setter_canonicalizes_numpy_dtype(): + # A float64 numpy array assigned into a float32 view must be canonicalized, + # not rejected with a dtype MathError. + origin = bm.Variable(jnp.arange(5., dtype=jnp.float32)) + view = bm.VariableView(origin, slice(None, 2, None)) + view.value = np.array([7., 8.], dtype=np.float64) + assert origin.value.dtype == jnp.float32 + assert bm.allclose(origin.value[:2], jnp.asarray([7., 8.])) + + +def test_variable_view_setter_unwraps_state(): + origin = bm.Variable(jnp.arange(5.)) + view = bm.VariableView(origin, slice(None, 2, None)) + src = bm.Variable(jnp.asarray([20., 21.])) + view.value = src + assert bm.allclose(origin.value[:2], jnp.asarray([20., 21.])) + + # =========================================================================== # Additional coverage for base.py # =========================================================================== diff --git a/brainpy/math/object_transform/variables.py b/brainpy/math/object_transform/variables.py index b15f33eb4..74543abfe 100644 --- a/brainpy/math/object_transform/variables.py +++ b/brainpy/math/object_transform/variables.py @@ -329,23 +329,37 @@ def value(self): @value.setter def value(self, v): + # Normalize/unwrap the incoming value *before* validating its + # shape/dtype, mirroring the hardened ``Variable.value`` setter. Without + # this a plain Python ``list``/scalar (no ``.shape``) raises an + # ``AttributeError``, a ``brainstate.State`` is not unwrapped, and a + # ``numpy`` array is never canonicalized to the view's dtype. + if isinstance(v, brainstate.State): + v = v.value + if isinstance(v, Array): + v = v.value + elif isinstance(v, np.ndarray): + v = jnp.asarray(v) + int_shape = self.shape + ext_shape = jnp.shape(v) if self.batch_axis is None: - ext_shape = v.shape + pass else: - ext_shape = v.shape[:self.batch_axis] + v.shape[self.batch_axis + 1:] + ext_shape = ext_shape[:self.batch_axis] + ext_shape[self.batch_axis + 1:] int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:] if ext_shape != int_shape: - error = f"The shape of the original data is {self.shape}, while we got {v.shape}" + error = f"The shape of the original data is {self.shape}, while we got {jnp.shape(v)}" if self.batch_axis is None: error += '. Do you forget to set "batch_axis" when initialize this variable?' else: error += f' with batch_axis={self.batch_axis}.' raise MathError(error) - if v.dtype != self._value.dtype: + ext_dtype = _get_dtype(v) + if ext_dtype != self._value.dtype: raise MathError(f"The dtype of the original data is {self._value.dtype}, " - f"while we got {v.dtype}.") - self._value[self.index] = v.value if isinstance(v, Array) else v + f"while we got {ext_dtype}.") + self._value[self.index] = v @register_pytree_node_class diff --git a/docs/issues-found-20260619-math-object-transform.md b/docs/issues-found-20260619-math-object-transform.md new file mode 100644 index 000000000..5b10dffb0 --- /dev/null +++ b/docs/issues-found-20260619-math-object-transform.md @@ -0,0 +1,165 @@ +# Audit 2026-06-19 — `brainpy/math/object_transform` + +Reviewer: senior Python + JAX expert (P4 slice). Branch +`fix/audit-20260619-math-object-transform`. JAX 0.10.2, brainstate 0.5.1. + +## Context + +A prior audit (`dev/issues-found-20260618.md`) already fixed the major +Critical/High issues in this package (C-25 `VarDict.tree_unflatten`/`jax.util`, +C-26 `Variable` pytree metadata loss, H-01 `cls_jit` negative argnums, H-02 +state-in-operands, H-03 zero-length pytree guard, H-04 `jit` `dyn_vars` kwargs, +H-05 `to()`/`cpu()`, H-06 `Variable.value` setter ordering, H-08 +`register_implicit_vars` container flatten, H-09 `Variable.__hash__`, M-02 +`cls_jit` `donate_argnums`, M-05 `ifelse` `check_cond=False`, L-04/L-05/L-06). +These were verified present and working in this worktree (all 222 in-scope tests +green at baseline). This document records a *fresh* review; remaining findings +are predominantly documented-contract / edge-case error-handling gaps. + +--- + +### P4-M1 — `bm.cond` crashes on non-callable branches [Medium] +- File: brainpy/math/object_transform/controls.py:96-158 +- Category: edge/error +- What: The docstring types both `true_fun` and `false_fun` as + ``callable, ArrayType, float, int, bool``, i.e. a constant branch is a + supported input. But `cond` forwards the branches straight to + `warp_to_no_state_input_output(true_fun)` (which just `@wraps` them) and then + to `brainstate.transform.cond`, which calls them. A constant branch is never + wrapped into a callable, so `bm.cond(True, 1.0, 2.0)` raises + ``TypeError: 'float' object is not callable``. The sibling `ifelse` handles + this correctly via its `make_callable` helper. +- Why it's a bug: A documented call form crashes. Historical BrainPy `cond` + accepted constant branches. +- Repro: ``bm.cond(True, 1.0, 2.0)`` → ``TypeError: 'float' object is not callable`` +- Fix: wrap non-callable `true_fun`/`false_fun` into zero-arg callables before + forwarding (mirroring `ifelse.make_callable`), unwrapping any `Array`/`State` + constant to its raw value so brainstate accepts it as an operand-free branch. +- Tests: `controls_test.py::TestCondBranchTypes` (3 cases) +- Status: fixed + +### P4-M2 — `bm.ifelse` crashes on a scalar-bool `conditions` [Medium] +- File: brainpy/math/object_transform/controls.py:161-267 +- Category: edge/error +- What: The docstring types `conditions` as ``bool, sequence of bool``. The + mutually-exclusive-condition conversion is guarded by + ``isinstance(conditions, (list, tuple))``; a bare scalar bool falls straight + through to `brainstate.transform.ifelse`, which immediately does + ``len(conditions)`` and raises ``TypeError: object of type 'bool' has no + len()``. +- Why it's a bug: A documented single-condition call form crashes. +- Repro: ``bm.ifelse(conditions=True, branches=[lambda: 1, lambda: 2])`` → + ``TypeError: object of type 'bool' has no len()`` +- Fix: normalize a scalar (non-list/tuple) `conditions` into a one-element list + before the conversion block. The existing ``len(branches) > len(conditions)`` + branch then appends the implicit ``else`` condition, giving the correct + two-way dispatch. +- Tests: `controls_test.py::TestIfElseScalarCondition` (2 cases) +- Status: fixed + +### P4-M3 — `Collector.__sub__` raises raw `KeyError` on a missing value operand [Medium] +- File: brainpy/math/object_transform/collectors.py:102-122 +- Category: edge/error +- What: When subtracting a list/tuple that contains a *value* object (not a + string key) which is not present in the collector, the code does + ``id_to_keys[id(key)]`` without a membership check, raising a bare + ``KeyError()``. Every other "not found" path in `__sub__` raises a + descriptive ``ValueError`` (and the co-located test + `test_sub_with_list_missing_key_raises` asserts ``ValueError`` for the string + case), so this is an inconsistent / unhelpful failure mode. +- Why it's a bug: Contract violation — the documented/observed behaviour for a + missing removal target is `ValueError`, not a cryptic id-keyed `KeyError`. +- Repro: + ```python + c = Collector(); c['a'] = some_var + c - [other_var_not_in_c] # -> KeyError(140...id) + ``` +- Fix: use ``id_to_keys.get(id(key))`` and raise the same descriptive + ``ValueError`` used elsewhere when the object is absent. +- Tests: `collectors_test.py::test_sub_with_list_missing_value_raises` +- Status: fixed + +### P4-M4 — `VariableView.value` setter is non-robust and asymmetric with `Variable` [Medium] +- File: brainpy/math/object_transform/variables.py:330-348 +- Category: edge/error +- What: The setter accesses ``v.shape`` / ``v.dtype`` on the raw input *before* + unwrapping, and only unwraps `Array` (not `brainstate.State`/`np.ndarray`). + Consequences: ``view.value = [1., 2.]`` raises + ``AttributeError: 'list' object has no attribute 'shape'`` and a numpy array + is never canonicalized to the view's dtype. The parent `Variable.value` + setter was already hardened (H-06) to unwrap `State`/`Array`/`np.ndarray` + first; the view setter was left behind, so the two diverge. +- Why it's a bug: Assigning a plain list/number/State to a `VariableView` + (a documented, public update path) crashes or silently mismatches dtype, + unlike the equivalent assignment to a `Variable`. +- Repro: ``bm.VariableView(bm.Variable(bm.arange(5.)), slice(0, 2)).value = [1., 2.]`` + → ``AttributeError`` +- Fix: unwrap `State`/`Array`/`np.ndarray` first (as the parent does), then use + ``jnp.shape``/`_get_dtype` for validation. This makes `VariableView` accept + the same inputs as `Variable` (numpy canonicalization, `State` unwrap) and, + for a plain Python list, fail with the *same* descriptive ``MathError`` as + the parent rather than an opaque ``AttributeError`` (a bare list remains + rejected for both, consistent with the parent — see P4-L1). +- Tests: `object_transform_fixes_test.py::test_variable_view_setter_python_list_matches_variable`, + `...::test_variable_view_setter_canonicalizes_numpy_dtype`, + `...::test_variable_view_setter_unwraps_state` +- Status: fixed + +### P4-L1 — `Variable.value = ` yields a confusing "object" dtype error [Low] +- File: brainpy/math/object_transform/variables.py:142-170 +- Category: edge/error +- What: A plain Python list is not unwrapped/`jnp.asarray`-ed, so the dtype + check computes ``canonicalize_dtype(list)`` → object dtype and raises + ``MathError: ... while we got object`` instead of either accepting the list + or giving a clear message. (Lists are not a documented input, hence Low.) +- Why it's a bug: Misleading diagnostic for a near-miss usage. +- Repro: ``bm.Variable(bm.arange(2.)).value = [1., 2.]`` +- Fix: recorded only. +- Tests: none +- Status: recorded-only + +### P4-L2 — `Variable.tree_unflatten` invokes `record_state_init` on every unflatten [Low] +- File: brainpy/math/object_transform/variables.py:199-214 +- Category: perf/correctness (latent) +- What: `tree_unflatten` calls ``brainstate.State.__init__`` to rebuild + bookkeeping. That runs ``source_info_util.current()`` (non-trivial) and + ``record_state_init(self)``, which appends the reconstructed state to every + active ``TRACE_CONTEXT.new_state_catcher``. A `Variable` is reconstructed on + *every* pytree round-trip (each jit/vmap/scan boundary, every `tree_map`). + If such a round-trip happens inside a brainstate "new-state catcher" context + (model-construction time), the rebuilt-but-not-actually-new state could be + spuriously caught. Not reproducible through the normal brainstate transform + paths (they close over states rather than passing Variables as pytree args), + so left as Low. +- Why it's a bug: Theoretical state-leak / minor per-unflatten cost. +- Repro: static (no observable failure in normal usage; verified jit/tree_map + round-trips do not leak). +- Fix: recorded only. (Reverting to full ``Variable.__init__`` would be worse — + it re-runs batch-axis validation + naming. A clean fix needs a brainstate + "rehydrate without recording" entry point, which is out of scope.) +- Tests: none +- Status: recorded-only + +### P4-L3 — auto name counter can collide with a manually supplied name [Low] +- File: brainpy/math/object_transform/naming.py:68-74 +- Category: edge/error +- What: ``get_unique_name`` hands out ``f'{type}{counter}'`` and bumps the + counter, ignoring names already taken manually. Creating ``Foo(name='Foo1')`` + before the auto counter reaches 1 makes the next auto-named ``Foo()`` raise + ``UniqueNameError``. Long-standing historical BrainPy behaviour. +- Why it's a bug: Surprising collision; mitigated by `clear_name_cache()`. +- Repro: ``Foo(); Foo(name='Foo1'); Foo()`` → ``UniqueNameError`` +- Fix: recorded only (historical contract; would change naming semantics). +- Tests: none +- Status: recorded-only + +--- + +## Cross-check vs `dev/issues-found-20260618.md` + +All object_transform / variables / transforms entries from the prior audit were +verified **already fixed** in this worktree and confirmed working: +C-25, C-26, H-01, H-02, H-03, H-04, H-05, H-06, H-08, H-09, M-02, M-03 (docstring +now says ``(final_carry, stacked_ys)``), M-04 (now documented), M-05, M-06 (now +documented intentional carry-passthrough), L-04, L-05, L-06. No still-present +verified bug from that list remained in scope.