Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion brainpy/math/object_transform/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,15 @@ def __sub__(self, other: Union[Dict, Sequence]):
if isinstance(key, str):
keys_to_remove.append(key)
else:
keys_to_remove.extend(id_to_keys[id(key)])
# Look the value up by identity. A value object that is not
# present must raise the same descriptive ``ValueError`` as
# the other "not found" paths below, rather than a bare
# ``KeyError(id)`` from an unchecked dict access.
matched = id_to_keys.get(id(key))
if matched is None:
raise ValueError(f'Cannot remove {key}, because we do not find it '
f'in {self.keys()}.')
keys_to_remove.extend(matched)

for key in set(keys_to_remove):
if key in gather:
Expand Down
10 changes: 10 additions & 0 deletions brainpy/math/object_transform/collectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ def test_sub_with_list_missing_key_raises():
c - ['nope']


def test_sub_with_list_missing_value_raises():
# P4-M3: removing a *value* object that is not present must raise the same
# descriptive ValueError as the string-key path, not a bare KeyError(id).
present = object()
absent = object()
c = Collector({'a': present})
with pytest.raises(ValueError):
Comment on lines +154 to +160

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Strengthen the regression by checking the ValueError message content for missing values.

Since this already covers the missing-value path for Collector.__sub__, you can make the regression stronger by also asserting on the error message to distinguish it from a generic ValueError:

with pytest.raises(ValueError, match=r"Cannot remove .* do not find it in"):
    c - [absent]

This also keeps the behavior aligned with the string-key path.

c - [absent]


def test_sub_rejects_bad_type():
c = Collector({'a': 1})
with pytest.raises(ValueError):
Expand Down
24 changes: 22 additions & 2 deletions brainpy/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,22 @@ def cond(
if not isinstance(operands, (tuple, list)):
operands = (operands,)
operands = _unwrap_state_operands(operands)

# ``true_fun``/``false_fun`` may be constants (array/number), per the
# documented contract. Wrap any non-callable branch into a callable that
# ignores ``*operands`` and returns the (unwrapped) constant, mirroring the
# handling in ``ifelse``. Otherwise brainstate would try to *call* the
# constant and raise ``TypeError: '<type>' object is not callable``.
def _make_branch(branch):
if callable(branch):
return warp_to_no_state_input_output(branch)
const = _unwrap_operand_leaf(branch)
return warp_to_no_state_input_output(lambda *args: const)

return brainstate.transform.cond(
pred,
warp_to_no_state_input_output(true_fun),
warp_to_no_state_input_output(false_fun),
_make_branch(true_fun),
_make_branch(false_fun),
*operands
)

Expand Down Expand Up @@ -230,6 +242,14 @@ def make_callable(branch):

branches = [make_callable(branch) for branch in branches]

# A single condition may be passed as a bare scalar bool/array (the
# docstring types ``conditions`` as ``bool, sequence of bool``). Normalise
# it into a one-element list so it flows through the conversion below;
# otherwise ``brainstate.transform.ifelse`` would call ``len()`` on the
# scalar and raise ``TypeError: object ... has no len()``.
if not isinstance(conditions, (list, tuple)):
conditions = [conditions]
Comment on lines +250 to +251

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Normalising non-sequence conditions to a one-element list changes behaviour for other iterable types (e.g. generators); consider a more precise type check.

The current if not isinstance(conditions, (list, tuple)) branch fixes the scalar case but also wraps generators and other non-list/tuple sequences, changing them from multiple conditions into a single one.

To keep behaviour consistent for genuine sequences, consider checking for “is a scalar/non-sequence” rather than “is not list/tuple”. For example, you could detect collections.abc.Sequence (excluding str/bytes/np.ndarray) or explicitly handle np.bool_ and scalar arrays, so real sequences still flow through unchanged while scalars avoid the len() error.

Suggested implementation:

    branches = [make_callable(branch) for branch in branches]

    # A single condition may be passed as a bare scalar bool/array (the
    # docstring types ``conditions`` as ``bool, sequence of bool``). Normalise
    # true scalars into a one-element list so they flow through the conversion
    # below; otherwise ``brainstate.transform.ifelse`` would call ``len()`` on
    # the scalar and raise ``TypeError: object ... has no len()``. Existing
    # iterable conditions (lists/tuples/generators/etc.) are left unchanged so
    # they still represent multiple conditions.
    is_numpy_scalar_bool = False
    try:
        import numpy as np  # use existing import if already present
        if isinstance(conditions, np.bool_) or (
            isinstance(conditions, np.ndarray) and conditions.ndim == 0 and conditions.dtype == bool
        ):
            is_numpy_scalar_bool = True
    except Exception:
        # NumPy not available or not used; fall back to pure-Python scalars only
        pass

    if isinstance(conditions, bool) or is_numpy_scalar_bool:
        conditions = [conditions]

    # Convert if-elif-else chain to mutually exclusive conditions
    if isinstance(conditions, (list, tuple)) and len(conditions) > 0:
        conditions = list(conditions)
  1. If this file does not already import NumPy as np, you may want to move the import numpy as np to the module top-level instead of keeping the local import inside the function, to match existing style/import conventions.
  2. If other scalar condition types (e.g. np.int_ or numeric scalars) should also be normalised, extend the is_numpy_scalar_bool logic and/or the isinstance(conditions, bool) check accordingly (e.g. using numbers.Integral), but the current change focuses on boolean scalars to match the documented type.


# Convert if-elif-else chain to mutually exclusive conditions
if isinstance(conditions, (list, tuple)) and len(conditions) > 0:
conditions = list(conditions)
Expand Down
47 changes: 47 additions & 0 deletions brainpy/math/object_transform/controls_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from functools import partial

import jax
import jax.numpy as jnp
from absl.testing import parameterized
from jax import vmap

Expand Down Expand Up @@ -528,3 +529,49 @@ def run(a, b):
run(0., 1.)

self.assertIn("cond_fun should not have any write states", str(cm.exception))


class TestCondBranchTypes(parameterized.TestCase):
"""Regression for P4-M1: ``bm.cond`` must accept non-callable (constant)
branches, as advertised by its docstring (``callable, ArrayType, float,
int, bool``)."""

def test_cond_with_constant_branches(self):
# Scalar Python constants as branches.
self.assertEqual(float(bm.cond(True, 1.0, 2.0)), 1.0)
self.assertEqual(float(bm.cond(False, 1.0, 2.0)), 2.0)

def test_cond_with_array_branches(self):
# Array constants as branches (unwrapped before forwarding).
r_true = bm.cond(True, bm.asarray([1., 2.]), bm.asarray([3., 4.]))
r_false = bm.cond(False, bm.asarray([1., 2.]), bm.asarray([3., 4.]))
self.assertTrue(bm.allclose(r_true, bm.asarray([1., 2.])))
self.assertTrue(bm.allclose(r_false, bm.asarray([3., 4.])))

def test_cond_callable_still_works(self):
# Callable branches keep working (and may mutate Variable state).
a = bm.Variable(bm.zeros(2))

def tf(op):
a.value += op

def ff(op):
a.value -= op

bm.cond(True, tf, ff, 5.0)
self.assertTrue(bm.allclose(a.value, bm.asarray([5., 5.])))


class TestIfElseScalarCondition(parameterized.TestCase):
"""Regression for P4-M2: ``bm.ifelse`` must accept a scalar-bool
``conditions`` argument, as advertised by its docstring."""

def test_ifelse_scalar_python_bool(self):
self.assertEqual(int(bm.ifelse(conditions=True, branches=[lambda: 1, lambda: 2])), 1)
self.assertEqual(int(bm.ifelse(conditions=False, branches=[lambda: 1, lambda: 2])), 2)

def test_ifelse_scalar_array_bool(self):
r = bm.ifelse(conditions=jnp.asarray(True), branches=[lambda: 10, lambda: 20])
self.assertEqual(int(r), 10)
r = bm.ifelse(conditions=jnp.asarray(False), branches=[lambda: 10, lambda: 20])
self.assertEqual(int(r), 20)
35 changes: 35 additions & 0 deletions brainpy/math/object_transform/object_transform_fixes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,41 @@ def test_variable_view_setter_shape_and_dtype_checks():
view.value = jnp.zeros(3) # wrong shape


# P4-M4: ``VariableView.value`` setter must accept the same inputs as
# ``Variable.value`` (plain list / numpy / State), instead of crashing or
# silently mismatching dtype.

def test_variable_view_setter_python_list_matches_variable():
# A plain Python list is handled identically to ``Variable.value`` (both
# raise a descriptive MathError rather than the previous opaque
# ``AttributeError: 'list' object has no attribute 'shape'``).
origin = bm.Variable(jnp.arange(5.))
view = bm.VariableView(origin, slice(None, 2, None))
parent_var = bm.Variable(jnp.arange(2.))
with pytest.raises(MathError):
parent_var.value = [10., 11.]
with pytest.raises(MathError):
view.value = [10., 11.]


def test_variable_view_setter_canonicalizes_numpy_dtype():
# A float64 numpy array assigned into a float32 view must be canonicalized,
# not rejected with a dtype MathError.
origin = bm.Variable(jnp.arange(5., dtype=jnp.float32))
view = bm.VariableView(origin, slice(None, 2, None))
view.value = np.array([7., 8.], dtype=np.float64)
assert origin.value.dtype == jnp.float32
assert bm.allclose(origin.value[:2], jnp.asarray([7., 8.]))


def test_variable_view_setter_unwraps_state():
origin = bm.Variable(jnp.arange(5.))
view = bm.VariableView(origin, slice(None, 2, None))
src = bm.Variable(jnp.asarray([20., 21.]))
view.value = src
assert bm.allclose(origin.value[:2], jnp.asarray([20., 21.]))


# ===========================================================================
# Additional coverage for base.py
# ===========================================================================
Expand Down
26 changes: 20 additions & 6 deletions brainpy/math/object_transform/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,23 +329,37 @@ def value(self):

@value.setter
def value(self, v):
# Normalize/unwrap the incoming value *before* validating its
# shape/dtype, mirroring the hardened ``Variable.value`` setter. Without
# this a plain Python ``list``/scalar (no ``.shape``) raises an
# ``AttributeError``, a ``brainstate.State`` is not unwrapped, and a
# ``numpy`` array is never canonicalized to the view's dtype.
if isinstance(v, brainstate.State):
v = v.value
if isinstance(v, Array):
v = v.value
elif isinstance(v, np.ndarray):
v = jnp.asarray(v)

int_shape = self.shape
ext_shape = jnp.shape(v)
if self.batch_axis is None:
ext_shape = v.shape
pass
else:
ext_shape = v.shape[:self.batch_axis] + v.shape[self.batch_axis + 1:]
ext_shape = ext_shape[:self.batch_axis] + ext_shape[self.batch_axis + 1:]
int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:]
if ext_shape != int_shape:
error = f"The shape of the original data is {self.shape}, while we got {v.shape}"
error = f"The shape of the original data is {self.shape}, while we got {jnp.shape(v)}"
if self.batch_axis is None:
error += '. Do you forget to set "batch_axis" when initialize this variable?'
else:
error += f' with batch_axis={self.batch_axis}.'
raise MathError(error)
if v.dtype != self._value.dtype:
ext_dtype = _get_dtype(v)
if ext_dtype != self._value.dtype:
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
f"while we got {v.dtype}.")
self._value[self.index] = v.value if isinstance(v, Array) else v
f"while we got {ext_dtype}.")
self._value[self.index] = v


