Fix CPU GQA NaN output for right-padded batched prompts with rotary embeddings#29069
Fix CPU GQA NaN output for right-padded batched prompts with rotary embeddings#29069Copilot wants to merge 4 commits into
Conversation
| GTEST_SKIP() << "WebGPU EP not available"; | ||
| } | ||
| RunBatchedRightPaddedRotaryPrefillForEP(/*use_cuda=*/false, /*use_webgpu=*/true); | ||
| } |
There was a problem hiding this comment.
Done — removed BatchedRightPaddedRotaryPrefill_WebGPU test and cleaned up the now-unused use_webgpu parameter from the helper functions in commit Remove BatchedRightPaddedRotaryPrefill_WebGPU test per review feedback.
There was a problem hiding this comment.
Pull request overview
This PR fixes a CPU GroupQueryAttention (GQA) correctness bug for batched right-padded prefill with rotary embeddings, where the softmax step could read beyond the QK GEMM-initialized region and propagate NaNs into the output. The fix caps the causal length used to compute the softmax window to the per-batch total_seqlen, and adds a regression test that validates per-batch consistency versus single-prompt execution.
Changes:
- Cap per-token causal length at
total_seqlenin both the float and quantized/MLAS softmax+masking paths to prevent out-of-bounds/uninitialized reads. - Add a regression test for batched right-padded packed-QKV rotary prefill, comparing each batch item’s real-last-token output to a batch=1 reference (CPU and CUDA EPs).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h |
Caps the effective causal length used by the softmax/masking window to stay within the GEMM-written [0, total_seqlen) region (float + quantized paths). |
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc |
Adds a right-padded batched rotary prefill regression test that compares each prompt’s real-last-token output to a single-prompt reference. |
| // Each batch's real-last-token output (used to predict next token) must match | ||
| // its single-prompt reference. The tolerance is loose enough for fp16 rounding | ||
| // while still catching the underflow bug (which produces values that differ | ||
| // by orders of magnitude or are NaN/Inf). |
Description
Fixes NaN output in the CPU GQA kernel when running batched right-padded prefill. For padding token positions where
seq_causal_length > total_seqlen, the softmax loop was reading beyond the GEMM-filled region of the attention probs buffer into uninitialized memory, producing NaN values that propagated through the V GEMM to the output.Root cause: In
ComputeAttentionProbs,seq_causal_length = causal_past_seqlen + seq + 1grows with each query position. For right-padded batches, a batch item withreal_len < sequence_lengthhastotal_seqlen = real_len, but padding positions still iterate up tosequence_length, givingseq_causal_length > total_seqlen. The QK GEMM only fills columns[0, total_seqlen)— positions beyond that are uninitialized.Fix: Cap the effective causal length at
total_seqlenbefore computing the softmax window:Applied to both the non-quantized float path (~line 1097) and the quantized MLAS path (~line 436).
Motivation and Context
The new test
GroupQueryAttentionTest.BatchedRightPaddedRotaryPrefill_CPU(added in this PR) exercises batched GQA with heterogeneous real sequence lengths{4, 2, 6}padded tosequence_length=6. Batch item 1 (real_len=2) has padding tokens at positions 2–5; position 3 triggered the NaN via uninitialized attention probs memory.