Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion brainpy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,16 @@ def jit_error(pred, err_fun, err_arg=None):
The arguments which passed into `err_f`.
"""
from brainpy.math.interoperability import as_jax

# Fast path for a concrete (non-traced) predicate. ``jax.pure_callback`` only
# surfaces its raised exception when the staged computation is *executed*; for
# an eager, concrete predicate the callback never runs synchronously, so the
# error would be silently swallowed. Evaluate and raise directly instead.
if not isinstance(pred, jax.core.Tracer):
if bool(np.asarray(pred)):
err_fun(err_arg)
return

partial(_cond, err_fun)(as_jax(pred), err_arg)


Expand All @@ -641,7 +651,16 @@ def jit_error_checking_no_args(pred: bool, err: Exception):

assert isinstance(err, Exception), 'Must be instance of Exception.'

def true_err_fun(arg, transforms):
# Fast path for a concrete (non-traced) predicate. The ``jax.pure_callback``
# below only raises when the staged computation is *executed*; eagerly the
# exception is never surfaced synchronously, so an out-of-bound value would
# be silently accepted (e.g. ``is_float(2.0, max_bound=1.0)``). Raise here.
if not isinstance(pred, jax.core.Tracer):
if bool(np.asarray(pred)):
raise err
return

def true_err_fun(*args):
raise err

cond(unvmap(as_jax(pred)),
Expand Down
60 changes: 36 additions & 24 deletions brainpy/check_coverage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
``is_all_vars``, ``is_all_objs``), ``serialize_kwargs``, and the JIT error
helpers (``jit_error``, ``jit_error_checking``, ``jit_error_checking_no_args``).
"""
import jax
import numpy as np
import pytest

Expand Down Expand Up @@ -306,18 +307,18 @@ def test_min_bound_ok(self):
assert checking.is_float(5.0, min_bound=1.0) == 5.0

def test_min_bound_branch(self):
# NOTE: min/max bound checks route through jit_error_checking_no_args
# which uses jax.pure_callback. Eagerly the side-effect raise is NOT
# propagated, so an out-of-bound value returns rather than raising.
# We exercise the branch (line coverage) and pin the no-raise behavior.
assert checking.is_float(0.5, min_bound=1.0, name='v') == 0.5
# P14-H2: eager out-of-bound checks now raise (previously the
# jit_error_checking_no_args pure_callback path silently accepted them).
with pytest.raises(Exception):
checking.is_float(0.5, min_bound=1.0, name='v')

def test_max_bound_ok(self):
assert checking.is_float(5.0, max_bound=10.0) == 5.0

def test_max_bound_branch(self):
# NOTE: see test_min_bound_branch -- eager pure_callback does not raise.
assert checking.is_float(20.0, max_bound=10.0, name='v') == 20.0
# P14-H2: eager out-of-bound checks now raise.
with pytest.raises(Exception):
checking.is_float(20.0, max_bound=10.0, name='v')


# --------------------------------------------------------------------------- #
Expand Down Expand Up @@ -354,16 +355,17 @@ def test_min_bound_ok(self):
assert checking.is_integer(5, min_bound=1) == 5

def test_min_bound_branch(self):
# NOTE: bound check goes through jit_error_checking_no_args (pure_callback);
# eager evaluation does not propagate the raise -- value is returned.
assert checking.is_integer(0, min_bound=1, name='v') == 0
# P14-H2: eager out-of-bound checks now raise.
with pytest.raises(Exception):
checking.is_integer(0, min_bound=1, name='v')

def test_max_bound_ok(self):
assert checking.is_integer(5, max_bound=10) == 5

def test_max_bound_branch(self):
# NOTE: see test_min_bound_branch.
assert checking.is_integer(20, max_bound=10, name='v') == 20
# P14-H2: eager out-of-bound checks now raise.
with pytest.raises(Exception):
checking.is_integer(20, max_bound=10, name='v')


# --------------------------------------------------------------------------- #
Expand Down Expand Up @@ -533,18 +535,26 @@ def test_all_objs_bad(self):
# jit error helpers
# --------------------------------------------------------------------------- #
class TestJitErrors:
# NOTE: these helpers wrap ``jax.lax.cond`` + ``jax.pure_callback``. When the
# predicate is True the error callback executes, but eagerly the raised
# exception is NOT propagated synchronously to the caller (a known quirk of
# pure_callback used for in-jit error signalling). We therefore exercise the
# True/False branches for line coverage and assert they run without crashing.
# P14-H2: for a *concrete* predicate these helpers now raise synchronously
# (previously the ``jax.pure_callback`` path silently swallowed the raise).
# Under tracing they keep the deferred ``cond`` + ``pure_callback`` path.
def test_no_args_pred_false(self):
# pred False -> false branch only
# pred False -> false branch only, no raise
checking.jit_error_checking_no_args(False, ValueError('boom'))

def test_no_args_pred_true(self):
# exercises the true branch (callback path), no propagated raise eagerly
checking.jit_error_checking_no_args(True, ValueError('boom'))
# concrete True -> raises synchronously
with pytest.raises(ValueError):
checking.jit_error_checking_no_args(True, ValueError('boom'))

def test_no_args_under_jit(self):
# tracer predicate -> deferred, no raise at trace time
@jax.jit
def f(x):
checking.jit_error_checking_no_args(x > 1.0, ValueError('boom'))
return x

assert float(f(0.0)) == 0.0

def test_no_args_bad_err(self):
with pytest.raises(AssertionError):
Expand All @@ -561,15 +571,17 @@ def test_jit_error_true(self):
def err_fun(arg):
raise ValueError('boom')

# exercises the true branch + _err_jit_true_branch single-array path
checking.jit_error(True, err_fun, bm.as_jax(bm.zeros(2)))
# concrete True -> raises synchronously
with pytest.raises(ValueError):
checking.jit_error(True, err_fun, bm.as_jax(bm.zeros(2)))

def test_jit_error_true_tuple_arg(self):
def err_fun(arg):
raise ValueError('boom')

# exercises the tuple/list branch of _err_jit_true_branch
checking.jit_error(True, err_fun, (bm.as_jax(bm.zeros(2)), bm.as_jax(bm.ones(3))))
# concrete True with a tuple err_arg -> raises synchronously
with pytest.raises(ValueError):
checking.jit_error(True, err_fun, (bm.as_jax(bm.zeros(2)), bm.as_jax(bm.ones(3))))

def test_alias(self):
assert checking.jit_error_checking is checking.jit_error
53 changes: 53 additions & 0 deletions brainpy/check_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,62 @@
# ==============================================================================
import unittest

import jax

from brainpy import check as checking


class TestBoundChecks(unittest.TestCase):
"""Regression tests for P14-H2: eager bound checks must actually raise.

``is_float``/``is_integer`` route their ``min_bound``/``max_bound`` checks
through ``jit_error_checking_no_args``. The previous implementation used a
``jax.pure_callback`` whose raise never propagated for a *concrete* (eager)
predicate, so out-of-bound values were silently accepted.
"""

def test_is_float_min_bound_raises(self):
with self.assertRaises(Exception):
checking.is_float(0.5, 'v', min_bound=1.0)

def test_is_float_max_bound_raises(self):
with self.assertRaises(Exception):
checking.is_float(20.0, 'v', max_bound=10.0)

def test_is_float_within_bounds_ok(self):
self.assertEqual(checking.is_float(5.0, 'v', min_bound=1.0, max_bound=10.0), 5.0)

def test_is_integer_min_bound_raises(self):
with self.assertRaises(Exception):
checking.is_integer(0, 'v', min_bound=1)

def test_is_integer_max_bound_raises(self):
with self.assertRaises(Exception):
checking.is_integer(20, 'v', max_bound=10)

def test_is_integer_within_bounds_ok(self):
self.assertEqual(checking.is_integer(5, 'v', min_bound=1, max_bound=10), 5)

def test_no_args_concrete_true_raises(self):
with self.assertRaises(ValueError):
checking.jit_error_checking_no_args(True, ValueError('boom'))

def test_no_args_concrete_false_ok(self):
# must not raise
checking.jit_error_checking_no_args(False, ValueError('boom'))

def test_no_args_under_jit_does_not_raise_at_trace(self):
# When the predicate is a tracer (inside jit) the check must NOT raise
# at trace time; it stays a deferred in-jit error signal.
@jax.jit
def f(x):
checking.jit_error_checking_no_args(x > 1.0, ValueError('boom'))
return x

# tracing/compiling with a value that does not trip the predicate runs fine
self.assertEqual(float(f(0.0)), 0.0)


class TestUtils(unittest.TestCase):
def test_check_shape(self):
all_shapes = [
Expand Down
153 changes: 150 additions & 3 deletions brainpy/dyn_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,26 @@
# ==============================================================================
import unittest

import numpy as np

import brainpy as bp
import brainpy.math as bm

# Capture the default ``dt`` before any test mutates it. ``DSRunner(dt=...)``
# permanently writes ``dt`` into the global brainstate environment (via
# ``share.save(dt=...)``), which otherwise leaks ``dt`` into later test files
# (e.g. delay tests that assume the default ``dt=0.1``).
_DEFAULT_DT = bm.get_dt()


class _DtRestoreMixin:
"""Restore the global ``dt`` after each test that runs a ``DSRunner``."""

def tearDown(self):
bm.set_dt(_DEFAULT_DT)

class TestDSRunner(unittest.TestCase):

class TestDSRunner(_DtRestoreMixin, unittest.TestCase):
def test1(self):
class ExampleDS(bp.DynamicalSystem):
def __init__(self):
Expand Down Expand Up @@ -89,5 +104,137 @@ def __init__(self, scale=1.0, method='exp_auto'):
inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2)


class TestMemoryEfficient(unittest.TestCase):
pass
class TestMemoryEfficient(_DtRestoreMixin, unittest.TestCase):
"""Regression tests for ``DSRunner(memory_efficient=True)`` (P14-C1).

The memory-efficient path collects monitors via a host-side callback and
stacks the model's ``update()`` outputs manually. A previous bug used
``list.append`` inside ``tree_map`` (which returns ``None``) so ``.run()``
silently returned ``None`` instead of the time-stacked outputs.
"""

