Skip to content

fix: prevent MilsteinGradFree OOM crash; make GPU test suite hardware-independent#837

Merged
chaoming0625 merged 4 commits into
masterfrom
worktree-fix-test-bugs
Jun 18, 2026
Merged

fix: prevent MilsteinGradFree OOM crash; make GPU test suite hardware-independent#837
chaoming0625 merged 4 commits into
masterfrom
worktree-fix-test-bugs

Conversation

@chaoming0625

@chaoming0625 chaoming0625 commented Jun 18, 2026

Copy link
Copy Markdown
Member

Summary

Fixes the test_linux CI failure and makes the test suite pass on GPU developer machines without changing CPU/CI behaviour.

The crash (primary fix)

test_linux (3.13) was OOM-killed (exit code 143) at
brainpy/integrators/sde/normal_coverage_test.py [69%].

Root cause: the vector-Wiener branch of MilsteinGradFree.step used the full
(m, m) diffusion-bar matrix together with minus * jnp.sum(noise_p2, -1),
leaving the noise dimension inside the integrated state. Each step grew the
output by two axes — ()(m, m)(m, m, m, m) → … — a multi-GB blow-up
over a long integration that exhausted the runner's RAM. macOS/Windows runners
had enough headroom to survive (4055 passed), so only Linux crashed.

Fix: take the diagonal of the diffusion-bar block and contract the
per-component Milstein correction over the noise axis, mirroring the
shape-preserving pattern already used by Milstein. The integrated state now
stays correctly shaped; a full single-process run of brainpy/ peaks at a
bounded ~9 GB with the crash region flat (no spike) and reaches 100%.

Hardware-independent test suite (GPU portability)

  • conftest.py — pin jax_default_matmul_precision='highest'. On NVIDIA GPUs
    the default uses TF32 for float32 matmuls (~1e-4 relative error), which broke
    operator-vs-dense correctness comparisons (JIT-connectivity layers,
    orthonormality checks) on GPU while passing on CPU. No effect on CPU/CI (already
    full float32).
  • brainpy/dnn/linear_test.py — float32-appropriate tolerances
    (rtol=1e-4, atol=1e-5) for the JIT-operator vs dense x @ conn comparison;
    the default atol=1e-8 is tighter than float32 rounding for the near-zero
    symmetric-uniform outputs.
  • brainpy/math/object_transform/object_transform_fixes_test.py — guard the
    .cuda() / .tpu() RuntimeError assertions on device availability so the
    test stays meaningful on CPU-only CI yet does not fail on GPU/TPU machines.

Testing

  • Full single-process pytest brainpy/ run (GPU, CI-equivalent): 4061 passed,
    14 skipped
    , no crash, bounded ~9 GB peak. (The only local failures were in
    surrogate/l1_loss tests, which are artifacts of a pinned braintools<0.2.0
    dev env; CI uses braintools>=0.2.0 where they pass — macOS and Windows jobs
    are green.)
  • Verified MilsteinGradFree now produces correctly-shaped (non-exploding)
    output for the vector-Wiener Ito case.

Dependency fix (unblocks CI install)

braintools 0.2.0 was yanked from PyPI, so the existing braintools>=0.2.0
pin became unsatisfiable and pip install -r requirements.txt failed on every
runner (~10 s, before any test ran). Bumped the pin to braintools>=0.3.0
(the next released version, which carries the surrogate / metric fixes the
brainpy.math.surrogate and L1-loss tests assert) in both requirements.txt
and pyproject.toml. Verified locally with braintools 0.3.0: the surrogate +
L1-loss tests that were red under the yanked-version fallback (0.1.10) now pass
(102 passed).

Cleanup: reuse brainunit.math einops (drop the local port)

brainpy.math's ein_reduce / ein_rearrange / ein_repeat / ein_shape
duplicated an einops implementation that now lives in brainunit.math
(einreduce / einrearrange / einrepeat / einshape — behaviour-identical
and accepting brainpy Array instances directly, verified). The historical
ein_* names are now thin aliases re-exported from brainunit.math, and the
duplicated implementation (einops.py, einops_parsing.py) plus its dedicated
tests (einops_test.py, einops_coverage_test.py, einops_parsing_test.py)
are removed. The einops tests in math_compat_fixes_test.py now exercise the
public bm.ein_* aliases and assert the re-export wiring (87 passed locally).

