Memory-bounded sparsify + per-row mass mode#843
Merged
Conversation
Make `BaseDiscreteSolverOutput.sparsify`'s `batch_size` actually bound peak memory, and add a per-row `mode="mass"` that keeps the largest entries capturing a fraction `value` of each row's mass (capped at `max_k`). - Unify the block loop on the source axis so each block holds rows of the transport matrix; this keeps the per-row mass criterion well-defined and peak memory at [batch_size, m]. Element-wise modes (threshold/percentile/ min_row) are unchanged. - `OTTOutput.sparsify` rebuilds an online `PointCloud` geometry with the requested `batch_size` before delegating to the base algorithm, so the online (lazy-cost) Sinkhorn case no longer OOMs when solved with a large `batch_size`. No-op for dense/offline, low-rank, and GW/FGW outputs, whose `apply` is already memory-bounded. Closes #639. Addresses the OOM in scverse/cellrank#1146. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Released ott-jax (<=0.6.0) calls jax.interpreters.batching.is_vmappable, which jax removed from the public namespace after 0.7.2. ott fixed this upstream (ott-jax#701) but has not released it. Re-expose the symbol from jax._src so ott's batched_vmap works, unblocking the test suite. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #843 +/- ##
==========================================
+ Coverage 76.38% 76.62% +0.23%
==========================================
Files 36 36
Lines 4125 4175 +50
Branches 661 670 +9
==========================================
+ Hits 3151 3199 +48
- Misses 678 679 +1
- Partials 296 297 +1
🚀 New features to boost your workflow:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes the out-of-memory failure when materializing a sparse transport matrix from an online (linear Sinkhorn)
PointCloudsolution — the case behind scverse/cellrank#1146 — and adds a predictable per-row sparsification mode.One sparsify algorithm in the base class, identical for every output type; OTT contributes only a geometry-reconfiguration step. Closes #639.
What & why
BaseDiscreteSolverOutput.sparsifymaterializes the transport matrix in row blocks viapush/pull. For an onlinePointCloudgeometry the cost is recomputed lazily duringapply, chunked bygeom.batch_size— which is frozen atsolve()time and cannot be lowered fromsparsify. A large solve-time batch therefore OOMs during sparsify, even on CPU (#639).1.
batch_sizenow bounds peak memory (OTTOutput.sparsify)OTTOutput.sparsifyrebuilds the online geometry with the requestedbatch_size(via the geometry pytree, preservingscale_cost/relative_epsilon/epsilon/cost_fn/x/y), then delegates to the base algorithm. It is a no-op for anything that is not an onlinePointCloudSinkhorn output — the same algorithm, the knob is simply inert where there is no lazy cost to retune.2. New
mode="mass"Per source row, keep the largest entries capturing a fraction
valueof the row's mass, capped atmax_k— bounded, predictable nnz. The block loop is unified on the source axis so each block holds rows ofT; element-wise modes (threshold/percentile/min_row) are unchanged and transpose-invariant.What this does per problem type
batch_size?f,gon lazy/onlinePointCloudQ,R,g[n,m]linearized costLRGWOutput)Tests
tests/solvers/test_base_solver.py): exact-reconstruction oracle (value=1.0, max_k=Nonereproduces the transport matrix), mass retention,max_kcap, validation — on denseMockSolverOutputfor bothn<mandn>m.tests/backends/ott/test_backend.py): online rebatch reconstruction acrossscale_cost ∈ {1.0, 0.5, mean, max_cost, max_norm, max_bound}× balanced/unbalanced; rebatch mechanics; no-op guards for offline /median/ low-rank / GW.Note
The online-
PointCloudbackend tests are skip-guarded: ott-jax ≤ 0.6.0 (and currentlymain) rely onjax.interpreters.batching.is_vmappable, removed in jax ≥ 0.9, which breaks any online apply (see #842). They are validated to pass where the online path works and skip otherwise; the base-class tests cover the core logic regardless.Links
batch_sizeinsparsify#639🤖 Generated with Claude Code