def _scalar_ds(self):
class ExampleDS(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.i = bm.Variable(bm.zeros(1))

def update(self):
self.i += 1.
return self.i.value

return ExampleDS

def test_output_matches_normal_scalar(self):
DS = self._scalar_ds()

out_normal = bp.DSRunner(DS(), dt=1., progress_bar=False,
memory_efficient=False).run(5.)
out_mem = bp.DSRunner(DS(), dt=1., progress_bar=False,
memory_efficient=True).run(5.)

out_normal = np.asarray(out_normal)
out_mem = np.asarray(out_mem)
# the memory-efficient output must not be lost
self.assertIsNotNone(out_mem.dtype)
self.assertEqual(out_normal.shape, out_mem.shape)
self.assertTrue(np.allclose(out_normal, out_mem))
self.assertTrue(np.allclose(out_mem.ravel(), [1., 2., 3., 4., 5.]))

def test_output_matches_normal_pytree(self):
# the update returns a dict (a non-trivial pytree of outputs)
class ExampleDS(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.i = bm.Variable(bm.zeros(2))

def update(self):
self.i += 1.
return {'a': self.i.value, 'b': self.i.value * 2.}

out_normal = bp.DSRunner(ExampleDS(), dt=1., progress_bar=False,
memory_efficient=False).run(4.)
out_mem = bp.DSRunner(ExampleDS(), dt=1., progress_bar=False,
memory_efficient=True).run(4.)

for key in ('a', 'b'):
a = np.asarray(out_normal[key])
b = np.asarray(out_mem[key])
self.assertEqual(a.shape, b.shape)
self.assertEqual(a.shape, (4, 2))
self.assertTrue(np.allclose(a, b))

def test_monitors_still_match(self):
DS = self._scalar_ds()
r_n = bp.DSRunner(DS(), dt=1., monitors=['i'], progress_bar=False,
memory_efficient=False)
r_n.run(5.)
r_m = bp.DSRunner(DS(), dt=1., monitors=['i'], progress_bar=False,
memory_efficient=True)
r_m.run(5.)
self.assertTrue(np.allclose(np.asarray(r_n.mon['i']),
np.asarray(r_m.mon['i'])))
self.assertTrue(np.allclose(np.asarray(r_n.mon['ts']),
np.asarray(r_m.mon['ts'])))

def test_output_none_when_update_returns_none(self):
# an ``update()`` with no explicit return must give ``None`` in both
# paths (the all-``None`` pytree collapses to ``None``), not crash.
class DS(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.i = bm.Variable(bm.zeros(1))

def update(self):
self.i += 1. # returns None

out_n = bp.DSRunner(DS(), dt=1., monitors=['i'], progress_bar=False,
memory_efficient=False).run(5.)
r_m = bp.DSRunner(DS(), dt=1., monitors=['i'], progress_bar=False,
memory_efficient=True)
out_m = r_m.run(5.)
self.assertIsNone(out_n)
self.assertIsNone(out_m)
self.assertTrue(np.allclose(np.asarray(r_m.mon['i']).ravel(),
[1., 2., 3., 4., 5.]))

def test_batched_monitor_axis_matches_normal(self):
"""P14-H1: for BatchingMode + data_first_axis='B' (the default for
batched models) the memory-efficient monitors must come out batch-major
``(B, T, ...)``, identical to the standard path. The bug returned
time-major ``(T, B, ...)`` only for the memory-efficient path."""

class Net(bp.DynamicalSystem):
def __init__(self):
super().__init__(mode=bm.BatchingMode(4))
self.n = bp.dyn.LifRef(3, mode=bm.BatchingMode(4))

def update(self, inp):
self.n(inp)
return self.n.V.value

inp = bm.ones((4, 8, 3)) * 2.0 # (batch, time, features), data_first_axis='B'

bm.random.seed(0)
net = Net(); net.reset(4)
r_n = bp.DSRunner(net, monitors=['n.V'], memory_efficient=False,
progress_bar=False)
out_n = r_n.run(inputs=inp)

bm.random.seed(0)
net2 = Net(); net2.reset(4)
r_m = bp.DSRunner(net2, monitors=['n.V'], memory_efficient=True,
progress_bar=False)
out_m = r_m.run(inputs=inp)

# outputs are batch-major in both paths
self.assertEqual(np.asarray(out_n).shape, np.asarray(out_m).shape)
self.assertEqual(np.asarray(out_m).shape, (4, 8, 3))
self.assertTrue(np.allclose(np.asarray(out_n), np.asarray(out_m)))

# monitors must share the same (batch, time, features) layout
self.assertEqual(np.asarray(r_n.mon['n.V']).shape,
np.asarray(r_m.mon['n.V']).shape)
self.assertEqual(np.asarray(r_m.mon['n.V']).shape, (4, 8, 3))
self.assertTrue(np.allclose(np.asarray(r_n.mon['n.V']),
np.asarray(r_m.mon['n.V'])))
Loading
Loading