@register_pytree_node_class
Expand Down
165 changes: 165 additions & 0 deletions docs/issues-found-20260619-math-object-transform.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Audit 2026-06-19 — `brainpy/math/object_transform`

Reviewer: senior Python + JAX expert (P4 slice). Branch
`fix/audit-20260619-math-object-transform`. JAX 0.10.2, brainstate 0.5.1.

## Context

A prior audit (`dev/issues-found-20260618.md`) already fixed the major
Critical/High issues in this package (C-25 `VarDict.tree_unflatten`/`jax.util`,
C-26 `Variable` pytree metadata loss, H-01 `cls_jit` negative argnums, H-02
state-in-operands, H-03 zero-length pytree guard, H-04 `jit` `dyn_vars` kwargs,
H-05 `to()`/`cpu()`, H-06 `Variable.value` setter ordering, H-08
`register_implicit_vars` container flatten, H-09 `Variable.__hash__`, M-02
`cls_jit` `donate_argnums`, M-05 `ifelse` `check_cond=False`, L-04/L-05/L-06).
These were verified present and working in this worktree (all 222 in-scope tests
green at baseline). This document records a *fresh* review; remaining findings
are predominantly documented-contract / edge-case error-handling gaps.

---

### P4-M1 — `bm.cond` crashes on non-callable branches [Medium]
- File: brainpy/math/object_transform/controls.py:96-158
- Category: edge/error
- What: The docstring types both `true_fun` and `false_fun` as
``callable, ArrayType, float, int, bool``, i.e. a constant branch is a
supported input. But `cond` forwards the branches straight to
`warp_to_no_state_input_output(true_fun)` (which just `@wraps` them) and then
to `brainstate.transform.cond`, which calls them. A constant branch is never
wrapped into a callable, so `bm.cond(True, 1.0, 2.0)` raises
``TypeError: 'float' object is not callable``. The sibling `ifelse` handles
this correctly via its `make_callable` helper.
- Why it's a bug: A documented call form crashes. Historical BrainPy `cond`
accepted constant branches.
- Repro: ``bm.cond(True, 1.0, 2.0)`` → ``TypeError: 'float' object is not callable``
- Fix: wrap non-callable `true_fun`/`false_fun` into zero-arg callables before
forwarding (mirroring `ifelse.make_callable`), unwrapping any `Array`/`State`
constant to its raw value so brainstate accepts it as an operand-free branch.
- Tests: `controls_test.py::TestCondBranchTypes` (3 cases)
- Status: fixed

