Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions brainpy/dynold/experimental/abstract_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,14 @@ def update(self, pre_spike, post_v=None):
self.conn_mask[1],
s,
shape=(self.pre_num, self.post_num),
transpose=True,
method='cusparse')
transpose=True)
if isinstance(self.mode, bm.BatchingMode):
f = vmap(f)
post_vs = f(pre_spike)
# ``f`` is fed ``syn_value`` (the STP-filtered drive when an stp
# component is present) so short-term plasticity actually affects
# the sparse conductance; the event-based no-stp branch above
# consumes the boolean ``pre_spike`` directly.
post_vs = f(syn_value)
else:
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

Expand Down Expand Up @@ -294,8 +297,7 @@ def update(self, pre_spike, post_v=None):
self.conn_mask[1],
s,
shape=(self.conn.pre_num, self.conn.post_num),
transpose=True,
method='cusparse'
transpose=True
)
if isinstance(self.mode, bm.BatchingMode):
f = vmap(f)
Expand Down
49 changes: 36 additions & 13 deletions brainpy/dynold/experimental/abstract_synapses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,36 @@ def test_sparse_no_stp_batching(self):
r = bm.as_jax(syn.update(bm.ones((2, 5))))
self.assertEqual(r.shape, (2, 4))

def test_sparse_with_stp_defect(self):
# NOTE: DEFECT -- the sparse + stp path calls
# bm.sparse.csrmv(..., method='cusparse'); csrmv no longer accepts
# a `method` kwarg, so this raises TypeError.
def test_sparse_with_stp(self):
# P11-H1/H2 regression: the sparse + stp path used to (a) call
# bm.sparse.csrmv(..., method='cusparse') -> TypeError, and (b) feed the
# raw pre_spike (not the STP-filtered syn_value) into csrmv. It must now
# run and return a finite, post-shaped conductance.
import numpy as np
conn = bp.conn.FixedProb(0.5)(pre_size=5, post_size=4)
syn = asyn.Exponential(conn, comp_method='sparse', stp=syn_plasticity.STD(5))
share.save(t=0.0, dt=bm.get_dt())
with self.assertRaises(TypeError):
syn.update(bm.ones(5, dtype=bool))
r = _step(syn, bm.ones(5, dtype=bool))
self.assertEqual(r.shape, (4,))
self.assertTrue(np.all(np.isfinite(np.asarray(r))))

def test_sparse_stp_filters_conductance(self):
# P11-H2 regression: with STD attached, the depressed (filtered) drive
# must produce a strictly smaller conductance than with no STP, on the
# same sparse connectivity / spikes. (Previously STP was ignored on the
# sparse path, so the two were identical.)
import numpy as np
conn = bp.conn.FixedProb(1.0)(pre_size=5, post_size=4)
spikes = bm.ones(5, dtype=bool)

syn_no = asyn.Exponential(conn, comp_method='sparse')
r_no = np.asarray(_step(syn_no, spikes, n=1))

syn_stp = asyn.Exponential(conn, comp_method='sparse',
stp=syn_plasticity.STD(5, U=0.5))
r_stp = np.asarray(_step(syn_stp, spikes, n=1))

self.assertTrue(np.all(r_stp <= r_no + 1e-6))
self.assertTrue(np.any(r_stp < r_no - 1e-6))

def test_reset_state(self):
conn = bp.conn.All2All()(pre_size=4, post_size=4)
Expand Down Expand Up @@ -188,15 +209,17 @@ def test_dh_dg_rhs(self):
g = bm.ones(3)
np.testing.assert_allclose(bm.as_jax(syn.dg(g, 0., h)), -bm.as_jax(g) / 10. + bm.as_jax(h))

def test_sparse_defect(self):
# NOTE: DEFECT -- DualExponential's sparse path always calls
def test_sparse(self):
# P11-H1 regression: DualExponential's sparse path used to always call
# bm.sparse.csrmv(..., method='cusparse'); csrmv no longer accepts a
# `method` kwarg, so this raises TypeError.
# `method` kwarg, so this raised TypeError. It must now run and return a
# finite, post-shaped conductance.
import numpy as np
conn = bp.conn.FixedProb(0.5)(pre_size=5, post_size=4)
syn = asyn.DualExponential(conn, comp_method='sparse')
share.save(t=0.0, dt=bm.get_dt())
with self.assertRaises(TypeError):
syn.update(bm.ones(5))
r = _step(syn, bm.ones(5))
self.assertEqual(r.shape, (4,))
self.assertTrue(np.all(np.isfinite(np.asarray(r))))

def test_reset_state(self):
conn = bp.conn.All2All()(pre_size=4, post_size=4)
Expand Down
16 changes: 11 additions & 5 deletions brainpy/dynold/experimental/syn_plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def reset_state(self, batch_size=None):

