Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions brainpy/math/delayvars_coverage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,40 @@ def test_update_value_none_without_target_raises():
assert ld.delay_target is None
with pytest.raises(ValueError):
ld.update(None)


# ---------------------------------------------------------------------------
# ring-buffer correctness regressions (guards the ``% num_delay_step`` modulo
# in ``TimeDelay._true_fn`` and the rotate index in ``LengthDelay``)
# ---------------------------------------------------------------------------

def test_time_delay_ring_buffer_wraps_modulo():
"""``_true_fn`` must read ``data[(idx + step) % num_delay_step]``.

Without the modulo, when the read index wraps past the end of the buffer JAX
clamps the out-of-bounds index to the last slot and returns a stale value.
We feed a long ramp (many wraps) and check the exact-step (no-interp) reads.
"""
dt = 0.1
delay_len = 1.0 # exact multiple of dt -> exact-step (``_true_fn``) branch
d = TimeDelay(bm.zeros(1), delay_len=delay_len, dt=dt, before_t0=lambda t: t)
# ``num_delay_step == 11``; iterate well past one full wrap of the buffer.
n = 37
for i in range(n):
d.update(bm.asarray([float(i)]))
ct = float(d.current_time[0])
last = n - 1 # the most recently stored ramp value
# delay d_ms -> value stored ``round(d_ms/dt)`` steps before ``last``.
for d_ms in [0.0, 0.1, 0.3, 0.5, 1.0]:
got = float(d(ct - d_ms)[0])
expected = last - round(d_ms / dt)
assert abs(got - expected) < 1e-4, (d_ms, got, expected)


def test_length_delay_ramp_matches_reference():
for method in (ROTATE_UPDATE, CONCAT_UPDATE):
d = LengthDelay(bm.zeros(1), delay_len=5, update_method=method)
for i in range(23): # many wraps for the rotate buffer (len 6)
d.update(bm.asarray([float(i)]))
got = [float(d(k)[0]) for k in range(6)]
assert got == [22 - k for k in range(6)], (method, got)
58 changes: 58 additions & 0 deletions brainpy/math/event/csr_matmat_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
"""Regression tests for ``brainpy/math/event/csr_matmat.py``.

Guards the event-driven (binary) CSR matmat, especially the ``transpose=True``
branch (must compute ``Aᵀ @ E``), against a dense numpy reference.
"""

import jax.numpy as jnp
import numpy as np

import brainevent

from brainpy.math.event.csr_matmat import csrmm


_ROWS = np.array([0, 0, 1, 2, 2])
_COLS = np.array([1, 3, 0, 1, 3])
_VALS = np.array([2., 4., 1., 3., 2.])
_SHAPE = (3, 4)


def _dense():
m = np.zeros(_SHAPE, dtype=np.float32)
for v, r, c in zip(_VALS, _ROWS, _COLS):
m[r, c] = v
return m


def _csr():
indptr, indices, order = brainevent.coo2csr(_ROWS, _COLS, shape=_SHAPE)
data = jnp.asarray(_VALS)[np.asarray(order)]
return data, np.asarray(indices), np.asarray(indptr)


def test_event_csrmm_no_transpose_matches_dense():
data, indices, indptr = _csr()
E = np.array([[True, False], [False, True], [True, True], [False, False]])
out = np.asarray(csrmm(data, indices, indptr, jnp.asarray(E), shape=_SHAPE, transpose=False))
assert out.shape == (3, 2)
np.testing.assert_allclose(out, _dense() @ E.astype(np.float32), rtol=1e-5, atol=1e-5)


def test_event_csrmm_transpose_matches_dense():
# transpose=True must compute Aᵀ @ E (Aᵀ is (4,3), E is (3,2)).
data, indices, indptr = _csr()
E = np.array([[True, False], [False, True], [True, True]])
out = np.asarray(csrmm(data, indices, indptr, jnp.asarray(E), shape=_SHAPE, transpose=True))
assert out.shape == (4, 2)
np.testing.assert_allclose(out, _dense().T @ E.astype(np.float32), rtol=1e-5, atol=1e-5)