### P4-M2 — `bm.ifelse` crashes on a scalar-bool `conditions` [Medium]
- File: brainpy/math/object_transform/controls.py:161-267
- Category: edge/error
- What: The docstring types `conditions` as ``bool, sequence of bool``. The
mutually-exclusive-condition conversion is guarded by
``isinstance(conditions, (list, tuple))``; a bare scalar bool falls straight
through to `brainstate.transform.ifelse`, which immediately does
``len(conditions)`` and raises ``TypeError: object of type 'bool' has no
len()``.
- Why it's a bug: A documented single-condition call form crashes.
- Repro: ``bm.ifelse(conditions=True, branches=[lambda: 1, lambda: 2])`` →
``TypeError: object of type 'bool' has no len()``
- Fix: normalize a scalar (non-list/tuple) `conditions` into a one-element list
before the conversion block. The existing ``len(branches) > len(conditions)``
branch then appends the implicit ``else`` condition, giving the correct
two-way dispatch.
- Tests: `controls_test.py::TestIfElseScalarCondition` (2 cases)
- Status: fixed

### P4-M3 — `Collector.__sub__` raises raw `KeyError` on a missing value operand [Medium]
- File: brainpy/math/object_transform/collectors.py:102-122
- Category: edge/error
- What: When subtracting a list/tuple that contains a *value* object (not a
string key) which is not present in the collector, the code does
``id_to_keys[id(key)]`` without a membership check, raising a bare
``KeyError(<int id>)``. Every other "not found" path in `__sub__` raises a
descriptive ``ValueError`` (and the co-located test
`test_sub_with_list_missing_key_raises` asserts ``ValueError`` for the string
case), so this is an inconsistent / unhelpful failure mode.
- Why it's a bug: Contract violation — the documented/observed behaviour for a
missing removal target is `ValueError`, not a cryptic id-keyed `KeyError`.
- Repro:
```python
c = Collector(); c['a'] = some_var
c - [other_var_not_in_c] # -> KeyError(140...id)
```
- Fix: use ``id_to_keys.get(id(key))`` and raise the same descriptive
``ValueError`` used elsewhere when the object is absent.
- Tests: `collectors_test.py::test_sub_with_list_missing_value_raises`
- Status: fixed