def update(self, pre_spike):
x = self.integral(self.x.value, share.load('t'), share.load('dt'))
self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
# The depression jump must be applied to the value *at spike arrival*,
# i.e. the recovered/decayed local ``x`` (= x^-), not the pre-decay
# ``self.x`` from the previous step (P11-M1).
self.x.value = bm.where(pre_spike, x - self.U * x, x)
return self.x.value


Expand Down Expand Up @@ -166,16 +169,19 @@ def __init__(
self.reset_state(self.mode)

def reset_state(self, batch_size=None):
self.x = variable_(jnp.ones, batch_size, self.num)
self.u = variable_(OneInit(self.U), batch_size, self.num)
self.x = variable_(jnp.ones, self.num, batch_size)
self.u = variable_(OneInit(self.U), self.num, batch_size)

du = lambda self, u, t: self.U - u / self.tau_f
dx = lambda self, x, t: (1 - x) / self.tau_d

def update(self, pre_spike):
u, x = self.integral(self.u.value, self.x.value, share.load('t'), bm.get_dt())
u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
x = bm.where(pre_spike, x - u * self.x, x)
# Tsodyks-Markram jumps act on the values *at spike arrival* (the decayed
# locals u^-/x^-), and the depression of x uses the facilitated u^+
# (P11-M1): u^+ = u^- + U(1 - u^-); x^+ = x^- - u^+ x^-.
u = bm.where(pre_spike, u + self.U * (1 - u), u)
x = bm.where(pre_spike, x - u * x, x)
self.x.value = x
self.u.value = u
return self.x.value * self.u.value
61 changes: 29 additions & 32 deletions brainpy/dynold/experimental/syn_plasticity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,10 @@
Exercises the experimental short-term plasticity components ``STD`` (fully
functional) and ``STP``.

.. note::

``STP`` is currently **unconstructable**: its ``reset_state`` calls
``variable_(jnp.ones, batch_size, self.num)`` with the ``batch_or_mode``
and ``sizes`` arguments swapped relative to ``STD``. When ``__init__``
calls ``reset_state(self.mode)`` the ``Mode`` object lands in the
``sizes`` slot and ``to_size`` raises ``ValueError: Cannot make a size
for NonBatchingMode``. The DEFECT is pinned in
``TestSTP.test_stp_construction_is_broken`` below; the rest of STP's
behaviour (the ``du``/``dx`` ODE RHS and ``update``) is exercised through
a manually corrected instance.
``STP`` previously could not be constructed: its ``reset_state`` called
``variable_(jnp.ones, batch_size, self.num)`` with the ``batch_or_mode`` and
``sizes`` arguments swapped relative to ``STD`` (P11-C1). That is now fixed and
the construction / update behaviour is exercised directly below.
"""

import unittest
Expand Down Expand Up @@ -88,29 +81,33 @@ def setUp(self):
bm.random.seed(123)
bm.set_dt(0.1)

def test_stp_construction_is_broken(self):
# NOTE: DEFECT -- STP.reset_state has swapped (batch_or_mode, sizes)
# arguments to variable_, so constructing STP with the default
# NonBatchingMode raises ValueError ("Cannot make a size for ...Mode").
# STD.reset_state uses the correct order; STP should mirror it.
with self.assertRaises(ValueError):
sp.STP(4)
def test_stp_construction_ok(self):
# P11-C1 regression: STP.reset_state used to pass (batch_or_mode, sizes)
# to variable_ in the wrong order, so constructing STP with the default
# NonBatchingMode raised ValueError ("Cannot make a size for ...Mode").
# It must now construct cleanly with x=ones, u=U.
stp = sp.STP(4, U=0.15, tau_f=1500., tau_d=200.)
self.assertEqual(stp.num, 4)
self.assertEqual(stp.x.shape, (4,))
self.assertEqual(stp.u.shape, (4,))
np.testing.assert_allclose(bm.as_jax(stp.x.value), np.ones(4))
np.testing.assert_allclose(bm.as_jax(stp.u.value), np.full(4, 0.15))

def test_stp_update_state_changes(self):
# P11-C1 regression: a constructed STP must update without error and
# respond to a presynaptic spike (u facilitates, x depresses).
stp = sp.STP(4, U=0.15, tau_f=1500., tau_d=200.)
share.save(t=0.0, dt=bm.dt)
x_before = bm.as_jax(stp.x.value).copy()
u_before = bm.as_jax(stp.u.value).copy()
r = bm.as_jax(stp.update(bm.ones(4, dtype=bool)))
self.assertEqual(r.shape, (4,))
self.assertTrue(np.all(bm.as_jax(stp.u.value) >= u_before - 1e-6))
Comment on lines +96 to +105

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (testing): The regression test for STP update only checks weak inequalities and could pass even if the spike has no effect.

In test_stp_update_state_changes, the checks u >= u_before - 1e-6 and x <= x_before + 1e-6 would still pass if u and x never change, so the test doesn’t actually guarantee that spikes modify the state. To make this a true regression, consider either (1) asserting that at least one element satisfies u > u_before + eps and x < x_before - eps for a small eps, or (2) checking against the analytically expected Tsodyks–Markram jump given U, tau_f, and tau_d, so the test will fail if STP stops responding to presynaptic spikes.

self.assertTrue(np.all(bm.as_jax(stp.x.value) <= x_before + 1e-6))

def _make_stp(self, num=4, U=0.15, tau_f=1500., tau_d=200.):
"""Build a working STP instance, working around the reset_state defect."""
stp = sp.STP.__new__(sp.STP)
SynSTPNS.__init__(stp)
stp.pre_size = tools.to_size(num)
stp.num = tools.size2num(stp.pre_size)
stp.tau_f = parameter(tau_f, stp.num)
stp.tau_d = parameter(tau_d, stp.num)
stp.U = parameter(U, stp.num)
stp.method = 'exp_auto'
stp.integral = odeint(JointEq([stp.du, stp.dx]), method=stp.method)
# correct argument order (mirrors STD.reset_state)
stp.x = variable_(jnp.ones, stp.num, None)
stp.u = variable_(OneInit(stp.U), stp.num, None)
return stp
"""Build an STP instance directly (now that the constructor works)."""
return sp.STP(num, U=U, tau_f=tau_f, tau_d=tau_d)

def test_du_dx_rhs(self):
stp = self._make_stp()
Expand Down
4 changes: 2 additions & 2 deletions brainpy/dynold/neurons/reduced_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,7 +1309,7 @@ def __init__(

# initializers
V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.),
a_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.),
a_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),

# parameter for training
spike_fun: Callable = bm.surrogate.relu_grad,
Expand Down Expand Up @@ -1472,7 +1472,7 @@ def __init__(

# initializers
V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.),
a_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.),
a_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),

# parameter for training
spike_fun: Callable = bm.surrogate.relu_grad,
Expand Down
23 changes: 23 additions & 0 deletions brainpy/dynold/neurons/reduced_neurons_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,26 @@ def test_training_shape(self, neuron):
progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10))


class TestBellecAdaptation(parameterized.TestCase):
"""P11-M2 regression: the SFA adaptation variable ``a`` must start at rest.

The threshold adaptation contributes ``beta * a`` to the effective firing
threshold (``V_th + beta * a``). The historical default ``OneInit(-50.)``
started ``a`` deeply negative, dropping the effective threshold by tens of
mV for thousands of ms and making a cold-started neuron fire spuriously.
The default must be a rest value (~0).
"""

@parameterized.named_parameters(
{'testcase_name': 'ALIFBellec2020', 'neuron': 'ALIFBellec2020'},
{'testcase_name': 'LIF_SFA_Bellec2020', 'neuron': 'LIF_SFA_Bellec2020'},
)
def test_default_adaptation_starts_at_rest(self, neuron):
bm.random.seed(0)
model = getattr(reduced_models, neuron)(size=4)
model.reset_state()
a0 = bm.as_jax(model.a.value)
# adaptation starts at zero (no spurious sub-threshold offset)
self.assertTrue(bool(bm.all(a0 == 0.)))
12 changes: 9 additions & 3 deletions brainpy/dynold/synplast/short_term_plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def reset_state(self, batch_size=None):

def update(self, pre_spike):
x = self.integral(self.x.value, share['t'], share['dt'])
self.x.value = jnp.where(pre_spike, x - self.U * self.x, x)
# The depression jump must be applied to the value *at spike arrival*,
# i.e. the recovered/decayed local ``x`` (= x^-), not the pre-decay
# ``self.x`` from the previous step (P11-M1).
self.x.value = jnp.where(pre_spike, x - self.U * x, x)

def filter(self, g):
if jnp.shape(g) != self.x.shape:
Expand Down Expand Up @@ -191,8 +194,11 @@ def derivative(self):

def update(self, pre_spike):
u, x = self.integral(self.u.value, self.x.value, share['t'], share['dt'])
u = jnp.where(pre_spike, u + self.U * (1 - self.u), u)
x = jnp.where(pre_spike, x - u * self.x, x)
# Tsodyks-Markram jumps act on the values *at spike arrival* (the decayed
# locals u^-/x^-), and the depression of x uses the facilitated u^+
# (P11-M1): u^+ = u^- + U(1 - u^-); x^+ = x^- - u^+ x^-.
u = jnp.where(pre_spike, u + self.U * (1 - u), u)
x = jnp.where(pre_spike, x - u * x, x)
self.x.value = x
self.u.value = u

Expand Down
119 changes: 119 additions & 0 deletions brainpy/dynold/synplast/short_term_plasticity_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ``brainpy.dynold.synplast.short_term_plasticity`` (STD / STP).

These STP components are attached as the ``stp=`` slot of a ``TwoEndConn``
synapse; ``register_master`` allocates their state from the master's
pre-synaptic group. The regressions here pin P11-M1: the discrete
Tsodyks-Markram jumps must act on the value *at spike arrival* (the decayed
local), not the pre-decay state held over from the previous step.
"""

import unittest

import numpy as np

import brainpy as bp
import brainpy.math as bm
from brainpy.context import share


def _make_std(num=4, tau=200., U=0.07):
"""Build an STD bound to a real master synapse."""
pre = bp.neurons.LIF(num)
post = bp.neurons.LIF(num)
syn = bp.synapses.Exponential(pre, post, bp.connect.One2One(),
stp=bp.synplast.STD(tau=tau, U=U),
comp_method='dense')
return syn.stp


def _make_stp(num=4, U=0.15, tau_f=1500., tau_d=200.):
pre = bp.neurons.LIF(num)
post = bp.neurons.LIF(num)
syn = bp.synapses.Exponential(pre, post, bp.connect.One2One(),
stp=bp.synplast.STP(U=U, tau_f=tau_f, tau_d=tau_d),
comp_method='dense')
return syn.stp


class TestSTD(unittest.TestCase):
def setUp(self):
bm.random.seed(0)
bm.set_dt(0.1)

def test_first_spike_from_rest(self):
std = _make_std(3, tau=200., U=0.07)
share.save(t=0.0, dt=bm.dt)
std.update(bm.ones(3, dtype=bool))
# from rest x=1 -> x^+ = 1 - U = 0.93 (decay over one dt is negligible)
np.testing.assert_allclose(bm.as_jax(std.x.value), np.full(3, 1 - 0.07), atol=2e-3)

def test_jump_uses_decayed_state(self):
# P11-M1: depress, let x recover for one step (no spike), then spike. The
# depression must scale with the *decayed* x (= x^- at spike arrival),
# i.e. x^+ = x_dec - U*x_dec, NOT x_dec - U*x_prev.
U, tau, dt = 0.5, 50., bm.dt
std = _make_std(1, tau=tau, U=U)
share.save(t=0.0, dt=dt)
std.update(bm.ones(1, dtype=bool)) # x drops to ~1-U
x_prev = float(bm.as_jax(std.x.value)[0])
share.save(t=float(dt), dt=dt)
std.update(bm.ones(1, dtype=bool)) # recover one dt, then spike
x_after = float(bm.as_jax(std.x.value)[0])

# decayed value at spike arrival
x_dec = x_prev + (1 - x_prev) / tau * float(dt)
expected_correct = x_dec - U * x_dec
expected_buggy = x_dec - U * x_prev
self.assertAlmostEqual(x_after, expected_correct, places=5)
# the two differ enough (recovery over dt) that the buggy form is rejected
self.assertNotAlmostEqual(expected_correct, expected_buggy, places=7)


class TestSTP(unittest.TestCase):
def setUp(self):
bm.random.seed(0)
bm.set_dt(0.1)

def test_jump_uses_decayed_state(self):
# P11-M1: u^+ = u^- + U(1-u^-) and x^+ = x^- - u^+ x^- must use the
# decayed (current-time) locals, not the previous-step Variables.
U, tau_f, tau_d, dt = 0.5, 100., 50., bm.dt
stp = _make_stp(1, U=U, tau_f=tau_f, tau_d=tau_d)
share.save(t=0.0, dt=dt)
stp.update(bm.ones(1, dtype=bool))
u_prev = float(bm.as_jax(stp.u.value)[0])
x_prev = float(bm.as_jax(stp.x.value)[0])
share.save(t=float(dt), dt=dt)
stp.update(bm.ones(1, dtype=bool))
u_after = float(bm.as_jax(stp.u.value)[0])
x_after = float(bm.as_jax(stp.x.value)[0])

# decayed locals at spike arrival (exp_auto integrates exactly here)
u_dec = u_prev + (U - u_prev / tau_f) * float(dt)
x_dec = x_prev + (1 - x_prev) / tau_d * float(dt)
u_correct = u_dec + U * (1 - u_dec)
x_correct = x_dec - u_correct * x_dec
self.assertAlmostEqual(u_after, u_correct, places=4)
self.assertAlmostEqual(x_after, x_correct, places=4)
# buggy variants (using the pre-decay Variables) must be distinguishable
u_buggy = u_dec + U * (1 - u_prev)
self.assertNotAlmostEqual(u_correct, u_buggy, places=6)


if __name__ == '__main__':
unittest.main()
Loading
Loading