webgpu: fix GQA batched right-padded prefill with do_rotary#29002
webgpu: fix GQA batched right-padded prefill with do_rotary#29002qjia7 wants to merge 5 commits into
Conversation
When GenAI runs a batched prefill with prompts of unequal lengths, short prompts are right-padded up to the batch max sequence_length and each batch's real length is reported via seqlens_k[b] = real_len[b] - 1. The WebGPU rotary embedding shaders computed past_seqlen = (seqlens_k[b]+1) - sequence_length per batch, which underflowed u32 for any batch shorter than sequence_length. The resulting astronomically large position_id indexed past the cos/sin caches and produced garbage rotated Q/K, which manifested as gibberish output text for the shorter batches in the batch. Clamp past_seqlen to 0 in all three rotary embedding shaders: RotaryEmbeddingProgram (seqlens variant), FusedQKRotaryEmbeddingProgram, and the split_packed_qkv_with_rotary_embedding template. Also extend CanApplyFlashAttention to bypass FlashAttention for batched cases with per-batch seqlens (which exercise the unpatched and-copykv variant), while still allowing it for shared-KV layers where it is mandatory. Adds a regression test exercising the packed-QKV do_rotary path with three batches of unequal real lengths.
The test added in the previous commit was scoped to WebGPU, but the property it asserts (each batch's real-last-token output equals the single-prompt reference) is generic and applies to any GQA-supporting EP. CPU and CUDA both support packed-QKV with do_rotary, so generalizing the test gives meaningful cross-EP coverage instead of leaving CPU and CUDA uncovered. Mirror the existing convention in this file: the runner takes bool use_cuda, bool use_webgpu defaulting to false (same as RunGQASharedKV / RunGQASharedKVWithRotary), and three thin TEST cases named _CPU / _CUDA / _WebGPU dispatch the inner helper for each EP with runtime availability checks via GTEST_SKIP. No production code touched. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The CPU variant of the GQA packed-QKV do_rotary right-padded prefill regression test exposes a separate CPU-side issue tracked and fixed in another PR. Drop the _CPU TEST here so this PR's CI stays green while the CPU fix lands independently. The _CUDA and _WebGPU variants remain and continue to exercise the property under test. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Fixes an incorrect RoPE position computation in WebGPU GroupQueryAttention when running batched, right-padded prefill with per-batch seqlens_k (packed QKV + do_rotary), preventing u32 underflow that could index past cos/sin caches and corrupt outputs.
Changes:
- Clamp
past_seqlento0in WebGPU rotary embedding shader codepaths to avoidu32underflow for right-padded batches. - Extend WebGPU FlashAttention gating to opt out when running multi-batch with per-batch
seqlens_k(while still allowing shared-KV /kv_sequence_length==0cases). - Add a regression test that compares each batch’s real-last-token output in a padded multi-batch prefill against its batch=1 reference run.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/contrib_ops/group_query_attention_op_test.cc | Adds a regression test for batched right-padded packed-QKV prefill with do_rotary. |
| onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template | Clamps past_seqlen to prevent u32 underflow in the packed-QKV split + rotary shader. |
| onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc | Clamps past_seqlen to prevent u32 underflow in generated rotary shaders (seqlens and fused QK variants). |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Passes seqlen_k into FlashAttention eligibility check so batched-per-seqlen cases can be gated off. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | Updates CanApplyFlashAttention signature to accept optional seqlen_k. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Implements the new eligibility rule to bypass FlashAttention for batched runs with per-batch seqlen_k (except kv_empty). |
Review of PR #29002 — webgpu: fix GQA batched right-padded prefill with do_rotaryVerdict: approve, with one follow-up worth either folding in or filing as a tracking issue. The diagnosis (u32 underflow in WGSL when Why the fix is correctWGSL
Coverage of shaders is intentional but not exhaustiveThe PR fixes three sites — all the non-flash codepaths reachable from the GQA packed-QKV +
Then it explicitly does not fix the fourth, sibling site: let past_seqlen = total_seqlen - uniforms.sequence_length;
let position_id = past_seqlen + seq_idx;Same underflow pattern, byte for byte. Instead, the PR closes the door by tightening This is pragmatic, but it leaves a latent footgun: the gate's
Either way, please don't let this fix-by-gating live silently — the diff at
|
…otary prefill Apply three reviewer suggestions on the GQA WebGPU right-padded prefill fix: - Clamp the 4th packed-QKV rotary shader site against u32 underflow, byte identical to the sibling shader pattern. - Document the disjunction in CanApplyFlashAttention as a positive contract: FlashAttention does not implement right-padded per-batch prefill, so the first disjunction restricts inputs to shapes where padding cannot occur. - Replace the (bool use_cuda, bool use_webgpu) pair across three test helpers with a single GqaTargetEp enum, route EP construction through a central MakeExecutionProviderForGqaTest helper with an ORT_THROW default so a future enumerator cannot silently fall through to an empty provider vector, and migrate every caller. All 56 GroupQueryAttentionTest cases pass locally (48 OK + 8 CUDA-skipped on a machine without CUDA EP).
- Spelling: "Modelled" -> "Modeled" (US English, per misspell linter) - clang-format: split the GqaTargetEp enum onto one enumerator per line
Summary
past_seqlento 0 in three WebGPU rotary embedding shaders to prevent u32 underflow during batched right-padded prefill (RotaryEmbeddingProgramseqlens variant,FusedQKRotaryEmbeddingProgram, andsplit_packed_qkv_with_rotary_embedding.wgsl.template).CanApplyFlashAttentionto bypass FlashAttention for batched cases with per-batch seqlens (the unpatchedsplit_packed_qkv_with_rotary_embedding_and_copykv.wgsl.templatestill has the underflow), while explicitly allowing it for shared-KV layers where it is mandatory.WebGPU_BatchedRightPaddedRotaryPrefillexercising the packed-QKVdo_rotarypath with three batches of unequal real lengths.Motivation
When GenAI runs a batched prefill with prompts of unequal lengths, shorter prompts are right-padded up to the batch max
sequence_lengthand each batch's real length is reported viaseqlens_k[b] = real_len[b] - 1. The WebGPU rotary embedding shaders computedpast_seqlen = (seqlens_k[b] + 1) - sequence_lengthper batch, which underflowedu32for any batch shorter thansequence_length. The resulting astronomically largeposition_idindexed past the cos/sin caches and produced garbage rotated Q/K, which manifested as gibberish output text for the shorter batches.Reproduced on phi4-prune via three concurrent prompts with unequal lengths: batches 0 and 1 produced output such as
- - - -while batch 2 (longest, no underflow) matched CPU output. After the fix, all three batches produce coherent text matching CPU output.Test plan
GroupQueryAttentionTest.WebGPU_BatchedRightPaddedRotaryPrefillpasses with fix, fails without (max diff 0.0355 in short-batch channel vs 5e-3 tolerance).GroupQueryAttentionTest.WebGPU_*tests pass (including previously-correctWebGPU_SharedKV_Rotary_MultiBatch).GroupQueryAttentionTest.*suite: 48 pass, 7 skipped (CUDA-only on this build).lintrunner -aclean.