Cleanup: remove taichi / tifunc remnants

The taichi backend is gone; this drops the fully-skipped tifunc_test.py
(a taichi ti.kernel test) and the stale method / Taichi paragraph in the
csrmv docstring (that parameter no longer exists — csrmv dispatches through
brainevent).

…-independent

The Linux CI job (`pytest brainpy/`) was OOM-killed (exit 143) at
`brainpy/integrators/sde/normal_coverage_test.py` [69%]. Root cause: the
vector-Wiener branch of `MilsteinGradFree.step` used the full `(m, m)`
diffusion-bar matrix together with `minus * jnp.sum(noise_p2, -1)`, which left
the noise dimension in the integrated state. Each step grew the output by two
axes — `()` -> `(m, m)` -> `(m, m, m, m)` -> ... — a multi-GB blow-up over a long
integration that exhausted the runner's RAM (macOS/Windows runners had enough
headroom to survive, so only Linux crashed). Take the diagonal of the
diffusion-bar block and contract the per-component Milstein correction over the
noise axis, mirroring the shape-preserving pattern already used by `Milstein`.
The integrated state now stays correctly shaped and the full suite peaks at a
bounded ~9 GB with the crash region flat (no spike).

Also make the suite pass on GPU developer machines without changing CPU/CI
behaviour:

* conftest.py: pin `jax_default_matmul_precision='highest'`. On NVIDIA GPUs the
  default uses TF32 for float32 matmuls (~1e-4 relative error), which broke the
  operator-vs-dense correctness comparisons (JIT-connectivity layers,
  orthonormality checks) on GPU while passing on CPU.
* brainpy/dnn/linear_test.py: use float32-appropriate tolerances
  (`rtol=1e-4, atol=1e-5`) for the JIT-operator vs dense `x @ conn` comparison;
  the default `atol=1e-8` is tighter than float32 rounding for the near-zero
  symmetric-uniform outputs.
* brainpy/math/object_transform/object_transform_fixes_test.py: guard the
  `.cuda()` / `.tpu()` `RuntimeError` assertions on device availability so the
  test stays meaningful on CPU-only CI yet does not fail on GPU/TPU machines.
@sourcery-ai

sourcery-ai Bot commented Jun 18, 2026

Copy link
Copy Markdown

Reviewer's Guide

Fixes an OOM crash in MilsteinGradFree’s vector-Wiener path by making the Milstein correction shape-preserving, and adjusts several tests so the suite behaves consistently across CPU- and GPU-equipped machines without changing core CPU/CI behavior.

Flow diagram for updated MilsteinGradFree.step vector-Wiener handling

flowchart LR
  start([Start step loop for key]) --> noise["Compute noise_p2 = (noise^2 - dt) or noise^2"]
  noise --> branch{wiener_type == VECTOR_WIENER?}

  branch -->|yes| vec_path["Take diagonal: g_bar = jnp.diagonal(diffusion_bars[key])"]
  vec_path --> minus_vec["Compute minus = (g_bar - diffusions[key]) / (2 * sqrt(dt))"]
  minus_vec --> update_vec["Update integral += sum(minus * noise_p2, axis=-1)"]

  branch -->|no| scalar_path["Compute minus = (diffusion_bars[key] - diffusions[key]) / (2 * sqrt(dt))"]
  scalar_path --> update_scalar["Update integral += minus * noise_p2"]

  update_vec --> end_node([Accumulate integral for key])
  update_scalar --> end_node
Loading

File-Level Changes

Change Details Files
Fix MilsteinGradFree vector-Wiener integration to avoid tensor rank blow-up and OOM in long SDE runs.
  • In MilsteinGradFree.step, for VECTOR_WIENER, take the diagonal of the trailing diffusion-bar (m, m) block so each noise component uses its own g_j(y_bar_j).
  • Compute the Milstein correction as (g_bar - diffusion) / (2 * sqrt(dt)) and contract it with noise_p2 over the noise axis using sum, mirroring the scalar/standard Milstein behavior.
  • Keep the original diffusion_bar-based minus term only for the non-vector-Wiener branch so CPU/CI behavior remains unchanged there.
