From 311000857b269f5027cca13bd1deb737dd938e3f Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 18 Jun 2026 10:31:22 +0800 Subject: [PATCH] fix: restore daily CI under pytest 9.1 + brainstate 0.5 typed API The scheduled daily CI began failing on the frozen commit d0627d4 due to dependency drift, not source regressions. Three independent root causes are addressed here (all braintrace-side); a fourth was an upstream saiunit bug fixed in saiunit/brainunit 0.5.1 and needs no change here. 1. pytest 9.1.0 parametrize collection error pytest 9.1.0 parses a trailing comma in a single-arg parametrize id ('cls,') as two values, tripping ParameterSet length validation and erroring collection with "GraphNodeMeta has no len()". Drop the trailing commas (module_info_test, hidden_group_test, hid_param_op_test, hidden_pertubation_test). 2. 154 mypy errors from brainstate's new py.typed API brainstate 0.4+ ships type information, exposing its typed surface to mypy. Adopt it properly in source: route PyTree through braintrace's existing alias, centralize an as_size_tuple() helper in _typing, drop FlattedDict subscripts, and add boundary asserts/casts. Minimal # type: ignore only where brainstate's typing makes it unavoidable (hidden Module.__init__, saiunit Unit|Quantity, treedef-as-PyTree). 3. conv test expectations for brainstate 0.5.0 brainstate 0.5.0 hardened conv validation from bare assert to ValueError and changed padding-tuple semantics to one value per spatial dim. Update the affected _conv_test expectations accordingly. Verified locally: full suite 1367 passed / 0 failed (2 xfailed), mypy clean (51 files), wheel+sdist build with py.typed shipped (PEP 561). --- braintrace/_compile.py | 2 +- braintrace/_etrace_algorithms/_common.py | 12 +- braintrace/_etrace_algorithms/base.py | 19 +-- .../_etrace_algorithms/graph_executor.py | 10 +- braintrace/_etrace_algorithms/oracle.py | 4 +- .../_etrace_algorithms/param_dim_vjp.py | 8 +- braintrace/_etrace_compiler/hid_param_op.py | 2 + .../_etrace_compiler/hid_param_op_test.py | 4 +- braintrace/_etrace_compiler/hidden_group.py | 6 +- .../_etrace_compiler/hidden_group_test.py | 16 +-- .../hidden_pertubation_test.py | 4 +- braintrace/_etrace_compiler/module_info.py | 12 +- .../_etrace_compiler/module_info_test.py | 4 +- braintrace/_grad_exponential.py | 10 +- braintrace/_state_managment.py | 24 ++-- braintrace/_typing.py | 26 ++++ braintrace/nn/_conv.py | 6 +- braintrace/nn/_conv_test.py | 12 +- braintrace/nn/_linear.py | 8 +- braintrace/nn/_readout.py | 13 +- braintrace/nn/_rnn.py | 124 +++++++++--------- 21 files changed, 186 insertions(+), 140 deletions(-) diff --git a/braintrace/_compile.py b/braintrace/_compile.py index 9ede1b1..15b9f9f 100644 --- a/braintrace/_compile.py +++ b/braintrace/_compile.py @@ -32,7 +32,7 @@ # Canonical lowercase name (+ aliases) -> algorithm class. No bare ``ostl`` # alias: the ambiguous OSTL factory was removed in 0.2.0, so callers pick # ``ostl_recurrent`` vs ``ostl_feedforward`` explicitly. -_ALGORITHM_REGISTRY = { +_ALGORITHM_REGISTRY: dict[str, type[ETraceAlgorithm]] = { 'd_rtrl': D_RTRL, 'pp_prop': pp_prop, 'es_d_rtrl': pp_prop, diff --git a/braintrace/_etrace_algorithms/_common.py b/braintrace/_etrace_algorithms/_common.py index cb07ed6..61fb2ff 100644 --- a/braintrace/_etrace_algorithms/_common.py +++ b/braintrace/_etrace_algorithms/_common.py @@ -23,6 +23,8 @@ import jax.numpy as jnp import saiunit as u +from braintrace._typing import PyTree + __all__ = [ 'PresynapticTrace', 'KappaFilter', @@ -388,7 +390,7 @@ def _unit_safe_add(a, b): return u.math.add(a, b) -def _extract_leaf(pytree_val: brainstate.typing.PyTree, leaf_idx: int): +def _extract_leaf(pytree_val: PyTree, leaf_idx: int): """Return the leaf at ``leaf_idx`` in ``jax.tree.leaves(pytree_val)``. Bare arrays (treedef with a single leaf) return the array unchanged. @@ -405,7 +407,7 @@ def _extract_leaf(pytree_val: brainstate.typing.PyTree, leaf_idx: int): def _wrap_leaves_as_pytree( - reference_pytree: brainstate.typing.PyTree, + reference_pytree: PyTree, leaf_grads: Dict[int, jax.Array], ): """Build a pytree matching ``reference_pytree`` with ``leaf_grads`` @@ -441,8 +443,8 @@ def _wrap_leaves_as_pytree( def _route_grads_by_path( relation, per_key_grads: Dict[str, jax.Array], - weight_vals: Dict[Any, brainstate.typing.PyTree], - target_dict: Dict[Any, brainstate.typing.PyTree], + weight_vals: Dict[Any, PyTree], + target_dict: Dict[Any, PyTree], ) -> None: """Route per-key gradients from a dict-API rule into per-path pytrees. @@ -474,7 +476,7 @@ def _route_grads_by_path( def _update_dict( the_dict: Dict, key: Any, - value: brainstate.typing.PyTree, + value: PyTree, error_when_no_key: Optional[bool] = False ): """Update the dictionary. diff --git a/braintrace/_etrace_algorithms/base.py b/braintrace/_etrace_algorithms/base.py index 386a48c..dfcc17a 100644 --- a/braintrace/_etrace_algorithms/base.py +++ b/braintrace/_etrace_algorithms/base.py @@ -90,7 +90,7 @@ def __init__( graph_executor: ETraceGraphExecutor, name: Optional[str] = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # brainstate hides Module.__init__ from type checkers # the model if not isinstance(model, brainstate.nn.Module): @@ -115,9 +115,9 @@ def __init__( self.running_index = brainstate.LongTermState(0) # other states - self._param_states = None - self._hidden_states = None - self._other_states = None + self._param_states: Optional[brainstate.util.FlattedDict] = None + self._hidden_states: Optional[brainstate.util.FlattedDict] = None + self._other_states: Optional[brainstate.util.FlattedDict] = None @property def graph(self) -> ETraceGraph: @@ -144,7 +144,7 @@ def executor(self) -> ETraceGraphExecutor: return self.graph_executor @property - def param_states(self) -> brainstate.util.FlattedDict[Path, brainstate.ParamState]: + def param_states(self) -> brainstate.util.FlattedDict: """ Get the parameter weight states. @@ -155,10 +155,11 @@ def param_states(self) -> brainstate.util.FlattedDict[Path, brainstate.ParamStat """ if self._param_states is None: self._split_state() + assert self._param_states is not None return self._param_states @property - def hidden_states(self) -> brainstate.util.FlattedDict[Path, brainstate.HiddenState]: + def hidden_states(self) -> brainstate.util.FlattedDict: """ Get the hidden states. @@ -169,10 +170,11 @@ def hidden_states(self) -> brainstate.util.FlattedDict[Path, brainstate.HiddenSt """ if self._hidden_states is None: self._split_state() + assert self._hidden_states is not None return self._hidden_states @property - def other_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: + def other_states(self) -> brainstate.util.FlattedDict: """ Get the other states. @@ -183,6 +185,7 @@ def other_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: """ if self._other_states is None: self._split_state() + assert self._other_states is not None return self._other_states def _split_state(self): @@ -235,7 +238,7 @@ def compile_graph(self, *args) -> None: self.is_compiled = True @property - def path_to_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: + def path_to_states(self) -> brainstate.util.FlattedDict: """ Get the path to the states. diff --git a/braintrace/_etrace_algorithms/graph_executor.py b/braintrace/_etrace_algorithms/graph_executor.py index 3a6f25e..e391082 100644 --- a/braintrace/_etrace_algorithms/graph_executor.py +++ b/braintrace/_etrace_algorithms/graph_executor.py @@ -110,7 +110,7 @@ def graph(self) -> ETraceGraph: return self._compiled_graph @property - def states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: + def states(self) -> brainstate.util.FlattedDict: """ The states for the model. @@ -122,7 +122,7 @@ def states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: return self.graph.module_info.retrieved_model_states @property - def path_to_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: + def path_to_states(self) -> brainstate.util.FlattedDict: """ The path to the states. @@ -216,6 +216,8 @@ def show_graph( # other hidden states other_states = [] short_states = self.states.filter(brainstate.ShortTermState) + # a single-type filter returns one FlattedDict (brainstate types it as a union) + assert isinstance(short_states, brainstate.util.FlattedDict) for i, path in enumerate(short_states.keys()): if path not in hidden_paths: other_states.append(path) @@ -239,7 +241,9 @@ def show_graph( msg += '\n\n' # non etrace weights - non_etratce_weight_paths = set(self.states.filter(brainstate.ParamState).keys()) + param_states = self.states.filter(brainstate.ParamState) + assert isinstance(param_states, brainstate.util.FlattedDict) + non_etratce_weight_paths = set(param_states.keys()) non_etratce_weight_paths = non_etratce_weight_paths.difference(etratce_weight_paths) if len(non_etratce_weight_paths): msg += 'The non-etrace weight parameters are:\n\n' diff --git a/braintrace/_etrace_algorithms/oracle.py b/braintrace/_etrace_algorithms/oracle.py index 6fbb587..87bb719 100644 --- a/braintrace/_etrace_algorithms/oracle.py +++ b/braintrace/_etrace_algorithms/oracle.py @@ -59,8 +59,10 @@ def finite_difference_param_gradients( """ template = model_factory() brainstate.nn.init_all_states(template, batch_size=1) + template_params = template.states(brainstate.ParamState) + assert isinstance(template_params, brainstate.util.FlattedDict) base_values = { - k: np.asarray(v.value) for k, v in template.states(brainstate.ParamState).items() + k: np.asarray(v.value) for k, v in template_params.items() } def loss_with(values): diff --git a/braintrace/_etrace_algorithms/param_dim_vjp.py b/braintrace/_etrace_algorithms/param_dim_vjp.py index ab68f32..73be421 100644 --- a/braintrace/_etrace_algorithms/param_dim_vjp.py +++ b/braintrace/_etrace_algorithms/param_dim_vjp.py @@ -533,7 +533,7 @@ def _call_yw_to_w_dict(d, trace_, _rule=yw_to_w_rule, _params=eqn_params): _update_dict(dG_weights, key, val) -def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree): +def _remove_units(xs_maybe_quantity: PyTree): """ Removes units from a PyTree of quantities, returning a unitless PyTree and a function to restore the units. @@ -542,10 +542,10 @@ def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree): original units to the unitless PyTree. Args: - xs_maybe_quantity (brainstate.typing.PyTree): A PyTree structure containing quantities with units. + xs_maybe_quantity (PyTree): A PyTree structure containing quantities with units. Returns: - Tuple[brainstate.typing.PyTree, Callable]: A tuple containing: + Tuple[PyTree, Callable]: A tuple containing: - A PyTree with the same structure as the input, but with units removed from each quantity. - A function that takes a unitless PyTree and restores the original units to it. """ @@ -556,7 +556,7 @@ def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree): new_leaves.append(leaf) units.append(unit) - def restore_units(xs_unitless: brainstate.typing.PyTree): + def restore_units(xs_unitless: PyTree): leaves, treedef2 = jax.tree.flatten(xs_unitless) # jax's PyTreeDef stubs omit __eq__; the comparison is valid at runtime. assert treedef == treedef2, 'The tree structure should be the same. ' # type: ignore[operator] diff --git a/braintrace/_etrace_compiler/hid_param_op.py b/braintrace/_etrace_compiler/hid_param_op.py index 18b044a..6e9029f 100644 --- a/braintrace/_etrace_compiler/hid_param_op.py +++ b/braintrace/_etrace_compiler/hid_param_op.py @@ -800,6 +800,8 @@ def find_hidden_param_op_relations_from_jaxpr( invar, t_path, producers, weight_path_to_invars, ) t_state = path_to_state.get(t_path) + # ``t_path`` is a trainable weight path, so its state is always a present ParamState. + assert isinstance(t_state, brainstate.ParamState) trainable_vars[key] = invar trainable_paths[key] = t_path trainable_leaf_indices[key] = t_leaf diff --git a/braintrace/_etrace_compiler/hid_param_op_test.py b/braintrace/_etrace_compiler/hid_param_op_test.py index 2c21741..ba7883f 100644 --- a/braintrace/_etrace_compiler/hid_param_op_test.py +++ b/braintrace/_etrace_compiler/hid_param_op_test.py @@ -59,7 +59,7 @@ def test_gru_one_layer(self): assert relation.connected_hidden_paths[0] == ('h',) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -87,7 +87,7 @@ def test_snn_single_layer(self, cls): print(relations) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, diff --git a/braintrace/_etrace_compiler/hidden_group.py b/braintrace/_etrace_compiler/hidden_group.py index 265744a..a484c76 100644 --- a/braintrace/_etrace_compiler/hidden_group.py +++ b/braintrace/_etrace_compiler/hidden_group.py @@ -45,7 +45,7 @@ # -*- coding: utf-8 -*- from itertools import combinations -from typing import List, Dict, Sequence, Tuple, Set, Optional, Callable, NamedTuple, Any +from typing import List, Dict, Sequence, Tuple, Set, Optional, Callable, NamedTuple, Any, cast import brainstate import jax.core @@ -989,7 +989,9 @@ def find_hidden_groups_from_jaxpr( weight_invars=weight_invars, invar_to_hidden_path=invar_to_hidden_path, outvar_to_hidden_path=outvar_to_hidden_path, - path_to_state=path_to_state, + # the evaluator only indexes hidden-state paths, whose entries are HiddenStates, + # even though the passed mapping carries every model state. + path_to_state=cast(Dict[Path, brainstate.HiddenState], path_to_state), ) hidden_groups, hid_path_to_group = evaluator.compile() return hidden_groups, brainstate.util.PrettyDict(hid_path_to_group) diff --git a/braintrace/_etrace_compiler/hidden_group_test.py b/braintrace/_etrace_compiler/hidden_group_test.py index c5ed39d..ae252d5 100644 --- a/braintrace/_etrace_compiler/hidden_group_test.py +++ b/braintrace/_etrace_compiler/hidden_group_test.py @@ -146,7 +146,7 @@ def test_gru_one_layer(self, cls): print() @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -180,7 +180,7 @@ def test_snn_single_layer(self, cls): print() @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -556,7 +556,7 @@ def test_gru(self, cls): print(out_vals) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -591,7 +591,7 @@ def test_snn_single_layer(self, cls): print(out_vals) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -689,7 +689,7 @@ def test_gru_accuracy(self, cls): assert (u.math.allclose(diag_jac, jax_jac, atol=1e-5)) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -724,7 +724,7 @@ def test_snn_single_layer(self, cls): print(diag_jac) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -767,7 +767,7 @@ def test_snn_single_layer_accuracy(self, cls): assert (u.math.allclose(diag_jac, jax_jac, atol=1e-5)) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -801,7 +801,7 @@ def test_snn_two_layers(self, cls): print(diag_jac) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, diff --git a/braintrace/_etrace_compiler/hidden_pertubation_test.py b/braintrace/_etrace_compiler/hidden_pertubation_test.py index f08428d..011fb8c 100644 --- a/braintrace/_etrace_compiler/hidden_pertubation_test.py +++ b/braintrace/_etrace_compiler/hidden_pertubation_test.py @@ -62,7 +62,7 @@ def test_rnn_one_layer(self, cls): assert len(states) == len(hidden_perturb.init_perturb_data()) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -94,7 +94,7 @@ def test_snn_single_layer(self, cls): assert len(states) == len(perturb) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, diff --git a/braintrace/_etrace_compiler/module_info.py b/braintrace/_etrace_compiler/module_info.py index b477730..3f59589 100644 --- a/braintrace/_etrace_compiler/module_info.py +++ b/braintrace/_etrace_compiler/module_info.py @@ -33,6 +33,7 @@ from braintrace._state_managment import sequence_split_state_values from braintrace._typing import ( Path, + PyTree, StateID, Inputs, Outputs, @@ -269,11 +270,11 @@ class ModuleInfo(NamedTuple): closed_jaxpr: ClosedJaxpr # states - retrieved_model_states: brainstate.util.FlattedDict[Path, brainstate.State] + retrieved_model_states: brainstate.util.FlattedDict compiled_model_states: Sequence[brainstate.State] state_id_to_path: Dict[StateID, Path] - state_tree_invars: brainstate.typing.PyTree[Var] - state_tree_outvars: brainstate.typing.PyTree[Var] + state_tree_invars: PyTree + state_tree_outvars: PyTree # hidden states hidden_path_to_invar: Dict[Path, Var] @@ -434,7 +435,10 @@ def _process(self, *args, jaxpr_outs: Sequence[jax.Array]): cache_key = self.stateful_model.get_arg_cache_key(*args, compile_if_miss=True) i_start = self.num_var_out i_end = i_start + self.num_var_state - out, new_state_vals = self.stateful_model.get_out_treedef_by_cache(cache_key).unflatten(jaxpr_outs[:i_end]) + # brainstate types the cached treedef as the broad ``PyTree``; at runtime it is a + # ``jax`` ``PyTreeDef`` that exposes ``unflatten``. + out, new_state_vals = self.stateful_model.get_out_treedef_by_cache(cache_key).unflatten( # type: ignore[attr-defined] + jaxpr_outs[:i_end]) # # check state value diff --git a/braintrace/_etrace_compiler/module_info_test.py b/braintrace/_etrace_compiler/module_info_test.py index 90301a2..6671de8 100644 --- a/braintrace/_etrace_compiler/module_info_test.py +++ b/braintrace/_etrace_compiler/module_info_test.py @@ -57,7 +57,7 @@ def test_rnn_one_layer(self, cls): pprint(minfo) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -85,7 +85,7 @@ def test_snn_single_layer(self, cls): pprint(minfo) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, diff --git a/braintrace/_grad_exponential.py b/braintrace/_grad_exponential.py index 9ba8826..60fa7ee 100644 --- a/braintrace/_grad_exponential.py +++ b/braintrace/_grad_exponential.py @@ -18,6 +18,8 @@ import jax.tree import saiunit as u +from braintrace._typing import PyTree + __all__ = [ 'GradExpon', ] @@ -31,7 +33,7 @@ class GradExpon(brainstate.nn.Module): Parameters ---------- - grad_shape : brainstate.typing.PyTree + grad_shape : PyTree A pytree whose leaves give the shape and dtype of the gradients to accumulate. The accumulator is initialised to zeros matching each leaf. @@ -69,7 +71,7 @@ class GradExpon(brainstate.nn.Module): def __init__( self, - grad_shape: brainstate.typing.PyTree, + grad_shape: PyTree, tau_or_decay: u.Quantity | float, ): super().__init__() @@ -90,7 +92,7 @@ def __init__( raise TypeError(f"tau_or_decay must be a Quantity or a float, but got {tau_or_decay}") self.decay = decay - def update(self, grads: brainstate.typing.PyTree): + def update(self, grads: PyTree): r"""Update the accumulated gradients with the exponential decay rule. Applies :math:`g_{t+1} = \mathrm{decay} \cdot g_t + \mathrm{grads}`, @@ -100,7 +102,7 @@ def update(self, grads: brainstate.typing.PyTree): Parameters ---------- - grads : brainstate.typing.PyTree + grads : PyTree The new gradients to incorporate into the accumulated gradients. Must match the pytree structure of the accumulator. diff --git a/braintrace/_state_managment.py b/braintrace/_state_managment.py index 88a5061..dffb718 100644 --- a/braintrace/_state_managment.py +++ b/braintrace/_state_managment.py @@ -18,12 +18,12 @@ import brainstate pass # ParamState removed (primitive-based ETP) -from ._typing import Path +from ._typing import Path, PyTree def assign_dict_state_values( - states: Dict[Path, brainstate.State], - state_values: Dict[Path, brainstate.typing.PyTree], + states: Mapping[Path, brainstate.State], + state_values: Mapping[Path, PyTree], write: bool = True ): """ @@ -38,7 +38,7 @@ def assign_dict_state_values( states : Dict[Path, brainstate.State] A dictionary where keys are paths and values are state objects to which values will be assigned or restored. - state_values : Dict[Path, brainstate.typing.PyTree] + state_values : Dict[Path, PyTree] A dictionary where keys are paths and values are the values corresponding to each state in `states`. write : bool, optional @@ -62,7 +62,7 @@ def assign_dict_state_values( def assign_state_values_v2( states: Mapping[Any, brainstate.State], - state_values: Mapping[Any, brainstate.typing.PyTree], + state_values: Mapping[Any, PyTree], write: bool = True ): """ @@ -77,7 +77,7 @@ def assign_state_values_v2( states : Dict[Hashable, brainstate.State] A dictionary where keys are hashable identifiers and values are state objects to which values will be assigned or restored. - state_values : Dict[Hashable, brainstate.typing.PyTree] + state_values : Dict[Hashable, PyTree] A dictionary where keys are hashable identifiers and values are the values corresponding to each state in `states`. write : bool, optional @@ -106,18 +106,18 @@ def assign_state_values_v2( def sequence_split_state_values( states: Sequence[brainstate.State], - state_values: List[brainstate.typing.PyTree], + state_values: List[PyTree], include_weight: bool = True ) -> ( Tuple[ - Sequence[brainstate.typing.PyTree], - Sequence[brainstate.typing.PyTree], - Sequence[brainstate.typing.PyTree] + Sequence[PyTree], + Sequence[PyTree], + Sequence[PyTree] ] | Tuple[ - Sequence[brainstate.typing.PyTree], - Sequence[brainstate.typing.PyTree] + Sequence[PyTree], + Sequence[PyTree] ] ): """ diff --git a/braintrace/_typing.py b/braintrace/_typing.py index 1b983dd..db9a039 100644 --- a/braintrace/_typing.py +++ b/braintrace/_typing.py @@ -19,6 +19,7 @@ import brainstate import jax +import numpy as np from ._compatible_imports import Var @@ -32,6 +33,31 @@ WeightID: TypeAlias = int Size: TypeAlias = brainstate.typing.Size Axis: TypeAlias = int + + +def as_size_tuple(size: Size) -> Tuple[int, ...]: + """Normalize an ``in_size``/``out_size`` spec to a tuple of ints. + + ``brainstate``'s size setters accept a scalar ``int`` or a sequence, while + the matching getters are typed as the broad :data:`Size` union, which static + type checkers do not treat as indexable. Routing values through this helper + yields a concrete ``tuple[int, ...]`` so both property assignment and + trailing-dimension lookups (``size[-1]``) type-check cleanly, reproducing + ``brainstate``'s own normalization at runtime. + + Parameters + ---------- + size : Size + A scalar ``int`` / numpy integer, or a sequence of them. + + Returns + ------- + tuple of int + The size expressed as a tuple of Python ints. + """ + if isinstance(size, (int, np.integer)): + return (int(size),) + return tuple(int(s) for s in size) Axes: TypeAlias = Union[int, Sequence[int]] Path: TypeAlias = Tuple[str, ...] diff --git a/braintrace/nn/_conv.py b/braintrace/nn/_conv.py index 1927d76..2e00d81 100644 --- a/braintrace/nn/_conv.py +++ b/braintrace/nn/_conv.py @@ -60,17 +60,17 @@ def _etp_conv_op(self, x, params): class Conv1d(brainstate.nn.Conv1d): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.Conv1d.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.Conv1d.__doc__ or '').replace('brainstate', 'braintrace') _conv_op = _etp_conv_op class Conv2d(brainstate.nn.Conv2d): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.Conv2d.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.Conv2d.__doc__ or '').replace('brainstate', 'braintrace') _conv_op = _etp_conv_op class Conv3d(brainstate.nn.Conv3d): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.Conv3d.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.Conv3d.__doc__ or '').replace('brainstate', 'braintrace') _conv_op = _etp_conv_op diff --git a/braintrace/nn/_conv_test.py b/braintrace/nn/_conv_test.py index b9c1e01..f564465 100644 --- a/braintrace/nn/_conv_test.py +++ b/braintrace/nn/_conv_test.py @@ -494,12 +494,12 @@ def test_conv3d_explicit_padding_int(self): assert y.shape[-1] == 64 def test_conv3d_explicit_padding_tuple(self): - """Test Conv3d with explicit tuple padding.""" + """Test Conv3d with explicit tuple padding (one value per spatial dim).""" conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, - padding=(1, 1) + padding=(1, 1, 1) ) x = brainstate.random.randn(2, 16, 16, 16, 3) y = conv(x) @@ -662,17 +662,17 @@ class TestConvEdgeCases: def test_conv_invalid_groups_out_channels(self): """Test that out_channels must be divisible by groups.""" - with pytest.raises(AssertionError): + with pytest.raises(ValueError): braintrace.nn.Conv2d(in_size=(28, 28, 4), out_channels=9, kernel_size=3, groups=2) def test_conv_invalid_groups_in_channels(self): """Test that in_channels must be divisible by groups.""" - with pytest.raises(AssertionError): + with pytest.raises(ValueError): braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=8, kernel_size=3, groups=2) def test_conv_invalid_padding_string(self): """Test that only SAME and VALID are accepted as string padding.""" - with pytest.raises(AssertionError): + with pytest.raises(ValueError): braintrace.nn.Conv2d( in_size=(28, 28, 3), out_channels=32, @@ -693,7 +693,7 @@ def test_conv_invalid_padding_wrong_length(self): def test_conv_invalid_in_size_length(self): """Test that in_size must have correct length.""" - with pytest.raises(AssertionError): + with pytest.raises(ValueError): braintrace.nn.Conv2d( in_size=(28, 3), # Should be 3D for Conv2d out_channels=32, diff --git a/braintrace/nn/_linear.py b/braintrace/nn/_linear.py index b18a4d4..70dbe60 100644 --- a/braintrace/nn/_linear.py +++ b/braintrace/nn/_linear.py @@ -31,7 +31,7 @@ class Linear(brainstate.nn.Linear): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.Linear.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.Linear.__doc__ or '').replace('brainstate', 'braintrace') def update(self, x): """Apply the linear transform through the ETP ``matmul`` primitive. @@ -59,7 +59,7 @@ def update(self, x): class SignedWLinear(brainstate.nn.SignedWLinear): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.SignedWLinear.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.SignedWLinear.__doc__ or '').replace('brainstate', 'braintrace') def update(self, x): """Apply the sign-constrained linear transform through ETP ``matmul``. @@ -86,7 +86,7 @@ def update(self, x): class ScaledWSLinear(brainstate.nn.ScaledWSLinear): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.ScaledWSLinear.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.ScaledWSLinear.__doc__ or '').replace('brainstate', 'braintrace') def update(self, x): """Apply the weight-standardized linear transform through ETP ``matmul``. @@ -115,7 +115,7 @@ def update(self, x): class SparseLinear(brainstate.nn.SparseLinear): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.SparseLinear.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.SparseLinear.__doc__ or '').replace('brainstate', 'braintrace') def update(self, x): """Apply the sparse linear transform through the ETP ``sparse_matmul``. diff --git a/braintrace/nn/_readout.py b/braintrace/nn/_readout.py index 078a61d..1fe1bb4 100644 --- a/braintrace/nn/_readout.py +++ b/braintrace/nn/_readout.py @@ -15,7 +15,6 @@ # -*- coding: utf-8 -*- -import numbers from typing import Callable, Optional import brainstate @@ -23,7 +22,7 @@ import saiunit as u from braintrace._etrace_op import matmul -from braintrace._typing import Size, ArrayLike +from braintrace._typing import Size, ArrayLike, as_size_tuple __all__ = [ 'LeakyRateReadout', @@ -104,16 +103,16 @@ def __init__( self, in_size: Size, out_size: Size, - tau: ArrayLike = 5. * u.ms, + tau: ArrayLike = 5. * u.ms, # type: ignore[assignment] # saiunit types `float * Unit` as `Unit | Quantity` w_init: Callable = braintools.init.KaimingNormal(), r_init: Callable = braintools.init.ZeroInit(), name: Optional[str] = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # brainstate hides Module.__init__ from type checkers # parameters - self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size) - self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size) + self.in_size = as_size_tuple(in_size) + self.out_size = as_size_tuple(out_size) self.tau = braintools.init.param(tau, self.out_size) # Compute decay handling units properly tau_normalized = u.maybe_decimal(self.tau / brainstate.environ.get_dt()) @@ -122,7 +121,7 @@ def __init__( # weights self.W = brainstate.ParamState( - braintools.init.param(w_init, (self.in_size[0], self.out_size[0])) + braintools.init.param(w_init, (as_size_tuple(self.in_size)[0], as_size_tuple(self.out_size)[0])) ) def init_state(self, batch_size=None, **kwargs): diff --git a/braintrace/nn/_rnn.py b/braintrace/nn/_rnn.py index c194088..e892b30 100644 --- a/braintrace/nn/_rnn.py +++ b/braintrace/nn/_rnn.py @@ -14,14 +14,14 @@ # ============================================================================== # -*- coding: utf-8 -*- -from typing import Callable, Union +from typing import Any, Callable, Union import brainstate import braintools import saiunit as u from braintrace._etrace_op import element_wise -from braintrace._typing import ArrayLike +from braintrace._typing import ArrayLike, as_size_tuple as _as_size_tuple from ._linear import Linear __all__ = [ @@ -89,12 +89,12 @@ def __init__( activation: str | Callable = 'relu', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # activation function if isinstance(activation, str): @@ -105,7 +105,7 @@ def __init__( # weights self.W = Linear( - self.in_size[-1] + self.out_size[-1], self.out_size[-1], + _as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], w_init=w_init, b_init=b_init, ) @@ -188,12 +188,12 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # activation function if isinstance(activation, str): @@ -203,10 +203,10 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=b_init) - self.Wz = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wr = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wh = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.Wz = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wr = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wh = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -290,12 +290,12 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # activation function if isinstance(activation, str): @@ -305,10 +305,10 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=b_init) - self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wi = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wh = Linear(self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.Wf = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wi = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wh = Linear(_as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -406,12 +406,12 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # activation function if isinstance(activation, str): @@ -421,9 +421,9 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=b_init) - self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wh = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.Wf = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wh = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -536,11 +536,11 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # initializers self._state_initializer = state_init @@ -553,11 +553,11 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=b_init) - self.Wi = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wg = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wo = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.Wi = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wg = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wf = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wo = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.c = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -644,11 +644,11 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # initializers self._state_initializer = state_init @@ -661,15 +661,15 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=None) - self.Wu = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wr = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wo = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=None) + self.Wu = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wf = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wr = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wo = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) self.bias = brainstate.ParamState(self._forget_bias()) def _forget_bias(self): - rand_val = brainstate.random.uniform(1 / self.out_size[-1], 1 - 1 / self.out_size[-1], (self.out_size[-1],)) + rand_val = brainstate.random.uniform(1 / _as_size_tuple(self.out_size)[-1], 1 - 1 / _as_size_tuple(self.out_size)[-1], (_as_size_tuple(self.out_size)[-1],)) return -u.math.log(1 / rand_val - 1) def init_state(self, batch_size: int = None, **kwargs): @@ -780,22 +780,22 @@ def __init__( phi: Callable = None, name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # functions - params = dict(w_init=w_init, b_init=b_init) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) if phi is None: - phi = Linear(self.in_size[-1], self.out_size[-1], **params) + phi = Linear(_as_size_tuple(self.in_size)[-1], _as_size_tuple(self.out_size)[-1], **params) assert callable(phi), f"The phi function should be a callable function. But got {phi}" self.phi = phi # weights - self.W_u = Linear(self.out_size[-1] * 2, self.out_size[-1], **params) + self.W_u = Linear(_as_size_tuple(self.out_size)[-1] * 2, _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -881,19 +881,19 @@ def __init__( state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # functions - params = dict(w_init=w_init, b_init=b_init) - self.W_x = Linear(self.in_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.W_x = Linear(_as_size_tuple(self.in_size)[-1], _as_size_tuple(self.out_size)[-1], **params) # weights - self.W_z = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + self.W_z = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -978,20 +978,20 @@ def __init__( state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # functions - params = dict(w_init=w_init, b_init=b_init) - self.W_x = Linear(self.in_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.W_x = Linear(_as_size_tuple(self.in_size)[-1], _as_size_tuple(self.out_size)[-1], **params) # weights - self.W_f = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.W_i = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + self.W_f = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.W_i = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size))