Skip to content

Fix CPU GQA NaN output for right-padded batched prompts with rotary embeddings#29069

Draft
Copilot wants to merge 4 commits into
mainfrom
copilot/fix-wasm-debug-build-job
Draft

Fix CPU GQA NaN output for right-padded batched prompts with rotary embeddings#29069
Copilot wants to merge 4 commits into
mainfrom
copilot/fix-wasm-debug-build-job

Conversation

Copilot AI commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Description

Fixes NaN output in the CPU GQA kernel when running batched right-padded prefill. For padding token positions where seq_causal_length > total_seqlen, the softmax loop was reading beyond the GEMM-filled region of the attention probs buffer into uninitialized memory, producing NaN values that propagated through the V GEMM to the output.

Root cause: In ComputeAttentionProbs, seq_causal_length = causal_past_seqlen + seq + 1 grows with each query position. For right-padded batches, a batch item with real_len < sequence_length has total_seqlen = real_len, but padding positions still iterate up to sequence_length, giving seq_causal_length > total_seqlen. The QK GEMM only fills columns [0, total_seqlen) — positions beyond that are uninitialized.

Fix: Cap the effective causal length at total_seqlen before computing the softmax window:

// gqa_attention_base.h - both float and quantized paths
const size_t effective_causal_length = std::min(seq_causal_length, total_seqlen);
// use effective_causal_length for: local window check, start_offset, window_size, masking loops

Applied to both the non-quantized float path (~line 1097) and the quantized MLAS path (~line 436).

Motivation and Context

The new test GroupQueryAttentionTest.BatchedRightPaddedRotaryPrefill_CPU (added in this PR) exercises batched GQA with heterogeneous real sequence lengths {4, 2, 6} padded to sequence_length=6. Batch item 1 (real_len=2) has padding tokens at positions 2–5; position 3 triggered the NaN via uninitialized attention probs memory.

Copilot AI changed the title [WIP] Fix failing GitHub Actions job for wasm_Debug Fix CPU GQA NaN output for right-padded batched prompts with rotary embeddings Jun 16, 2026
Copilot finished work on behalf of qjia7 June 16, 2026 08:18
Copilot AI requested a review from qjia7 June 16, 2026 08:18
GTEST_SKIP() << "WebGPU EP not available";
}
RunBatchedRightPaddedRotaryPrefillForEP(/*use_cuda=*/false, /*use_webgpu=*/true);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot Please remove BatchedRightPaddedRotaryPrefill_WebGPU test since I will fix it in PR #29002

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed BatchedRightPaddedRotaryPrefill_WebGPU test and cleaned up the now-unused use_webgpu parameter from the helper functions in commit Remove BatchedRightPaddedRotaryPrefill_WebGPU test per review feedback.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a CPU GroupQueryAttention (GQA) correctness bug for batched right-padded prefill with rotary embeddings, where the softmax step could read beyond the QK GEMM-initialized region and propagate NaNs into the output. The fix caps the causal length used to compute the softmax window to the per-batch total_seqlen, and adds a regression test that validates per-batch consistency versus single-prompt execution.

Changes:

  • Cap per-token causal length at total_seqlen in both the float and quantized/MLAS softmax+masking paths to prevent out-of-bounds/uninitialized reads.
  • Add a regression test for batched right-padded packed-QKV rotary prefill, comparing each batch item’s real-last-token output to a batch=1 reference (CPU and CUDA EPs).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Caps the effective causal length used by the softmax/masking window to stay within the GEMM-written [0, total_seqlen) region (float + quantized paths).
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Adds a right-padded batched rotary prefill regression test that compares each prompt’s real-last-token output to a single-prompt reference.

Comment on lines +2565 to +2568
// Each batch's real-last-token output (used to predict next token) must match
// its single-prompt reference. The tolerance is loose enough for fp16 rounding
// while still catching the underflow bug (which produces values that differ
// by orders of magnitude or are NaN/Inf).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants