fix: prevent MilsteinGradFree OOM crash; make GPU test suite hardware-independent#837
Merged
Conversation
…-independent The Linux CI job (`pytest brainpy/`) was OOM-killed (exit 143) at `brainpy/integrators/sde/normal_coverage_test.py` [69%]. Root cause: the vector-Wiener branch of `MilsteinGradFree.step` used the full `(m, m)` diffusion-bar matrix together with `minus * jnp.sum(noise_p2, -1)`, which left the noise dimension in the integrated state. Each step grew the output by two axes — `()` -> `(m, m)` -> `(m, m, m, m)` -> ... — a multi-GB blow-up over a long integration that exhausted the runner's RAM (macOS/Windows runners had enough headroom to survive, so only Linux crashed). Take the diagonal of the diffusion-bar block and contract the per-component Milstein correction over the noise axis, mirroring the shape-preserving pattern already used by `Milstein`. The integrated state now stays correctly shaped and the full suite peaks at a bounded ~9 GB with the crash region flat (no spike). Also make the suite pass on GPU developer machines without changing CPU/CI behaviour: * conftest.py: pin `jax_default_matmul_precision='highest'`. On NVIDIA GPUs the default uses TF32 for float32 matmuls (~1e-4 relative error), which broke the operator-vs-dense correctness comparisons (JIT-connectivity layers, orthonormality checks) on GPU while passing on CPU. * brainpy/dnn/linear_test.py: use float32-appropriate tolerances (`rtol=1e-4, atol=1e-5`) for the JIT-operator vs dense `x @ conn` comparison; the default `atol=1e-8` is tighter than float32 rounding for the near-zero symmetric-uniform outputs. * brainpy/math/object_transform/object_transform_fixes_test.py: guard the `.cuda()` / `.tpu()` `RuntimeError` assertions on device availability so the test stays meaningful on CPU-only CI yet does not fail on GPU/TPU machines.
Reviewer's GuideFixes an OOM crash in MilsteinGradFree’s vector-Wiener path by making the Milstein correction shape-preserving, and adjusts several tests so the suite behaves consistently across CPU- and GPU-equipped machines without changing core CPU/CI behavior. Flow diagram for updated MilsteinGradFree.step vector-Wiener handlingflowchart LR
start([Start step loop for key]) --> noise["Compute noise_p2 = (noise^2 - dt) or noise^2"]
noise --> branch{wiener_type == VECTOR_WIENER?}
branch -->|yes| vec_path["Take diagonal: g_bar = jnp.diagonal(diffusion_bars[key])"]
vec_path --> minus_vec["Compute minus = (g_bar - diffusions[key]) / (2 * sqrt(dt))"]
minus_vec --> update_vec["Update integral += sum(minus * noise_p2, axis=-1)"]
branch -->|no| scalar_path["Compute minus = (diffusion_bars[key] - diffusions[key]) / (2 * sqrt(dt))"]
scalar_path --> update_scalar["Update integral += minus * noise_p2"]
update_vec --> end_node([Accumulate integral for key])
update_scalar --> end_node
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've left some high level feedback:
- The repeated
bm.allclose(..., rtol=1e-4, atol=1e-5)calls and identical explanatory comments across multiple tests inlinear_test.pycould be centralized into a small helper or a shared constant to reduce duplication and keep the tolerance rationale in one place. - In
test_cuda_tpu_raise_without_device, consider lifting_device_availableto module scope (and possibly reusing it in other tests that care about device presence) so its behavior is easier to reuse and reason about outside this single test.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The repeated `bm.allclose(..., rtol=1e-4, atol=1e-5)` calls and identical explanatory comments across multiple tests in `linear_test.py` could be centralized into a small helper or a shared constant to reduce duplication and keep the tolerance rationale in one place.
- In `test_cuda_tpu_raise_without_device`, consider lifting `_device_available` to module scope (and possibly reusing it in other tests that care about device presence) so its behavior is easier to reuse and reason about outside this single test.Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
braintools 0.2.0 was yanked from PyPI, so the previous `braintools>=0.2.0` pin became unsatisfiable and broke `pip install -r requirements.txt` on every CI runner (install failed in ~10s, before any test ran). 0.3.0 is the next released version and carries the surrogate / metric fixes that the `brainpy.math.surrogate` and L1-loss tests assert. Bump the pin in both requirements.txt and pyproject.toml.
brainpy's ein_reduce / ein_rearrange / ein_repeat / ein_shape duplicated the einops implementation that now lives in brainunit.math (einreduce / einrearrange / einrepeat / einshape — behaviour-identical and accepting brainpy ``Array`` instances directly). Re-export the historical ``ein_*`` names as thin aliases of brainunit's and delete the duplicated implementation (einops.py, einops_parsing.py) and its dedicated tests (einops_test.py, einops_coverage_test.py, einops_parsing_test.py). The einops tests in math_compat_fixes_test.py now exercise the public ``bm.ein_*`` aliases and assert the re-export wiring.
The taichi backend is gone; drop the fully-skipped tifunc_test.py (a taichi ``ti.kernel`` test) and the stale ``method`` / Taichi paragraph in the csrmv docstring (that parameter no longer exists — csrmv dispatches through brainevent).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes the
test_linuxCI failure and makes the test suite pass on GPU developer machines without changing CPU/CI behaviour.The crash (primary fix)
test_linux (3.13)was OOM-killed (exit code 143) atbrainpy/integrators/sde/normal_coverage_test.py[69%].Root cause: the vector-Wiener branch of
MilsteinGradFree.stepused the full(m, m)diffusion-bar matrix together withminus * jnp.sum(noise_p2, -1),leaving the noise dimension inside the integrated state. Each step grew the
output by two axes —
()→(m, m)→(m, m, m, m)→ … — a multi-GB blow-upover a long integration that exhausted the runner's RAM. macOS/Windows runners
had enough headroom to survive (4055 passed), so only Linux crashed.
Fix: take the diagonal of the diffusion-bar block and contract the
per-component Milstein correction over the noise axis, mirroring the
shape-preserving pattern already used by
Milstein. The integrated state nowstays correctly shaped; a full single-process run of
brainpy/peaks at abounded ~9 GB with the crash region flat (no spike) and reaches 100%.
Hardware-independent test suite (GPU portability)
jax_default_matmul_precision='highest'. On NVIDIA GPUsthe default uses TF32 for
float32matmuls (~1e-4 relative error), which brokeoperator-vs-dense correctness comparisons (JIT-connectivity layers,
orthonormality checks) on GPU while passing on CPU. No effect on CPU/CI (already
full
float32).(
rtol=1e-4, atol=1e-5) for the JIT-operator vs densex @ conncomparison;the default
atol=1e-8is tighter than float32 rounding for the near-zerosymmetric-uniform outputs.
.cuda()/.tpu()RuntimeErrorassertions on device availability so thetest stays meaningful on CPU-only CI yet does not fail on GPU/TPU machines.
Testing
pytest brainpy/run (GPU, CI-equivalent): 4061 passed,14 skipped, no crash, bounded ~9 GB peak. (The only local failures were in
surrogate/l1_losstests, which are artifacts of a pinnedbraintools<0.2.0dev env; CI uses
braintools>=0.2.0where they pass — macOS and Windows jobsare green.)
MilsteinGradFreenow produces correctly-shaped (non-exploding)output for the vector-Wiener Ito case.
Dependency fix (unblocks CI install)
braintools 0.2.0was yanked from PyPI, so the existingbraintools>=0.2.0pin became unsatisfiable and
pip install -r requirements.txtfailed on everyrunner (~10 s, before any test ran). Bumped the pin to
braintools>=0.3.0(the next released version, which carries the surrogate / metric fixes the
brainpy.math.surrogateand L1-loss tests assert) in bothrequirements.txtand
pyproject.toml. Verified locally with braintools 0.3.0: the surrogate +L1-loss tests that were red under the yanked-version fallback (0.1.10) now pass
(102 passed).
Cleanup: reuse
brainunit.matheinops (drop the local port)brainpy.math'sein_reduce/ein_rearrange/ein_repeat/ein_shapeduplicated an einops implementation that now lives in
brainunit.math(
einreduce/einrearrange/einrepeat/einshape— behaviour-identicaland accepting brainpy
Arrayinstances directly, verified). The historicalein_*names are now thin aliases re-exported frombrainunit.math, and theduplicated implementation (
einops.py,einops_parsing.py) plus its dedicatedtests (
einops_test.py,einops_coverage_test.py,einops_parsing_test.py)are removed. The einops tests in
math_compat_fixes_test.pynow exercise thepublic
bm.ein_*aliases and assert the re-export wiring (87 passed locally).Cleanup: remove taichi / tifunc remnants
The taichi backend is gone; this drops the fully-skipped
tifunc_test.py(a taichi
ti.kerneltest) and the stalemethod/ Taichi paragraph in thecsrmvdocstring (that parameter no longer exists —csrmvdispatches throughbrainevent).