Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion braintrace/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# Canonical lowercase name (+ aliases) -> algorithm class. No bare ``ostl``
# alias: the ambiguous OSTL factory was removed in 0.2.0, so callers pick
# ``ostl_recurrent`` vs ``ostl_feedforward`` explicitly.
_ALGORITHM_REGISTRY = {
_ALGORITHM_REGISTRY: dict[str, type[ETraceAlgorithm]] = {
'd_rtrl': D_RTRL,
'pp_prop': pp_prop,
'es_d_rtrl': pp_prop,
Expand Down
12 changes: 7 additions & 5 deletions braintrace/_etrace_algorithms/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import jax.numpy as jnp
import saiunit as u

from braintrace._typing import PyTree

__all__ = [
'PresynapticTrace',
'KappaFilter',
Expand Down Expand Up @@ -388,7 +390,7 @@ def _unit_safe_add(a, b):
return u.math.add(a, b)


def _extract_leaf(pytree_val: brainstate.typing.PyTree, leaf_idx: int):
def _extract_leaf(pytree_val: PyTree, leaf_idx: int):
"""Return the leaf at ``leaf_idx`` in ``jax.tree.leaves(pytree_val)``.

Bare arrays (treedef with a single leaf) return the array unchanged.
Expand All @@ -405,7 +407,7 @@ def _extract_leaf(pytree_val: brainstate.typing.PyTree, leaf_idx: int):


def _wrap_leaves_as_pytree(
reference_pytree: brainstate.typing.PyTree,
reference_pytree: PyTree,
leaf_grads: Dict[int, jax.Array],
):
"""Build a pytree matching ``reference_pytree`` with ``leaf_grads``
Expand Down Expand Up @@ -441,8 +443,8 @@ def _wrap_leaves_as_pytree(
def _route_grads_by_path(
relation,
per_key_grads: Dict[str, jax.Array],
weight_vals: Dict[Any, brainstate.typing.PyTree],
target_dict: Dict[Any, brainstate.typing.PyTree],
weight_vals: Dict[Any, PyTree],
target_dict: Dict[Any, PyTree],
) -> None:
"""Route per-key gradients from a dict-API rule into per-path pytrees.

Expand Down Expand Up @@ -474,7 +476,7 @@ def _route_grads_by_path(
def _update_dict(
the_dict: Dict,
key: Any,
value: brainstate.typing.PyTree,
value: PyTree,
error_when_no_key: Optional[bool] = False
):
"""Update the dictionary.
Expand Down
19 changes: 11 additions & 8 deletions braintrace/_etrace_algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
graph_executor: ETraceGraphExecutor,
name: Optional[str] = None,
):
super().__init__(name=name)
super().__init__(name=name) # type: ignore[call-arg] # brainstate hides Module.__init__ from type checkers

# the model
if not isinstance(model, brainstate.nn.Module):
Expand All @@ -115,9 +115,9 @@ def __init__(
self.running_index = brainstate.LongTermState(0)

# other states
self._param_states = None
self._hidden_states = None
self._other_states = None
self._param_states: Optional[brainstate.util.FlattedDict] = None
self._hidden_states: Optional[brainstate.util.FlattedDict] = None
self._other_states: Optional[brainstate.util.FlattedDict] = None

@property
def graph(self) -> ETraceGraph:
Expand All @@ -144,7 +144,7 @@ def executor(self) -> ETraceGraphExecutor:
return self.graph_executor

@property
def param_states(self) -> brainstate.util.FlattedDict[Path, brainstate.ParamState]:
def param_states(self) -> brainstate.util.FlattedDict:
"""
Get the parameter weight states.

Expand All @@ -155,10 +155,11 @@ def param_states(self) -> brainstate.util.FlattedDict[Path, brainstate.ParamStat
"""
if self._param_states is None:
self._split_state()
assert self._param_states is not None
return self._param_states

@property
def hidden_states(self) -> brainstate.util.FlattedDict[Path, brainstate.HiddenState]:
def hidden_states(self) -> brainstate.util.FlattedDict:
"""
Get the hidden states.

Expand All @@ -169,10 +170,11 @@ def hidden_states(self) -> brainstate.util.FlattedDict[Path, brainstate.HiddenSt
"""
if self._hidden_states is None:
self._split_state()
assert self._hidden_states is not None
return self._hidden_states

@property
def other_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
def other_states(self) -> brainstate.util.FlattedDict:
"""
Get the other states.

Expand All @@ -183,6 +185,7 @@ def other_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
"""
if self._other_states is None:
self._split_state()
assert self._other_states is not None
return self._other_states

def _split_state(self):
Expand Down Expand Up @@ -235,7 +238,7 @@ def compile_graph(self, *args) -> None:
self.is_compiled = True

@property
def path_to_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
def path_to_states(self) -> brainstate.util.FlattedDict:
"""
Get the path to the states.

Expand Down
10 changes: 7 additions & 3 deletions braintrace/_etrace_algorithms/graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def graph(self) -> ETraceGraph:
return self._compiled_graph

@property
def states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
def states(self) -> brainstate.util.FlattedDict:
"""
The states for the model.

Expand All @@ -122,7 +122,7 @@ def states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
return self.graph.module_info.retrieved_model_states

@property
def path_to_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
def path_to_states(self) -> brainstate.util.FlattedDict:
"""
The path to the states.

Expand Down Expand Up @@ -216,6 +216,8 @@ def show_graph(
# other hidden states
other_states = []
short_states = self.states.filter(brainstate.ShortTermState)
# a single-type filter returns one FlattedDict (brainstate types it as a union)
assert isinstance(short_states, brainstate.util.FlattedDict)
for i, path in enumerate(short_states.keys()):
if path not in hidden_paths:
other_states.append(path)
Expand All @@ -239,7 +241,9 @@ def show_graph(
msg += '\n\n'

# non etrace weights
non_etratce_weight_paths = set(self.states.filter(brainstate.ParamState).keys())
param_states = self.states.filter(brainstate.ParamState)
assert isinstance(param_states, brainstate.util.FlattedDict)
non_etratce_weight_paths = set(param_states.keys())
non_etratce_weight_paths = non_etratce_weight_paths.difference(etratce_weight_paths)
if len(non_etratce_weight_paths):
msg += 'The non-etrace weight parameters are:\n\n'
Expand Down
4 changes: 3 additions & 1 deletion braintrace/_etrace_algorithms/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ def finite_difference_param_gradients(
"""
template = model_factory()
brainstate.nn.init_all_states(template, batch_size=1)
template_params = template.states(brainstate.ParamState)
assert isinstance(template_params, brainstate.util.FlattedDict)
base_values = {
k: np.asarray(v.value) for k, v in template.states(brainstate.ParamState).items()
k: np.asarray(v.value) for k, v in template_params.items()
}

def loss_with(values):
Expand Down
8 changes: 4 additions & 4 deletions braintrace/_etrace_algorithms/param_dim_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def _call_yw_to_w_dict(d, trace_, _rule=yw_to_w_rule, _params=eqn_params):
_update_dict(dG_weights, key, val)


def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree):
def _remove_units(xs_maybe_quantity: PyTree):
"""
Removes units from a PyTree of quantities, returning a unitless PyTree and a function to restore the units.

Expand All @@ -542,10 +542,10 @@ def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree):
original units to the unitless PyTree.

Args:
xs_maybe_quantity (brainstate.typing.PyTree): A PyTree structure containing quantities with units.
xs_maybe_quantity (PyTree): A PyTree structure containing quantities with units.

Returns:
Tuple[brainstate.typing.PyTree, Callable]: A tuple containing:
Tuple[PyTree, Callable]: A tuple containing:
- A PyTree with the same structure as the input, but with units removed from each quantity.
- A function that takes a unitless PyTree and restores the original units to it.
"""
Expand All @@ -556,7 +556,7 @@ def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree):
new_leaves.append(leaf)
units.append(unit)

def restore_units(xs_unitless: brainstate.typing.PyTree):
def restore_units(xs_unitless: PyTree):
leaves, treedef2 = jax.tree.flatten(xs_unitless)
# jax's PyTreeDef stubs omit __eq__; the comparison is valid at runtime.
assert treedef == treedef2, 'The tree structure should be the same. ' # type: ignore[operator]
Expand Down
2 changes: 2 additions & 0 deletions braintrace/_etrace_compiler/hid_param_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,8 @@ def find_hidden_param_op_relations_from_jaxpr(
invar, t_path, producers, weight_path_to_invars,
)
t_state = path_to_state.get(t_path)
# ``t_path`` is a trainable weight path, so its state is always a present ParamState.
assert isinstance(t_state, brainstate.ParamState)
trainable_vars[key] = invar
trainable_paths[key] = t_path
trainable_leaf_indices[key] = t_leaf
Expand Down
4 changes: 2 additions & 2 deletions braintrace/_etrace_compiler/hid_param_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_gru_one_layer(self):
assert relation.connected_hidden_paths[0] == ('h',)

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_snn_single_layer(self, cls):
print(relations)

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down
6 changes: 4 additions & 2 deletions braintrace/_etrace_compiler/hidden_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# -*- coding: utf-8 -*-

from itertools import combinations
from typing import List, Dict, Sequence, Tuple, Set, Optional, Callable, NamedTuple, Any
from typing import List, Dict, Sequence, Tuple, Set, Optional, Callable, NamedTuple, Any, cast

import brainstate
import jax.core
Expand Down Expand Up @@ -989,7 +989,9 @@ def find_hidden_groups_from_jaxpr(
weight_invars=weight_invars,
invar_to_hidden_path=invar_to_hidden_path,
outvar_to_hidden_path=outvar_to_hidden_path,
path_to_state=path_to_state,
# the evaluator only indexes hidden-state paths, whose entries are HiddenStates,
# even though the passed mapping carries every model state.
path_to_state=cast(Dict[Path, brainstate.HiddenState], path_to_state),
)
hidden_groups, hid_path_to_group = evaluator.compile()
return hidden_groups, brainstate.util.PrettyDict(hid_path_to_group)
Expand Down
16 changes: 8 additions & 8 deletions braintrace/_etrace_compiler/hidden_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_gru_one_layer(self, cls):
print()

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_snn_single_layer(self, cls):
print()

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down Expand Up @@ -556,7 +556,7 @@ def test_gru(self, cls):
print(out_vals)

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down Expand Up @@ -591,7 +591,7 @@ def test_snn_single_layer(self, cls):
print(out_vals)

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down Expand Up @@ -689,7 +689,7 @@ def test_gru_accuracy(self, cls):
assert (u.math.allclose(diag_jac, jax_jac, atol=1e-5))

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down Expand Up @@ -724,7 +724,7 @@ def test_snn_single_layer(self, cls):
print(diag_jac)

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down Expand Up @@ -767,7 +767,7 @@ def test_snn_single_layer_accuracy(self, cls):
assert (u.math.allclose(diag_jac, jax_jac, atol=1e-5))

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down Expand Up @@ -801,7 +801,7 @@ def test_snn_two_layers(self, cls):
print(diag_jac)

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down
4 changes: 2 additions & 2 deletions braintrace/_etrace_compiler/hidden_pertubation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_rnn_one_layer(self, cls):
assert len(states) == len(hidden_perturb.init_perturb_data())

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_snn_single_layer(self, cls):
assert len(states) == len(perturb)

@pytest.mark.parametrize(
'cls,',
'cls',
[
IF_Delta_Dense_Layer,
LIF_ExpCo_Dense_Layer,
Expand Down
12 changes: 8 additions & 4 deletions braintrace/_etrace_compiler/module_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from braintrace._state_managment import sequence_split_state_values
from braintrace._typing import (
Path,
PyTree,
StateID,
Inputs,
Outputs,
Expand Down Expand Up @@ -269,11 +270,11 @@ class ModuleInfo(NamedTuple):
closed_jaxpr: ClosedJaxpr

# states
retrieved_model_states: brainstate.util.FlattedDict[Path, brainstate.State]
retrieved_model_states: brainstate.util.FlattedDict
compiled_model_states: Sequence[brainstate.State]
state_id_to_path: Dict[StateID, Path]
state_tree_invars: brainstate.typing.PyTree[Var]
state_tree_outvars: brainstate.typing.PyTree[Var]
state_tree_invars: PyTree
state_tree_outvars: PyTree

# hidden states
hidden_path_to_invar: Dict[Path, Var]
Expand Down Expand Up @@ -434,7 +435,10 @@ def _process(self, *args, jaxpr_outs: Sequence[jax.Array]):
cache_key = self.stateful_model.get_arg_cache_key(*args, compile_if_miss=True)
i_start = self.num_var_out
i_end = i_start + self.num_var_state
out, new_state_vals = self.stateful_model.get_out_treedef_by_cache(cache_key).unflatten(jaxpr_outs[:i_end])
# brainstate types the cached treedef as the broad ``PyTree``; at runtime it is a
# ``jax`` ``PyTreeDef`` that exposes ``unflatten``.
out, new_state_vals = self.stateful_model.get_out_treedef_by_cache(cache_key).unflatten( # type: ignore[attr-defined]
jaxpr_outs[:i_end])

#
# check state value
Expand Down
Loading