You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TurboQuant TQ4 KV cache compression for Qwen 3.5 MoE (pytorch#18687)
TurboQuant TQ4 KV cache compression for Qwen 3.5 MoE
Add TurboQuant (arXiv 2504.19874) KV cache compression to the CUDA
backend, reducing KV cache memory 3.8x by storing nibble-packed uint8
indices + bf16 norms instead of bf16 tensors. A fused Triton SDPA
kernel decompresses K/V per-tile in the attention inner loop so the
full cache is never materialized.
Components:
- backends/cuda/triton/kernels/tq4_sdpa.py: Fused TQ4 Flash Attention
kernel with Pack GQA optimization (adapted from sdpa.py structure),
precomputed [256]-entry bf16 LUT for zero-bit-op centroid gather,
and norm factoring (norms multiplied on [M,N] QK/P matrices instead
of [N,D] K/V tiles). NaN-safe softmax guards for sparse masks.
Registered as @triton_op for torch.export + CUDA backend lowering.
- extension/llm/modules/turboquant/: TurboQuantKVCache nn.Module with
bf16 compression path and self-contained Lloyd-Max codebook solver
(no external dependencies; scipy used lazily for codebook init only).
- examples/models/qwen3_5_moe/: --turboquant flag in export.py, branch
in FullAttention.forward() between standard SDPA and tq4_sdpa.
- backends/aoti/: Added aoti_torch_dtype_uint8 shim and Byte ScalarType
to slim headers (required for uint8 KV cache tensors in C++ runtime).
Performance (A100, Qwen 3.5 MoE, B=1, GQA 16:2, D=256, seq=4096):
TQ4 SDPA kernel: 0.66ms (was 3.74ms before optimizations)
Baseline bf16: 0.45ms (1.5x overhead for 3.8x memory savings)
Full AOTI path: 0.79ms (inductor fuses compress ops)
The full Qwen 3.5 MoE model with TurboQuant KV cache compression runs at
75% (78->60)of baseline decode speed with 3.8x memory savings (relative
to normal KVcache). End-to-end validated through the C++ runner. But
full attention kvcache is small part of Qwen3.5 MoE because 3/4 of it is
recurrent states. In 200K context length, turboquant saves about 3GB
(4GB -> 1GB) of kvcache.
Copy file name to clipboardExpand all lines: .ci/scripts/test_model_e2e.sh
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -354,7 +354,7 @@ EOF
354
354
fi
355
355
;;
356
356
qwen3_5_moe)
357
-
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 32"
357
+
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0"
0 commit comments