Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions brainpy/dnn/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ def test_JitFPHomoLinear(self, prob, weight, shape):
self.assertTrue(y.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
# float32-appropriate tolerances: the JIT operator and the dense ``x @ conn``
# differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that
# for the symmetric-uniform layer whose outputs sit near zero.
self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5))
# print(conn_matrix.shape)
# self.assertTrue(conn_matrix.shape == (200, 100))

Expand All @@ -168,7 +171,10 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape):
self.assertTrue(y.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
# float32-appropriate tolerances: the JIT operator and the dense ``x @ conn``
# differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that
# for the symmetric-uniform layer whose outputs sit near zero.
self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5))

@parameterized.product(
prob=[0.1],
Expand All @@ -184,7 +190,10 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
self.assertTrue(y.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
# float32-appropriate tolerances: the JIT operator and the dense ``x @ conn``
# differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that
# for the symmetric-uniform layer whose outputs sit near zero.
self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5))

@parameterized.product(
prob=[0.1],
Expand All @@ -202,7 +211,10 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape):
self.assertTrue(y2.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
# float32-appropriate tolerances: the JIT operator and the dense ``x @ conn``
# differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that
# for the symmetric-uniform layer whose outputs sit near zero.
self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5))

@parameterized.product(
prob=[0.1],
Expand All @@ -221,7 +233,10 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
self.assertTrue(y2.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
# float32-appropriate tolerances: the JIT operator and the dense ``x @ conn``
# differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that
# for the symmetric-uniform layer whose outputs sit near zero.
self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5))

@parameterized.product(
prob=[0.1],
Expand All @@ -240,7 +255,10 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
self.assertTrue(y2.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
# float32-appropriate tolerances: the JIT operator and the dense ``x @ conn``
# differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that
# for the symmetric-uniform layer whose outputs sit near zero.
self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5))


if __name__ == '__main__':
Expand Down
14 changes: 12 additions & 2 deletions brainpy/integrators/sde/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,20 @@ def step(self, *args, **kwargs):
else:
integral += diffusions[key] * noise
noise_p2 = (noise ** 2 - dt) if self.intg_type == constants.ITO_SDE else noise ** 2
minus = (diffusion_bars[key] - diffusions[key]) / 2 / jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
integral += minus * jnp.sum(noise_p2, axis=-1)
# ``y_bars[key]`` carries the noise axis (one support value per noise
# component ``j``: ``y_bar_j = Y + f dt + g_j sqrt(dt)``), so
# ``diffusion_bars[key]`` has a trailing ``(m, m)`` block whose diagonal is
# ``g_j(y_bar_j)``. Previously the full ``(m, m)`` matrix was used together
# with ``minus * jnp.sum(noise_p2, -1)``, which left the noise dimension in
# the state and grew the output by two axes every step (a multi-GB blow-up
# for long integrations). Take the diagonal and contract the per-component
# Milstein correction over the noise axis, mirroring ``Milstein``.
g_bar = jnp.diagonal(diffusion_bars[key], axis1=-2, axis2=-1)
minus = (g_bar - diffusions[key]) / 2 / jnp.sqrt(dt)
integral += jnp.sum(minus * noise_p2, axis=-1)
else:
minus = (diffusion_bars[key] - diffusions[key]) / 2 / jnp.sqrt(dt)
integral += minus * noise_p2
integrals.append(integral)
return integrals if len(self.variables) > 1 else integrals[0]
Expand Down
9 changes: 8 additions & 1 deletion brainpy/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@
from .compat_tensorflow import *
from .datatypes import *
from .delayvars import *
from .einops import *
# einops-style helpers are reused from ``brainunit.math`` (the local port was
# removed); keep the historical ``ein_*`` names as thin aliases.
from brainunit.math import (
einreduce as ein_reduce,
einrearrange as ein_rearrange,
einrepeat as ein_repeat,
einshape as ein_shape,
)
from .environment import *
from .interoperability import *
# environment settings
Expand Down
Loading
Loading