### P4-M4 — `VariableView.value` setter is non-robust and asymmetric with `Variable` [Medium]
- File: brainpy/math/object_transform/variables.py:330-348
- Category: edge/error
- What: The setter accesses ``v.shape`` / ``v.dtype`` on the raw input *before*
unwrapping, and only unwraps `Array` (not `brainstate.State`/`np.ndarray`).
Consequences: ``view.value = [1., 2.]`` raises
``AttributeError: 'list' object has no attribute 'shape'`` and a numpy array
is never canonicalized to the view's dtype. The parent `Variable.value`
setter was already hardened (H-06) to unwrap `State`/`Array`/`np.ndarray`
first; the view setter was left behind, so the two diverge.
- Why it's a bug: Assigning a plain list/number/State to a `VariableView`
(a documented, public update path) crashes or silently mismatches dtype,
unlike the equivalent assignment to a `Variable`.
- Repro: ``bm.VariableView(bm.Variable(bm.arange(5.)), slice(0, 2)).value = [1., 2.]``
→ ``AttributeError``
- Fix: unwrap `State`/`Array`/`np.ndarray` first (as the parent does), then use
``jnp.shape``/`_get_dtype` for validation. This makes `VariableView` accept
the same inputs as `Variable` (numpy canonicalization, `State` unwrap) and,
for a plain Python list, fail with the *same* descriptive ``MathError`` as
the parent rather than an opaque ``AttributeError`` (a bare list remains
rejected for both, consistent with the parent — see P4-L1).
- Tests: `object_transform_fixes_test.py::test_variable_view_setter_python_list_matches_variable`,
`...::test_variable_view_setter_canonicalizes_numpy_dtype`,
`...::test_variable_view_setter_unwraps_state`
- Status: fixed

### P4-L1 — `Variable.value = <python list>` yields a confusing "object" dtype error [Low]
- File: brainpy/math/object_transform/variables.py:142-170
- Category: edge/error
- What: A plain Python list is not unwrapped/`jnp.asarray`-ed, so the dtype
check computes ``canonicalize_dtype(list)`` → object dtype and raises
``MathError: ... while we got object`` instead of either accepting the list
or giving a clear message. (Lists are not a documented input, hence Low.)
- Why it's a bug: Misleading diagnostic for a near-miss usage.
- Repro: ``bm.Variable(bm.arange(2.)).value = [1., 2.]``
- Fix: recorded only.
- Tests: none
- Status: recorded-only

### P4-L2 — `Variable.tree_unflatten` invokes `record_state_init` on every unflatten [Low]
- File: brainpy/math/object_transform/variables.py:199-214
- Category: perf/correctness (latent)
- What: `tree_unflatten` calls ``brainstate.State.__init__`` to rebuild
bookkeeping. That runs ``source_info_util.current()`` (non-trivial) and
``record_state_init(self)``, which appends the reconstructed state to every
active ``TRACE_CONTEXT.new_state_catcher``. A `Variable` is reconstructed on
*every* pytree round-trip (each jit/vmap/scan boundary, every `tree_map`).
If such a round-trip happens inside a brainstate "new-state catcher" context
(model-construction time), the rebuilt-but-not-actually-new state could be
spuriously caught. Not reproducible through the normal brainstate transform
paths (they close over states rather than passing Variables as pytree args),
so left as Low.
- Why it's a bug: Theoretical state-leak / minor per-unflatten cost.
- Repro: static (no observable failure in normal usage; verified jit/tree_map
round-trips do not leak).
- Fix: recorded only. (Reverting to full ``Variable.__init__`` would be worse —
it re-runs batch-axis validation + naming. A clean fix needs a brainstate
"rehydrate without recording" entry point, which is out of scope.)
- Tests: none
- Status: recorded-only

### P4-L3 — auto name counter can collide with a manually supplied name [Low]
- File: brainpy/math/object_transform/naming.py:68-74
- Category: edge/error
- What: ``get_unique_name`` hands out ``f'{type}{counter}'`` and bumps the
counter, ignoring names already taken manually. Creating ``Foo(name='Foo1')``
before the auto counter reaches 1 makes the next auto-named ``Foo()`` raise
``UniqueNameError``. Long-standing historical BrainPy behaviour.
- Why it's a bug: Surprising collision; mitigated by `clear_name_cache()`.
- Repro: ``Foo(); Foo(name='Foo1'); Foo()`` → ``UniqueNameError``
- Fix: recorded only (historical contract; would change naming semantics).
- Tests: none
- Status: recorded-only

---

## Cross-check vs `dev/issues-found-20260618.md`

All object_transform / variables / transforms entries from the prior audit were
verified **already fixed** in this worktree and confirmed working:
C-25, C-26, H-01, H-02, H-03, H-04, H-05, H-06, H-08, H-09, M-02, M-03 (docstring
now says ``(final_carry, stacked_ys)``), M-04 (now documented), M-05, M-06 (now
documented intentional carry-passthrough), L-04, L-05, L-06. No still-present
verified bug from that list remained in scope.
Loading