diff --git a/braintrace/_compile.py b/braintrace/_compile.py index 9ede1b1..15b9f9f 100644 --- a/braintrace/_compile.py +++ b/braintrace/_compile.py @@ -32,7 +32,7 @@ # Canonical lowercase name (+ aliases) -> algorithm class. No bare ``ostl`` # alias: the ambiguous OSTL factory was removed in 0.2.0, so callers pick # ``ostl_recurrent`` vs ``ostl_feedforward`` explicitly. -_ALGORITHM_REGISTRY = { +_ALGORITHM_REGISTRY: dict[str, type[ETraceAlgorithm]] = { 'd_rtrl': D_RTRL, 'pp_prop': pp_prop, 'es_d_rtrl': pp_prop, diff --git a/braintrace/_etrace_algorithms/_common.py b/braintrace/_etrace_algorithms/_common.py index cb07ed6..61fb2ff 100644 --- a/braintrace/_etrace_algorithms/_common.py +++ b/braintrace/_etrace_algorithms/_common.py @@ -23,6 +23,8 @@ import jax.numpy as jnp import saiunit as u +from braintrace._typing import PyTree + __all__ = [ 'PresynapticTrace', 'KappaFilter', @@ -388,7 +390,7 @@ def _unit_safe_add(a, b): return u.math.add(a, b) -def _extract_leaf(pytree_val: brainstate.typing.PyTree, leaf_idx: int): +def _extract_leaf(pytree_val: PyTree, leaf_idx: int): """Return the leaf at ``leaf_idx`` in ``jax.tree.leaves(pytree_val)``. Bare arrays (treedef with a single leaf) return the array unchanged. @@ -405,7 +407,7 @@ def _extract_leaf(pytree_val: brainstate.typing.PyTree, leaf_idx: int): def _wrap_leaves_as_pytree( - reference_pytree: brainstate.typing.PyTree, + reference_pytree: PyTree, leaf_grads: Dict[int, jax.Array], ): """Build a pytree matching ``reference_pytree`` with ``leaf_grads`` @@ -441,8 +443,8 @@ def _wrap_leaves_as_pytree( def _route_grads_by_path( relation, per_key_grads: Dict[str, jax.Array], - weight_vals: Dict[Any, brainstate.typing.PyTree], - target_dict: Dict[Any, brainstate.typing.PyTree], + weight_vals: Dict[Any, PyTree], + target_dict: Dict[Any, PyTree], ) -> None: """Route per-key gradients from a dict-API rule into per-path pytrees. @@ -474,7 +476,7 @@ def _route_grads_by_path( def _update_dict( the_dict: Dict, key: Any, - value: brainstate.typing.PyTree, + value: PyTree, error_when_no_key: Optional[bool] = False ): """Update the dictionary. diff --git a/braintrace/_etrace_algorithms/base.py b/braintrace/_etrace_algorithms/base.py index 386a48c..dfcc17a 100644 --- a/braintrace/_etrace_algorithms/base.py +++ b/braintrace/_etrace_algorithms/base.py @@ -90,7 +90,7 @@ def __init__( graph_executor: ETraceGraphExecutor, name: Optional[str] = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # brainstate hides Module.__init__ from type checkers # the model if not isinstance(model, brainstate.nn.Module): @@ -115,9 +115,9 @@ def __init__( self.running_index = brainstate.LongTermState(0) # other states - self._param_states = None - self._hidden_states = None - self._other_states = None + self._param_states: Optional[brainstate.util.FlattedDict] = None + self._hidden_states: Optional[brainstate.util.FlattedDict] = None + self._other_states: Optional[brainstate.util.FlattedDict] = None @property def graph(self) -> ETraceGraph: @@ -144,7 +144,7 @@ def executor(self) -> ETraceGraphExecutor: return self.graph_executor @property - def param_states(self) -> brainstate.util.FlattedDict[Path, brainstate.ParamState]: + def param_states(self) -> brainstate.util.FlattedDict: """ Get the parameter weight states. @@ -155,10 +155,11 @@ def param_states(self) -> brainstate.util.FlattedDict[Path, brainstate.ParamStat """ if self._param_states is None: self._split_state() + assert self._param_states is not None return self._param_states @property - def hidden_states(self) -> brainstate.util.FlattedDict[Path, brainstate.HiddenState]: + def hidden_states(self) -> brainstate.util.FlattedDict: """ Get the hidden states. @@ -169,10 +170,11 @@ def hidden_states(self) -> brainstate.util.FlattedDict[Path, brainstate.HiddenSt """ if self._hidden_states is None: self._split_state() + assert self._hidden_states is not None return self._hidden_states @property - def other_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: + def other_states(self) -> brainstate.util.FlattedDict: """ Get the other states. @@ -183,6 +185,7 @@ def other_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: """ if self._other_states is None: self._split_state() + assert self._other_states is not None return self._other_states def _split_state(self): @@ -235,7 +238,7 @@ def compile_graph(self, *args) -> None: self.is_compiled = True @property - def path_to_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: + def path_to_states(self) -> brainstate.util.FlattedDict: """ Get the path to the states. diff --git a/braintrace/_etrace_algorithms/graph_executor.py b/braintrace/_etrace_algorithms/graph_executor.py index 3a6f25e..e391082 100644 --- a/braintrace/_etrace_algorithms/graph_executor.py +++ b/braintrace/_etrace_algorithms/graph_executor.py @@ -110,7 +110,7 @@ def graph(self) -> ETraceGraph: return self._compiled_graph @property - def states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: + def states(self) -> brainstate.util.FlattedDict: """ The states for the model. @@ -122,7 +122,7 @@ def states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: return self.graph.module_info.retrieved_model_states @property - def path_to_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]: + def path_to_states(self) -> brainstate.util.FlattedDict: """ The path to the states. @@ -216,6 +216,8 @@ def show_graph( # other hidden states other_states = [] short_states = self.states.filter(brainstate.ShortTermState) + # a single-type filter returns one FlattedDict (brainstate types it as a union) + assert isinstance(short_states, brainstate.util.FlattedDict) for i, path in enumerate(short_states.keys()): if path not in hidden_paths: other_states.append(path) @@ -239,7 +241,9 @@ def show_graph( msg += '\n\n' # non etrace weights - non_etratce_weight_paths = set(self.states.filter(brainstate.ParamState).keys()) + param_states = self.states.filter(brainstate.ParamState) + assert isinstance(param_states, brainstate.util.FlattedDict) + non_etratce_weight_paths = set(param_states.keys()) non_etratce_weight_paths = non_etratce_weight_paths.difference(etratce_weight_paths) if len(non_etratce_weight_paths): msg += 'The non-etrace weight parameters are:\n\n' diff --git a/braintrace/_etrace_algorithms/oracle.py b/braintrace/_etrace_algorithms/oracle.py index 6fbb587..87bb719 100644 --- a/braintrace/_etrace_algorithms/oracle.py +++ b/braintrace/_etrace_algorithms/oracle.py @@ -59,8 +59,10 @@ def finite_difference_param_gradients( """ template = model_factory() brainstate.nn.init_all_states(template, batch_size=1) + template_params = template.states(brainstate.ParamState) + assert isinstance(template_params, brainstate.util.FlattedDict) base_values = { - k: np.asarray(v.value) for k, v in template.states(brainstate.ParamState).items() + k: np.asarray(v.value) for k, v in template_params.items() } def loss_with(values): diff --git a/braintrace/_etrace_algorithms/param_dim_vjp.py b/braintrace/_etrace_algorithms/param_dim_vjp.py index ab68f32..73be421 100644 --- a/braintrace/_etrace_algorithms/param_dim_vjp.py +++ b/braintrace/_etrace_algorithms/param_dim_vjp.py @@ -533,7 +533,7 @@ def _call_yw_to_w_dict(d, trace_, _rule=yw_to_w_rule, _params=eqn_params): _update_dict(dG_weights, key, val) -def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree): +def _remove_units(xs_maybe_quantity: PyTree): """ Removes units from a PyTree of quantities, returning a unitless PyTree and a function to restore the units. @@ -542,10 +542,10 @@ def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree): original units to the unitless PyTree. Args: - xs_maybe_quantity (brainstate.typing.PyTree): A PyTree structure containing quantities with units. + xs_maybe_quantity (PyTree): A PyTree structure containing quantities with units. Returns: - Tuple[brainstate.typing.PyTree, Callable]: A tuple containing: + Tuple[PyTree, Callable]: A tuple containing: - A PyTree with the same structure as the input, but with units removed from each quantity. - A function that takes a unitless PyTree and restores the original units to it. """ @@ -556,7 +556,7 @@ def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree): new_leaves.append(leaf) units.append(unit) - def restore_units(xs_unitless: brainstate.typing.PyTree): + def restore_units(xs_unitless: PyTree): leaves, treedef2 = jax.tree.flatten(xs_unitless) # jax's PyTreeDef stubs omit __eq__; the comparison is valid at runtime. assert treedef == treedef2, 'The tree structure should be the same. ' # type: ignore[operator] diff --git a/braintrace/_etrace_compiler/hid_param_op.py b/braintrace/_etrace_compiler/hid_param_op.py index 18b044a..6e9029f 100644 --- a/braintrace/_etrace_compiler/hid_param_op.py +++ b/braintrace/_etrace_compiler/hid_param_op.py @@ -800,6 +800,8 @@ def find_hidden_param_op_relations_from_jaxpr( invar, t_path, producers, weight_path_to_invars, ) t_state = path_to_state.get(t_path) + # ``t_path`` is a trainable weight path, so its state is always a present ParamState. + assert isinstance(t_state, brainstate.ParamState) trainable_vars[key] = invar trainable_paths[key] = t_path trainable_leaf_indices[key] = t_leaf diff --git a/braintrace/_etrace_compiler/hid_param_op_test.py b/braintrace/_etrace_compiler/hid_param_op_test.py index 2c21741..ba7883f 100644 --- a/braintrace/_etrace_compiler/hid_param_op_test.py +++ b/braintrace/_etrace_compiler/hid_param_op_test.py @@ -59,7 +59,7 @@ def test_gru_one_layer(self): assert relation.connected_hidden_paths[0] == ('h',) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -87,7 +87,7 @@ def test_snn_single_layer(self, cls): print(relations) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, diff --git a/braintrace/_etrace_compiler/hidden_group.py b/braintrace/_etrace_compiler/hidden_group.py index 265744a..a484c76 100644 --- a/braintrace/_etrace_compiler/hidden_group.py +++ b/braintrace/_etrace_compiler/hidden_group.py @@ -45,7 +45,7 @@ # -*- coding: utf-8 -*- from itertools import combinations -from typing import List, Dict, Sequence, Tuple, Set, Optional, Callable, NamedTuple, Any +from typing import List, Dict, Sequence, Tuple, Set, Optional, Callable, NamedTuple, Any, cast import brainstate import jax.core @@ -989,7 +989,9 @@ def find_hidden_groups_from_jaxpr( weight_invars=weight_invars, invar_to_hidden_path=invar_to_hidden_path, outvar_to_hidden_path=outvar_to_hidden_path, - path_to_state=path_to_state, + # the evaluator only indexes hidden-state paths, whose entries are HiddenStates, + # even though the passed mapping carries every model state. + path_to_state=cast(Dict[Path, brainstate.HiddenState], path_to_state), ) hidden_groups, hid_path_to_group = evaluator.compile() return hidden_groups, brainstate.util.PrettyDict(hid_path_to_group) diff --git a/braintrace/_etrace_compiler/hidden_group_test.py b/braintrace/_etrace_compiler/hidden_group_test.py index c5ed39d..ae252d5 100644 --- a/braintrace/_etrace_compiler/hidden_group_test.py +++ b/braintrace/_etrace_compiler/hidden_group_test.py @@ -146,7 +146,7 @@ def test_gru_one_layer(self, cls): print() @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -180,7 +180,7 @@ def test_snn_single_layer(self, cls): print() @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -556,7 +556,7 @@ def test_gru(self, cls): print(out_vals) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -591,7 +591,7 @@ def test_snn_single_layer(self, cls): print(out_vals) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -689,7 +689,7 @@ def test_gru_accuracy(self, cls): assert (u.math.allclose(diag_jac, jax_jac, atol=1e-5)) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -724,7 +724,7 @@ def test_snn_single_layer(self, cls): print(diag_jac) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -767,7 +767,7 @@ def test_snn_single_layer_accuracy(self, cls): assert (u.math.allclose(diag_jac, jax_jac, atol=1e-5)) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -801,7 +801,7 @@ def test_snn_two_layers(self, cls): print(diag_jac) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, diff --git a/braintrace/_etrace_compiler/hidden_pertubation_test.py b/braintrace/_etrace_compiler/hidden_pertubation_test.py index f08428d..011fb8c 100644 --- a/braintrace/_etrace_compiler/hidden_pertubation_test.py +++ b/braintrace/_etrace_compiler/hidden_pertubation_test.py @@ -62,7 +62,7 @@ def test_rnn_one_layer(self, cls): assert len(states) == len(hidden_perturb.init_perturb_data()) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -94,7 +94,7 @@ def test_snn_single_layer(self, cls): assert len(states) == len(perturb) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, diff --git a/braintrace/_etrace_compiler/module_info.py b/braintrace/_etrace_compiler/module_info.py index b477730..3f59589 100644 --- a/braintrace/_etrace_compiler/module_info.py +++ b/braintrace/_etrace_compiler/module_info.py @@ -33,6 +33,7 @@ from braintrace._state_managment import sequence_split_state_values from braintrace._typing import ( Path, + PyTree, StateID, Inputs, Outputs, @@ -269,11 +270,11 @@ class ModuleInfo(NamedTuple): closed_jaxpr: ClosedJaxpr # states - retrieved_model_states: brainstate.util.FlattedDict[Path, brainstate.State] + retrieved_model_states: brainstate.util.FlattedDict compiled_model_states: Sequence[brainstate.State] state_id_to_path: Dict[StateID, Path] - state_tree_invars: brainstate.typing.PyTree[Var] - state_tree_outvars: brainstate.typing.PyTree[Var] + state_tree_invars: PyTree + state_tree_outvars: PyTree # hidden states hidden_path_to_invar: Dict[Path, Var] @@ -434,7 +435,10 @@ def _process(self, *args, jaxpr_outs: Sequence[jax.Array]): cache_key = self.stateful_model.get_arg_cache_key(*args, compile_if_miss=True) i_start = self.num_var_out i_end = i_start + self.num_var_state - out, new_state_vals = self.stateful_model.get_out_treedef_by_cache(cache_key).unflatten(jaxpr_outs[:i_end]) + # brainstate types the cached treedef as the broad ``PyTree``; at runtime it is a + # ``jax`` ``PyTreeDef`` that exposes ``unflatten``. + out, new_state_vals = self.stateful_model.get_out_treedef_by_cache(cache_key).unflatten( # type: ignore[attr-defined] + jaxpr_outs[:i_end]) # # check state value diff --git a/braintrace/_etrace_compiler/module_info_test.py b/braintrace/_etrace_compiler/module_info_test.py index 90301a2..6671de8 100644 --- a/braintrace/_etrace_compiler/module_info_test.py +++ b/braintrace/_etrace_compiler/module_info_test.py @@ -57,7 +57,7 @@ def test_rnn_one_layer(self, cls): pprint(minfo) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, @@ -85,7 +85,7 @@ def test_snn_single_layer(self, cls): pprint(minfo) @pytest.mark.parametrize( - 'cls,', + 'cls', [ IF_Delta_Dense_Layer, LIF_ExpCo_Dense_Layer, diff --git a/braintrace/_grad_exponential.py b/braintrace/_grad_exponential.py index 9ba8826..60fa7ee 100644 --- a/braintrace/_grad_exponential.py +++ b/braintrace/_grad_exponential.py @@ -18,6 +18,8 @@ import jax.tree import saiunit as u +from braintrace._typing import PyTree + __all__ = [ 'GradExpon', ] @@ -31,7 +33,7 @@ class GradExpon(brainstate.nn.Module): Parameters ---------- - grad_shape : brainstate.typing.PyTree + grad_shape : PyTree A pytree whose leaves give the shape and dtype of the gradients to accumulate. The accumulator is initialised to zeros matching each leaf. @@ -69,7 +71,7 @@ class GradExpon(brainstate.nn.Module): def __init__( self, - grad_shape: brainstate.typing.PyTree, + grad_shape: PyTree, tau_or_decay: u.Quantity | float, ): super().__init__() @@ -90,7 +92,7 @@ def __init__( raise TypeError(f"tau_or_decay must be a Quantity or a float, but got {tau_or_decay}") self.decay = decay - def update(self, grads: brainstate.typing.PyTree): + def update(self, grads: PyTree): r"""Update the accumulated gradients with the exponential decay rule. Applies :math:`g_{t+1} = \mathrm{decay} \cdot g_t + \mathrm{grads}`, @@ -100,7 +102,7 @@ def update(self, grads: brainstate.typing.PyTree): Parameters ---------- - grads : brainstate.typing.PyTree + grads : PyTree The new gradients to incorporate into the accumulated gradients. Must match the pytree structure of the accumulator. diff --git a/braintrace/_state_managment.py b/braintrace/_state_managment.py index 88a5061..dffb718 100644 --- a/braintrace/_state_managment.py +++ b/braintrace/_state_managment.py @@ -18,12 +18,12 @@ import brainstate pass # ParamState removed (primitive-based ETP) -from ._typing import Path +from ._typing import Path, PyTree def assign_dict_state_values( - states: Dict[Path, brainstate.State], - state_values: Dict[Path, brainstate.typing.PyTree], + states: Mapping[Path, brainstate.State], + state_values: Mapping[Path, PyTree], write: bool = True ): """ @@ -38,7 +38,7 @@ def assign_dict_state_values( states : Dict[Path, brainstate.State] A dictionary where keys are paths and values are state objects to which values will be assigned or restored. - state_values : Dict[Path, brainstate.typing.PyTree] + state_values : Dict[Path, PyTree] A dictionary where keys are paths and values are the values corresponding to each state in `states`. write : bool, optional @@ -62,7 +62,7 @@ def assign_dict_state_values( def assign_state_values_v2( states: Mapping[Any, brainstate.State], - state_values: Mapping[Any, brainstate.typing.PyTree], + state_values: Mapping[Any, PyTree], write: bool = True ): """ @@ -77,7 +77,7 @@ def assign_state_values_v2( states : Dict[Hashable, brainstate.State] A dictionary where keys are hashable identifiers and values are state objects to which values will be assigned or restored. - state_values : Dict[Hashable, brainstate.typing.PyTree] + state_values : Dict[Hashable, PyTree] A dictionary where keys are hashable identifiers and values are the values corresponding to each state in `states`. write : bool, optional @@ -106,18 +106,18 @@ def assign_state_values_v2( def sequence_split_state_values( states: Sequence[brainstate.State], - state_values: List[brainstate.typing.PyTree], + state_values: List[PyTree], include_weight: bool = True ) -> ( Tuple[ - Sequence[brainstate.typing.PyTree], - Sequence[brainstate.typing.PyTree], - Sequence[brainstate.typing.PyTree] + Sequence[PyTree], + Sequence[PyTree], + Sequence[PyTree] ] | Tuple[ - Sequence[brainstate.typing.PyTree], - Sequence[brainstate.typing.PyTree] + Sequence[PyTree], + Sequence[PyTree] ] ): """ diff --git a/braintrace/_typing.py b/braintrace/_typing.py index 1b983dd..db9a039 100644 --- a/braintrace/_typing.py +++ b/braintrace/_typing.py @@ -19,6 +19,7 @@ import brainstate import jax +import numpy as np from ._compatible_imports import Var @@ -32,6 +33,31 @@ WeightID: TypeAlias = int Size: TypeAlias = brainstate.typing.Size Axis: TypeAlias = int + + +def as_size_tuple(size: Size) -> Tuple[int, ...]: + """Normalize an ``in_size``/``out_size`` spec to a tuple of ints. + + ``brainstate``'s size setters accept a scalar ``int`` or a sequence, while + the matching getters are typed as the broad :data:`Size` union, which static + type checkers do not treat as indexable. Routing values through this helper + yields a concrete ``tuple[int, ...]`` so both property assignment and + trailing-dimension lookups (``size[-1]``) type-check cleanly, reproducing + ``brainstate``'s own normalization at runtime. + + Parameters + ---------- + size : Size + A scalar ``int`` / numpy integer, or a sequence of them. + + Returns + ------- + tuple of int + The size expressed as a tuple of Python ints. + """ + if isinstance(size, (int, np.integer)): + return (int(size),) + return tuple(int(s) for s in size) Axes: TypeAlias = Union[int, Sequence[int]] Path: TypeAlias = Tuple[str, ...] diff --git a/braintrace/nn/_conv.py b/braintrace/nn/_conv.py index 1927d76..2e00d81 100644 --- a/braintrace/nn/_conv.py +++ b/braintrace/nn/_conv.py @@ -60,17 +60,17 @@ def _etp_conv_op(self, x, params): class Conv1d(brainstate.nn.Conv1d): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.Conv1d.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.Conv1d.__doc__ or '').replace('brainstate', 'braintrace') _conv_op = _etp_conv_op class Conv2d(brainstate.nn.Conv2d): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.Conv2d.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.Conv2d.__doc__ or '').replace('brainstate', 'braintrace') _conv_op = _etp_conv_op class Conv3d(brainstate.nn.Conv3d): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.Conv3d.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.Conv3d.__doc__ or '').replace('brainstate', 'braintrace') _conv_op = _etp_conv_op diff --git a/braintrace/nn/_conv_test.py b/braintrace/nn/_conv_test.py index b9c1e01..f564465 100644 --- a/braintrace/nn/_conv_test.py +++ b/braintrace/nn/_conv_test.py @@ -494,12 +494,12 @@ def test_conv3d_explicit_padding_int(self): assert y.shape[-1] == 64 def test_conv3d_explicit_padding_tuple(self): - """Test Conv3d with explicit tuple padding.""" + """Test Conv3d with explicit tuple padding (one value per spatial dim).""" conv = braintrace.nn.Conv3d( in_size=(16, 16, 16, 3), out_channels=64, kernel_size=3, - padding=(1, 1) + padding=(1, 1, 1) ) x = brainstate.random.randn(2, 16, 16, 16, 3) y = conv(x) @@ -662,17 +662,17 @@ class TestConvEdgeCases: def test_conv_invalid_groups_out_channels(self): """Test that out_channels must be divisible by groups.""" - with pytest.raises(AssertionError): + with pytest.raises(ValueError): braintrace.nn.Conv2d(in_size=(28, 28, 4), out_channels=9, kernel_size=3, groups=2) def test_conv_invalid_groups_in_channels(self): """Test that in_channels must be divisible by groups.""" - with pytest.raises(AssertionError): + with pytest.raises(ValueError): braintrace.nn.Conv2d(in_size=(28, 28, 3), out_channels=8, kernel_size=3, groups=2) def test_conv_invalid_padding_string(self): """Test that only SAME and VALID are accepted as string padding.""" - with pytest.raises(AssertionError): + with pytest.raises(ValueError): braintrace.nn.Conv2d( in_size=(28, 28, 3), out_channels=32, @@ -693,7 +693,7 @@ def test_conv_invalid_padding_wrong_length(self): def test_conv_invalid_in_size_length(self): """Test that in_size must have correct length.""" - with pytest.raises(AssertionError): + with pytest.raises(ValueError): braintrace.nn.Conv2d( in_size=(28, 3), # Should be 3D for Conv2d out_channels=32, diff --git a/braintrace/nn/_linear.py b/braintrace/nn/_linear.py index b18a4d4..70dbe60 100644 --- a/braintrace/nn/_linear.py +++ b/braintrace/nn/_linear.py @@ -31,7 +31,7 @@ class Linear(brainstate.nn.Linear): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.Linear.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.Linear.__doc__ or '').replace('brainstate', 'braintrace') def update(self, x): """Apply the linear transform through the ETP ``matmul`` primitive. @@ -59,7 +59,7 @@ def update(self, x): class SignedWLinear(brainstate.nn.SignedWLinear): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.SignedWLinear.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.SignedWLinear.__doc__ or '').replace('brainstate', 'braintrace') def update(self, x): """Apply the sign-constrained linear transform through ETP ``matmul``. @@ -86,7 +86,7 @@ def update(self, x): class ScaledWSLinear(brainstate.nn.ScaledWSLinear): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.ScaledWSLinear.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.ScaledWSLinear.__doc__ or '').replace('brainstate', 'braintrace') def update(self, x): """Apply the weight-standardized linear transform through ETP ``matmul``. @@ -115,7 +115,7 @@ def update(self, x): class SparseLinear(brainstate.nn.SparseLinear): __module__ = 'braintrace.nn' - __doc__ = brainstate.nn.SparseLinear.__doc__.replace('brainstate', 'braintrace') + __doc__ = (brainstate.nn.SparseLinear.__doc__ or '').replace('brainstate', 'braintrace') def update(self, x): """Apply the sparse linear transform through the ETP ``sparse_matmul``. diff --git a/braintrace/nn/_readout.py b/braintrace/nn/_readout.py index 078a61d..1fe1bb4 100644 --- a/braintrace/nn/_readout.py +++ b/braintrace/nn/_readout.py @@ -15,7 +15,6 @@ # -*- coding: utf-8 -*- -import numbers from typing import Callable, Optional import brainstate @@ -23,7 +22,7 @@ import saiunit as u from braintrace._etrace_op import matmul -from braintrace._typing import Size, ArrayLike +from braintrace._typing import Size, ArrayLike, as_size_tuple __all__ = [ 'LeakyRateReadout', @@ -104,16 +103,16 @@ def __init__( self, in_size: Size, out_size: Size, - tau: ArrayLike = 5. * u.ms, + tau: ArrayLike = 5. * u.ms, # type: ignore[assignment] # saiunit types `float * Unit` as `Unit | Quantity` w_init: Callable = braintools.init.KaimingNormal(), r_init: Callable = braintools.init.ZeroInit(), name: Optional[str] = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # brainstate hides Module.__init__ from type checkers # parameters - self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size) - self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size) + self.in_size = as_size_tuple(in_size) + self.out_size = as_size_tuple(out_size) self.tau = braintools.init.param(tau, self.out_size) # Compute decay handling units properly tau_normalized = u.maybe_decimal(self.tau / brainstate.environ.get_dt()) @@ -122,7 +121,7 @@ def __init__( # weights self.W = brainstate.ParamState( - braintools.init.param(w_init, (self.in_size[0], self.out_size[0])) + braintools.init.param(w_init, (as_size_tuple(self.in_size)[0], as_size_tuple(self.out_size)[0])) ) def init_state(self, batch_size=None, **kwargs): diff --git a/braintrace/nn/_rnn.py b/braintrace/nn/_rnn.py index c194088..e892b30 100644 --- a/braintrace/nn/_rnn.py +++ b/braintrace/nn/_rnn.py @@ -14,14 +14,14 @@ # ============================================================================== # -*- coding: utf-8 -*- -from typing import Callable, Union +from typing import Any, Callable, Union import brainstate import braintools import saiunit as u from braintrace._etrace_op import element_wise -from braintrace._typing import ArrayLike +from braintrace._typing import ArrayLike, as_size_tuple as _as_size_tuple from ._linear import Linear __all__ = [ @@ -89,12 +89,12 @@ def __init__( activation: str | Callable = 'relu', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # activation function if isinstance(activation, str): @@ -105,7 +105,7 @@ def __init__( # weights self.W = Linear( - self.in_size[-1] + self.out_size[-1], self.out_size[-1], + _as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], w_init=w_init, b_init=b_init, ) @@ -188,12 +188,12 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # activation function if isinstance(activation, str): @@ -203,10 +203,10 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=b_init) - self.Wz = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wr = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wh = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.Wz = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wr = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wh = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -290,12 +290,12 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # activation function if isinstance(activation, str): @@ -305,10 +305,10 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=b_init) - self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wi = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wh = Linear(self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.Wf = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wi = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wh = Linear(_as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -406,12 +406,12 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # activation function if isinstance(activation, str): @@ -421,9 +421,9 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=b_init) - self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wh = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.Wf = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wh = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -536,11 +536,11 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # initializers self._state_initializer = state_init @@ -553,11 +553,11 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=b_init) - self.Wi = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wg = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wo = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.Wi = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wg = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wf = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wo = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.c = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -644,11 +644,11 @@ def __init__( activation: str | Callable = 'tanh', name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # initializers self._state_initializer = state_init @@ -661,15 +661,15 @@ def __init__( self.activation = activation # weights - params = dict(w_init=w_init, b_init=None) - self.Wu = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wf = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wr = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.Wo = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=None) + self.Wu = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wf = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wr = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.Wo = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) self.bias = brainstate.ParamState(self._forget_bias()) def _forget_bias(self): - rand_val = brainstate.random.uniform(1 / self.out_size[-1], 1 - 1 / self.out_size[-1], (self.out_size[-1],)) + rand_val = brainstate.random.uniform(1 / _as_size_tuple(self.out_size)[-1], 1 - 1 / _as_size_tuple(self.out_size)[-1], (_as_size_tuple(self.out_size)[-1],)) return -u.math.log(1 / rand_val - 1) def init_state(self, batch_size: int = None, **kwargs): @@ -780,22 +780,22 @@ def __init__( phi: Callable = None, name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # functions - params = dict(w_init=w_init, b_init=b_init) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) if phi is None: - phi = Linear(self.in_size[-1], self.out_size[-1], **params) + phi = Linear(_as_size_tuple(self.in_size)[-1], _as_size_tuple(self.out_size)[-1], **params) assert callable(phi), f"The phi function should be a callable function. But got {phi}" self.phi = phi # weights - self.W_u = Linear(self.out_size[-1] * 2, self.out_size[-1], **params) + self.W_u = Linear(_as_size_tuple(self.out_size)[-1] * 2, _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -881,19 +881,19 @@ def __init__( state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # functions - params = dict(w_init=w_init, b_init=b_init) - self.W_x = Linear(self.in_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.W_x = Linear(_as_size_tuple(self.in_size)[-1], _as_size_tuple(self.out_size)[-1], **params) # weights - self.W_z = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + self.W_z = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size)) @@ -978,20 +978,20 @@ def __init__( state_init: Union[ArrayLike, Callable] = braintools.init.ZeroInit(), name: str = None, ): - super().__init__(name=name) + super().__init__(name=name) # type: ignore[call-arg] # parameters self._state_initializer = state_init - self.out_size = out_size - self.in_size = in_size + self.out_size = _as_size_tuple(out_size) + self.in_size = _as_size_tuple(in_size) # functions - params = dict(w_init=w_init, b_init=b_init) - self.W_x = Linear(self.in_size[-1], self.out_size[-1], **params) + params: dict[str, Any] = dict(w_init=w_init, b_init=b_init) + self.W_x = Linear(_as_size_tuple(self.in_size)[-1], _as_size_tuple(self.out_size)[-1], **params) # weights - self.W_f = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) - self.W_i = Linear(self.in_size[-1] + self.out_size[-1], self.out_size[-1], **params) + self.W_f = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) + self.W_i = Linear(_as_size_tuple(self.in_size)[-1] + _as_size_tuple(self.out_size)[-1], _as_size_tuple(self.out_size)[-1], **params) def init_state(self, batch_size: int = None, **kwargs): self.h = brainstate.HiddenState(braintools.init.param(self._state_initializer, self.out_size, batch_size))