def test_event_csrmm_matches_float_csrmm_with_binary_input():
# An all-True event matrix multiplied by the binary path equals the dense
# product restricted to the selected entries.
data, indices, indptr = _csr()
E = np.array([[True, True], [True, True], [True, True], [True, True]])
out = np.asarray(csrmm(data, indices, indptr, jnp.asarray(E), shape=_SHAPE, transpose=False))
np.testing.assert_allclose(out, _dense() @ E.astype(np.float32), rtol=1e-5, atol=1e-5)
75 changes: 75 additions & 0 deletions brainpy/math/pre_syn_post_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
"""Regression tests for ``brainpy/math/pre_syn_post.py``.

Targets the event-driven CSR routing of ``pre2post_event_sum`` (which delegates
to ``event.csrmv(transpose=True)``) and the empty-group / structural guards of
``syn2post_mean`` and ``syn2post_softmax``.
"""

import numpy as np

import brainpy.math as bm
from brainpy.math.pre_syn_post import (
pre2post_event_sum,
syn2post_sum,
syn2post_mean,
syn2post_softmax,
)


# pre_num=3, post_num=4 CSR: pre0 -> {1,3}, pre1 -> {0}, pre2 -> {1,3}
_INDICES = np.array([1, 3, 0, 1, 3])
_INDPTR = np.array([0, 2, 3, 5])
_POST_NUM = 4


def test_pre2post_event_sum_scalar_value():
events = np.array([True, False, True]) # pre 0 and 2 fire
out = np.asarray(pre2post_event_sum(events, (_INDICES, _INDPTR), _POST_NUM, values=1.))
np.testing.assert_array_equal(out, [0., 2., 0., 2.])


def test_pre2post_event_sum_vector_value():
events = np.array([True, False, True])
vals = np.array([10., 20., 30., 40., 50.])
out = np.asarray(pre2post_event_sum(events, (_INDICES, _INDPTR), _POST_NUM, values=vals))
# pre0: post1+=10, post3+=20; pre2: post1+=40, post3+=50
np.testing.assert_array_equal(out, [0., 50., 0., 70.])


def test_pre2post_event_sum_matches_dense_transpose():
# equivalent dense Aᵀ @ events, with A (pre_num x post_num) of all-ones weights
events = np.array([True, True, False])
A = np.zeros((3, _POST_NUM), dtype=np.float32)
for pre in range(3):
for j in range(_INDPTR[pre], _INDPTR[pre + 1]):
A[pre, _INDICES[j]] = 1.0
out = np.asarray(pre2post_event_sum(events, (_INDICES, _INDPTR), _POST_NUM, values=1.))
np.testing.assert_allclose(out, A.T @ events.astype(np.float32), rtol=1e-5, atol=1e-5)


def test_syn2post_sum_matches_reference():
syn = np.array([1., 2., 3., 4.])
post_ids = np.array([0, 0, 2, 2])
out = np.asarray(syn2post_sum(syn, post_ids, 3))
np.testing.assert_array_equal(out, [3., 0., 7.])


def test_syn2post_mean_empty_group_is_zero_not_nan():
syn = np.array([2., 4., 6.])
post_ids = np.array([0, 0, 2]) # group 1 is empty
out = np.asarray(syn2post_mean(syn, post_ids, 3))
assert not np.any(np.isnan(out))
np.testing.assert_allclose(out, [3., 0., 6.], rtol=1e-6, atol=1e-6)


