Skip to content

Fix: Resolve KV Cache OOM & Logit Extraction Memory Spikes#678

Open
prince-shakyaa wants to merge 3 commits into
google-deepmind:mainfrom
prince-shakyaa:fix-memory-oom-v3
Open

Fix: Resolve KV Cache OOM & Logit Extraction Memory Spikes#678
prince-shakyaa wants to merge 3 commits into
google-deepmind:mainfrom
prince-shakyaa:fix-memory-oom-v3

Conversation

@prince-shakyaa

@prince-shakyaa prince-shakyaa commented Jun 7, 2026

Copy link
Copy Markdown

Pull Request: Resolve KV Cache OOM & Logit Extraction Memory Spikes

Related Issue

Fixes: #675

Overview

This PR addresses critical architectural limitations in the text sampling backend that cause severe memory inflation and Out-Of-Memory (OOM) errors during long multi-turn conversations and extended context generation.

By introducing a rolling KV cache and optimizing the stateful preservation of prediction logits, we significantly reduce the VRAM overhead per sequence.

Key Changes

1. Rolling Cache via Turn Eviction

File Modified: gemma/gm/text/_chat_sampler.py

  • Introduced a rolling_cache_threshold to the ChatSampler.
  • When the context exceeds this threshold, the oldest conversational turn is evicted from self.turns. The last_state is then invalidated and gracefully re-prefilled against the truncated conversation. This ensures context lengths never surpass our safe cache boundary, eliminating sequential multi-turn OOMs.

2. Ephemeral Top-K Logit Extraction

Files Modified: gemma/gm/text/_sampler_loop.py, gemma/gm/text/_sampler.py, gemma/gm/text/_prefill.py

  • Refactored SamplingState to replace the massive $V$-dimensional vocabulary distribution (predicted_logits) with predicted_top_logits and predicted_top_indices.
  • Added a configuration top_k_logits: int = 0 to the low-level Sampler which pipes down to the SamplerLoop and _prefill.py state allocations.
  • Within the _sample_step, we use jax.lax.top_k immediately against the model emission if top_k_logits > 0. The large vocabulary tensor is then destroyed instead of cached along the sequence length, freeing massive amounts of GPU memory.

Validation & Testing

  • Unit Tests: Ensured existing tests in gemma/gm/text/ continue to pass without regressions.
  • Memory Profiling: Verified via memory tracing that multi-turn generation can now run indefinitely up to the bounds of the cache_length without linearly leaking memory.
  • Correctness: Output token quality remains identical, as the core sampling logic is unchanged.

@prince-shakyaa prince-shakyaa marked this pull request as ready for review June 7, 2026 16:30
@prince-shakyaa

Copy link
Copy Markdown
Author

This PR is now ready for review!
I've implemented the rolling cache logic and top-k logit extraction, and verified locally that it successfully bounds VRAM usage during long multi-turn sessions without causing shape regressions.

Let me know if you need any changes.

Thank You.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Context Exhaustion and VRAM Spikes in KV Cache & SamplerLoop

1 participant