brainpy/integrators/sde/normal.py
Relax linear layer JIT operator vs dense-matrix equality checks to be numerically appropriate for float32 on GPU.
  • Update multiple allclose assertions comparing y to x @ conn_matrix.T to pass explicit rtol=1e-4 and atol=1e-5 tolerances.
  • Document in comments that the looser tolerances account for float32-level differences, especially for near-zero symmetric-uniform outputs where the previous atol=1e-8 was too strict.
brainpy/dnn/linear_test.py
Make CUDA/TPU object-transform tests conditional on actual device availability so they pass on both CPU-only CI and GPU/TPU developer machines.
  • Introduce a helper that queries jax.devices(platform) and returns False on errors to detect whether GPU/TPU is present.
  • Wrap the existing RuntimeError expectations for obj.cuda() and obj.tpu() in guards that only assert the error when the corresponding device type is unavailable.
brainpy/math/object_transform/object_transform_fixes_test.py
Standardize matrix multiplication precision for tests to avoid TF32-induced discrepancies between CPU and NVIDIA GPU runs.
  • Configure the global test environment to set jax_default_matmul_precision to 'highest' so float32 matmuls behave consistently across hardware.
  • Rely on this setting to fix previously flaky operator-vs-dense and orthonormality checks on GPU without affecting CPU behavior.
conftest.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@github-actions github-actions Bot added the tests label Jun 18, 2026

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've left some high level feedback:

  • The repeated bm.allclose(..., rtol=1e-4, atol=1e-5) calls and identical explanatory comments across multiple tests in linear_test.py could be centralized into a small helper or a shared constant to reduce duplication and keep the tolerance rationale in one place.
  • In test_cuda_tpu_raise_without_device, consider lifting _device_available to module scope (and possibly reusing it in other tests that care about device presence) so its behavior is easier to reuse and reason about outside this single test.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- The repeated `bm.allclose(..., rtol=1e-4, atol=1e-5)` calls and identical explanatory comments across multiple tests in `linear_test.py` could be centralized into a small helper or a shared constant to reduce duplication and keep the tolerance rationale in one place.
- In `test_cuda_tpu_raise_without_device`, consider lifting `_device_available` to module scope (and possibly reusing it in other tests that care about device presence) so its behavior is easier to reuse and reason about outside this single test.

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

braintools 0.2.0 was yanked from PyPI, so the previous `braintools>=0.2.0`
pin became unsatisfiable and broke `pip install -r requirements.txt` on every
CI runner (install failed in ~10s, before any test ran). 0.3.0 is the next
released version and carries the surrogate / metric fixes that the
`brainpy.math.surrogate` and L1-loss tests assert. Bump the pin in both
requirements.txt and pyproject.toml.
@github-actions github-actions Bot added dependencies Pull requests that update a dependency file build labels Jun 18, 2026
brainpy's ein_reduce / ein_rearrange / ein_repeat / ein_shape duplicated the
einops implementation that now lives in brainunit.math (einreduce / einrearrange
/ einrepeat / einshape — behaviour-identical and accepting brainpy ``Array``
instances directly). Re-export the historical ``ein_*`` names as thin aliases of
brainunit's and delete the duplicated implementation (einops.py,
einops_parsing.py) and its dedicated tests (einops_test.py,
einops_coverage_test.py, einops_parsing_test.py). The einops tests in
math_compat_fixes_test.py now exercise the public ``bm.ein_*`` aliases and assert
the re-export wiring.
The taichi backend is gone; drop the fully-skipped tifunc_test.py (a taichi
``ti.kernel`` test) and the stale ``method`` / Taichi paragraph in the csrmv
docstring (that parameter no longer exists — csrmv dispatches through brainevent).
@chaoming0625 chaoming0625 merged commit 7471264 into master Jun 18, 2026
14 checks passed
@chaoming0625 chaoming0625 deleted the worktree-fix-test-bugs branch June 18, 2026 17:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

build dependencies Pull requests that update a dependency file tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant