Skip to content

Memory-bounded sparsify + per-row mass mode#843

Merged
selmanozleyen merged 12 commits into
mainfrom
feature/sparsify-batchsize-mass
Jun 11, 2026
Merged

Memory-bounded sparsify + per-row mass mode#843
selmanozleyen merged 12 commits into
mainfrom
feature/sparsify-batchsize-mass

Conversation

@Marius1311

Copy link
Copy Markdown
Collaborator

Summary

Fixes the out-of-memory failure when materializing a sparse transport matrix from an online (linear Sinkhorn) PointCloud solution — the case behind scverse/cellrank#1146 — and adds a predictable per-row sparsification mode.

One sparsify algorithm in the base class, identical for every output type; OTT contributes only a geometry-reconfiguration step. Closes #639.

What & why

BaseDiscreteSolverOutput.sparsify materializes the transport matrix in row blocks via push/pull. For an online PointCloud geometry the cost is recomputed lazily during apply, chunked by geom.batch_size — which is frozen at solve() time and cannot be lowered from sparsify. A large solve-time batch therefore OOMs during sparsify, even on CPU (#639).

1. batch_size now bounds peak memory (OTTOutput.sparsify)

OTTOutput.sparsify rebuilds the online geometry with the requested batch_size (via the geometry pytree, preserving scale_cost/relative_epsilon/epsilon/cost_fn/x/y), then delegates to the base algorithm. It is a no-op for anything that is not an online PointCloud Sinkhorn output — the same algorithm, the knob is simply inert where there is no lazy cost to retune.

2. New mode="mass"

Per source row, keep the largest entries capturing a fraction value of the row's mass, capped at max_k — bounded, predictable nnz. The block loop is unified on the source axis so each block holds rows of T; element-wise modes (threshold/percentile/min_row) are unchanged and transpose-invariant.

What this does per problem type

Problem Rank Coupling stored as online batch_size? rebatch effect
Linear full f,g on lazy/online PointCloud yes lowers it only case changed — bounds the OOM
Linear low factors Q,R,g no no-op already bounded
GW full dense [n,m] linearized cost no no-op dense by construction
GW low factors (LRGWOutput) no no-op already bounded
FGW full as GW-full (+ fused term) no no-op dense
FGW low factors no no-op already bounded

Tests

  • Base (tests/solvers/test_base_solver.py): exact-reconstruction oracle (value=1.0, max_k=None reproduces the transport matrix), mass retention, max_k cap, validation — on dense MockSolverOutput for both n<m and n>m.
  • OTT backend (tests/backends/ott/test_backend.py): online rebatch reconstruction across scale_cost ∈ {1.0, 0.5, mean, max_cost, max_norm, max_bound} × balanced/unbalanced; rebatch mechanics; no-op guards for offline / median / low-rank / GW.

Note

The online-PointCloud backend tests are skip-guarded: ott-jax ≤ 0.6.0 (and currently main) rely on jax.interpreters.batching.is_vmappable, removed in jax ≥ 0.9, which breaks any online apply (see #842). They are validated to pass where the online path works and skip otherwise; the base-class tests cover the core logic regardless.

Links

🤖 Generated with Claude Code

Make `BaseDiscreteSolverOutput.sparsify`'s `batch_size` actually bound peak
memory, and add a per-row `mode="mass"` that keeps the largest entries
capturing a fraction `value` of each row's mass (capped at `max_k`).

- Unify the block loop on the source axis so each block holds rows of the
  transport matrix; this keeps the per-row mass criterion well-defined and
  peak memory at [batch_size, m]. Element-wise modes (threshold/percentile/
  min_row) are unchanged.
- `OTTOutput.sparsify` rebuilds an online `PointCloud` geometry with the
  requested `batch_size` before delegating to the base algorithm, so the
  online (lazy-cost) Sinkhorn case no longer OOMs when solved with a large
  `batch_size`. No-op for dense/offline, low-rank, and GW/FGW outputs, whose
  `apply` is already memory-bounded.

Closes #639. Addresses the OOM in scverse/cellrank#1146.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
selmanozleyen and others added 11 commits June 10, 2026 16:37
Released ott-jax (<=0.6.0) calls jax.interpreters.batching.is_vmappable,
which jax removed from the public namespace after 0.7.2. ott fixed this
upstream (ott-jax#701) but has not released it. Re-expose the symbol from
jax._src so ott's batched_vmap works, unblocking the test suite.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@codecov

codecov Bot commented Jun 11, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 93.54839% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.62%. Comparing base (bacec6c) to head (b7fdcb6).
⚠️ Report is 13 commits behind head on main.

Files with missing lines Patch % Lines
src/moscot/backends/ott/output.py 90.47% 1 Missing and 1 partial ⚠️
src/moscot/base/output.py 95.12% 2 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #843      +/-   ##
==========================================
+ Coverage   76.38%   76.62%   +0.23%     
==========================================
  Files          36       36              
  Lines        4125     4175      +50     
  Branches      661      670       +9     
==========================================
+ Hits         3151     3199      +48     
- Misses        678      679       +1     
- Partials      296      297       +1     
Files with missing lines Coverage Δ
src/moscot/backends/ott/output.py 82.55% <90.47%> (+1.01%) ⬆️
src/moscot/base/output.py 83.57% <95.12%> (+2.89%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@selmanozleyen selmanozleyen merged commit 915643e into main Jun 11, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

change batch_size in sparsify

2 participants