def test_syn2post_softmax_normalizes_per_group():
syn = np.array([1., 2., 3., 4.])
post_ids = np.array([0, 0, 1, 1])
out = np.asarray(syn2post_softmax(syn, post_ids, 2))
# within each post group the softmax weights sum to 1
np.testing.assert_allclose(out[:2].sum(), 1.0, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(out[2:].sum(), 1.0, rtol=1e-5, atol=1e-5)
# values match a manual softmax of [1,2] and [3,4]
s01 = np.exp([1., 2.] - np.max([1., 2.])); s01 /= s01.sum()
np.testing.assert_allclose(out[:2], s01, rtol=1e-5, atol=1e-5)
65 changes: 65 additions & 0 deletions brainpy/math/sparse/coo_mv_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
"""Regression tests for ``brainpy/math/sparse/coo_mv.py``.

``coomv`` converts COO indices to CSR (``brainevent.coo2csr``) before delegating
to ``brainevent.CSR``. These tests check both orientations and the scalar-weight
broadcast path against a dense numpy reference, with unsorted COO triples (so the
``coo2csr`` permutation of ``data`` is exercised).
"""

import jax
import jax.numpy as jnp
import numpy as np

from brainpy.math.sparse.coo_mv import coomv


# Deliberately UNSORTED COO triples for a 3 x 4 matrix:
# [[0, 2, 0, 4],
# [1, 0, 0, 0],
# [0, 3, 0, 2]]
_ROWS = np.array([2, 0, 1, 0, 2])
_COLS = np.array([1, 1, 0, 3, 3])
_VALS = np.array([3., 2., 1., 4., 2.])
_SHAPE = (3, 4)


def _dense():
m = np.zeros(_SHAPE, dtype=np.float32)
for v, r, c in zip(_VALS, _ROWS, _COLS):
m[r, c] = v
return m


def test_coomv_no_transpose_matches_dense():
v = jnp.arange(4, dtype=jnp.float32)
out = np.asarray(coomv(_VALS, _ROWS, _COLS, v, shape=_SHAPE, transpose=False))
assert out.shape == (3,)
np.testing.assert_allclose(out, _dense() @ np.asarray(v), rtol=1e-5, atol=1e-5)


def test_coomv_transpose_matches_dense():
v = jnp.arange(3, dtype=jnp.float32)
out = np.asarray(coomv(_VALS, _ROWS, _COLS, v, shape=_SHAPE, transpose=True))
assert out.shape == (4,)
np.testing.assert_allclose(out, _dense().T @ np.asarray(v), rtol=1e-5, atol=1e-5)


def test_coomv_scalar_weight_broadcast():
# scalar weight -> every stored entry uses the same value.
v = jnp.arange(4, dtype=jnp.float32)
out = np.asarray(coomv(2.0, _ROWS, _COLS, v, shape=_SHAPE, transpose=False))
ref = np.zeros(_SHAPE, dtype=np.float32)
ref[_ROWS, _COLS] = 2.0
np.testing.assert_allclose(out, ref @ np.asarray(v), rtol=1e-5, atol=1e-5)


def test_coomv_grad_scalar_weight():
v = jnp.arange(4, dtype=jnp.float32)

def f(s):
return coomv(s, _ROWS, _COLS, v, shape=_SHAPE, transpose=False).sum()

g = float(jax.grad(f)(2.0))
# d/ds sum(A(s) @ v) = sum over stored entries of v[col]
np.testing.assert_allclose(g, float(jnp.asarray(v)[_COLS].sum()), rtol=1e-5, atol=1e-5)
76 changes: 76 additions & 0 deletions brainpy/math/sparse/csr_mm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
"""Regression tests for ``brainpy/math/sparse/csr_mm.py``.

Guards the ``transpose=True`` branch of :func:`csrmm` (must compute ``Aᵀ @ B``,
not ``B @ A``) against a dense numpy reference, including its autodiff.
"""

import jax
import jax.numpy as jnp
import numpy as np

import brainevent

from brainpy.math.sparse.csr_mm import csrmm


# 3 x 4 sparse matrix:
# [[0, 2, 0, 4],
# [1, 0, 0, 0],
# [0, 3, 0, 2]]
_ROWS = np.array([0, 0, 1, 2, 2])
_COLS = np.array([1, 3, 0, 1, 3])
_VALS = np.array([2., 4., 1., 3., 2.])
_SHAPE = (3, 4)


def _dense():
m = np.zeros(_SHAPE, dtype=np.float32)
for v, r, c in zip(_VALS, _ROWS, _COLS):
m[r, c] = v
return m


def _csr():
indptr, indices, order = brainevent.coo2csr(_ROWS, _COLS, shape=_SHAPE)
data = jnp.asarray(_VALS)[np.asarray(order)]
return data, np.asarray(indices), np.asarray(indptr)


def test_csrmm_no_transpose_matches_dense():
data, indices, indptr = _csr()
B = jnp.arange(4 * 2, dtype=jnp.float32).reshape(4, 2)
out = np.asarray(csrmm(data, indices, indptr, B, shape=_SHAPE, transpose=False))
assert out.shape == (3, 2)
np.testing.assert_allclose(out, _dense() @ np.asarray(B), rtol=1e-5, atol=1e-5)


def test_csrmm_transpose_matches_dense():
# transpose=True must compute Aᵀ @ B, where Aᵀ is (4, 3) and B is (3, 2).
data, indices, indptr = _csr()
B = jnp.arange(3 * 2, dtype=jnp.float32).reshape(3, 2)
out = np.asarray(csrmm(data, indices, indptr, B, shape=_SHAPE, transpose=True))
assert out.shape == (4, 2)
np.testing.assert_allclose(out, _dense().T @ np.asarray(B), rtol=1e-5, atol=1e-5)


def test_csrmm_transpose_grad_matches_dense():
data, indices, indptr = _csr()
B = jnp.arange(3 * 2, dtype=jnp.float32).reshape(3, 2)

def f(d):
return csrmm(d, indices, indptr, B, shape=_SHAPE, transpose=True).sum()

g = np.asarray(jax.grad(f)(_csr()[0]))
# dense reference gradient wrt the stored values
dense_ref = _dense()

def fd(flat):
m = jnp.zeros(_SHAPE, dtype=jnp.float32)
m = m.at[_ROWS, _COLS].set(flat)
return (m.T @ B).sum()

# values in CSR order correspond to coo2csr ``order``
_, _, order = brainevent.coo2csr(_ROWS, _COLS, shape=_SHAPE)
g_ref = np.asarray(jax.grad(fd)(jnp.asarray(_VALS)))[np.asarray(order)]
np.testing.assert_allclose(g, g_ref, rtol=1e-5, atol=1e-5)
49 changes: 48 additions & 1 deletion brainpy/math/sparse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,57 @@ def coo_to_csr(
*,
num_row: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""convert pre_ids, post_ids to (indices, indptr)."""
"""Convert COO ``(pre_ids, post_ids)`` connectivity to CSR ``(indices, indptr)``.

Parameters
----------
pre_ids : ndarray
Row (pre-synaptic) index of each non-zero entry. Every value must be in
``[0, num_row)``.
post_ids : ndarray
Column (post-synaptic) index of each non-zero entry, aligned with
``pre_ids``.
num_row : int
Number of rows of the sparse matrix (``shape[0]``).

Returns
-------
indices : ndarray
CSR column indices of shape ``(nse,)``.
indptr : ndarray
CSR row pointers of shape ``(num_row + 1,)`` and dtype ``int32``.

Raises
------
ValueError
If any ``pre_ids`` falls outside ``[0, num_row)``. Such an entry would
otherwise be silently dropped from ``indptr`` (its scatter index is
out-of-bounds), producing a structurally invalid CSR in which
``indptr[-1] != len(indices)``.

Notes
-----
This is an eager preprocessing helper: it relies on ``jnp.unique`` (whose
output size is data-dependent) and therefore cannot be traced under
``jit``/``vmap``.
"""
pre_ids = as_jax(pre_ids)
post_ids = as_jax(post_ids)

# Validate the pre (row) indices eagerly. An out-of-range ``pre_id`` would be
# silently dropped by the out-of-bounds ``.at[].set`` scatter below, yielding
# a corrupt CSR (``indptr[-1] != nse``) instead of an error. ``coo_to_csr``
# already cannot be ``jit``-traced (``jnp.unique``), so this concrete check
# does not regress any JAX transformation behaviour.
if pre_ids.size > 0:
pre_min = int(jnp.min(pre_ids))
pre_max = int(jnp.max(pre_ids))
if pre_min < 0 or pre_max >= num_row:
raise ValueError(
f'"pre_ids" must lie in [0, num_row) = [0, {num_row}), '
f'but got values in [{pre_min}, {pre_max}].'
)

# sorting
sort_ids = jnp.argsort(pre_ids, stable=True)
post_ids = post_ids[sort_ids]
Expand Down
Loading
Loading