Skip to content

webgpu: fix GQA batched right-padded prefill with do_rotary#29002

Open
qjia7 wants to merge 5 commits into
microsoft:mainfrom
qjia7:fix/webgpu-gqa-batched-right-pad-rotary
Open

webgpu: fix GQA batched right-padded prefill with do_rotary#29002
qjia7 wants to merge 5 commits into
microsoft:mainfrom
qjia7:fix/webgpu-gqa-batched-right-pad-rotary

Conversation

@qjia7

@qjia7 qjia7 commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Clamp past_seqlen to 0 in three WebGPU rotary embedding shaders to prevent u32 underflow during batched right-padded prefill (RotaryEmbeddingProgram seqlens variant, FusedQKRotaryEmbeddingProgram, and split_packed_qkv_with_rotary_embedding.wgsl.template).
  • Extend CanApplyFlashAttention to bypass FlashAttention for batched cases with per-batch seqlens (the unpatched split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template still has the underflow), while explicitly allowing it for shared-KV layers where it is mandatory.
  • Add a regression test WebGPU_BatchedRightPaddedRotaryPrefill exercising the packed-QKV do_rotary path with three batches of unequal real lengths.

Motivation

When GenAI runs a batched prefill with prompts of unequal lengths, shorter prompts are right-padded up to the batch max sequence_length and each batch's real length is reported via seqlens_k[b] = real_len[b] - 1. The WebGPU rotary embedding shaders computed past_seqlen = (seqlens_k[b] + 1) - sequence_length per batch, which underflowed u32 for any batch shorter than sequence_length. The resulting astronomically large position_id indexed past the cos/sin caches and produced garbage rotated Q/K, which manifested as gibberish output text for the shorter batches.

Reproduced on phi4-prune via three concurrent prompts with unequal lengths: batches 0 and 1 produced output such as - - - - while batch 2 (longest, no underflow) matched CPU output. After the fix, all three batches produce coherent text matching CPU output.

Test plan

  • New unit test GroupQueryAttentionTest.WebGPU_BatchedRightPaddedRotaryPrefill passes with fix, fails without (max diff 0.0355 in short-batch channel vs 5e-3 tolerance).
  • All 7 GroupQueryAttentionTest.WebGPU_* tests pass (including previously-correct WebGPU_SharedKV_Rotary_MultiBatch).
  • Full GroupQueryAttentionTest.* suite: 48 pass, 7 skipped (CUDA-only on this build).
  • phi4-prune 3-prompt batched generation produces coherent output for all batches matching CPU reference.
  • phi4-prune single-prompt generation still correct (FlashAttention path on batch=1).
  • whisper-tiny-int4: 2/2 cases byte-exact match with CPU reference.
  • lintrunner -a clean.

When GenAI runs a batched prefill with prompts of unequal lengths, short
prompts are right-padded up to the batch max sequence_length and each
batch's real length is reported via seqlens_k[b] = real_len[b] - 1. The
WebGPU rotary embedding shaders computed past_seqlen = (seqlens_k[b]+1)
- sequence_length per batch, which underflowed u32 for any batch shorter
than sequence_length. The resulting astronomically large position_id
indexed past the cos/sin caches and produced garbage rotated Q/K, which
manifested as gibberish output text for the shorter batches in the
batch.

Clamp past_seqlen to 0 in all three rotary embedding shaders:
RotaryEmbeddingProgram (seqlens variant), FusedQKRotaryEmbeddingProgram,
and the split_packed_qkv_with_rotary_embedding template. Also extend
CanApplyFlashAttention to bypass FlashAttention for batched cases with
per-batch seqlens (which exercise the unpatched and-copykv variant),
while still allowing it for shared-KV layers where it is mandatory.

Adds a regression test exercising the packed-QKV do_rotary path with
three batches of unequal real lengths.
@qjia7 qjia7 marked this pull request as draft June 15, 2026 01:05
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Jun 15, 2026
The test added in the previous commit was scoped to WebGPU, but the
property it asserts (each batch's real-last-token output equals the
single-prompt reference) is generic and applies to any GQA-supporting
EP. CPU and CUDA both support packed-QKV with do_rotary, so generalizing
the test gives meaningful cross-EP coverage instead of leaving CPU and
CUDA uncovered.

Mirror the existing convention in this file: the runner takes
bool use_cuda, bool use_webgpu defaulting to false (same as
RunGQASharedKV / RunGQASharedKVWithRotary), and three thin TEST cases
named _CPU / _CUDA / _WebGPU dispatch the inner helper for each EP with
runtime availability checks via GTEST_SKIP. No production code touched.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The CPU variant of the GQA packed-QKV do_rotary right-padded prefill
regression test exposes a separate CPU-side issue tracked and fixed in
another PR. Drop the _CPU TEST here so this PR's CI stays green while
the CPU fix lands independently. The _CUDA and _WebGPU variants remain
and continue to exercise the property under test.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@qjia7 qjia7 marked this pull request as ready for review June 17, 2026 04:20
@qjia7 qjia7 requested a review from Copilot June 17, 2026 04:21

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes an incorrect RoPE position computation in WebGPU GroupQueryAttention when running batched, right-padded prefill with per-batch seqlens_k (packed QKV + do_rotary), preventing u32 underflow that could index past cos/sin caches and corrupt outputs.

Changes:

  • Clamp past_seqlen to 0 in WebGPU rotary embedding shader codepaths to avoid u32 underflow for right-padded batches.
  • Extend WebGPU FlashAttention gating to opt out when running multi-batch with per-batch seqlens_k (while still allowing shared-KV / kv_sequence_length==0 cases).
  • Add a regression test that compares each batch’s real-last-token output in a padded multi-batch prefill against its batch=1 reference run.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Adds a regression test for batched right-padded packed-QKV prefill with do_rotary.
onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template Clamps past_seqlen to prevent u32 underflow in the packed-QKV split + rotary shader.
onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc Clamps past_seqlen to prevent u32 underflow in generated rotary shaders (seqlens and fused QK variants).
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Passes seqlen_k into FlashAttention eligibility check so batched-per-seqlen cases can be gated off.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Updates CanApplyFlashAttention signature to accept optional seqlen_k.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Implements the new eligibility rule to bypass FlashAttention for batched runs with per-batch seqlen_k (except kv_empty).

Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
@qjia7 qjia7 requested review from guschmue and hariharans29 June 17, 2026 04:52
@hariharans29

Copy link
Copy Markdown
Member

Review of PR #29002 — webgpu: fix GQA batched right-padded prefill with do_rotary

Verdict: approve, with one follow-up worth either folding in or filing as a tracking issue. The diagnosis (u32 underflow in WGSL when total_seqlen < sequence_length) is sharp, the select(...) clamp is the right fix, the FlashAttention bypass is necessary and minimal, and the test is well-constructed. 86/86 checks green.

Why the fix is correct

WGSL u32 - u32 wraps on underflow — there's no signed semantics or implicit promotion. In a right-padded multi-batch prefill, seqlens_k[b] = real_len[b] - 1 for each batch, so for any short batch total_seqlen = seqlens_k[b] + 1 < sequence_length and total_seqlen - sequence_length wraps to 0xFFFFFFFF-ish. position_id = (huge) + sequence_idx then indexes far past cos_cache_shape[0]. Where the shader has a downstream bounds check (e.g. FusedQKRotaryEmbeddingProgram at rotary_embedding.cc:91-93 has if (position_id >= max_position) → passthrough), the wrapped index is "saved" by passthrough — silently producing rotated-shaped output that wasn't rotated, which manifests as gibberish, exactly the symptom in the PR description. Where it doesn't (the templated shader), it'd be a hard OOB. Either way the clamp is mandatory.

select(total_seqlen - uniforms.sequence_length, 0u, total_seqlen <= uniforms.sequence_length) is WGSL-correct: select(a, b, cond) returns b when cond is true, so this returns 0u on the would-underflow case and the normal difference otherwise. The author got the argument order right (a common WGSL footgun where the convention differs from ?: operands).

Coverage of shaders is intentional but not exhaustive

The PR fixes three sites — all the non-flash codepaths reachable from the GQA packed-QKV + do_rotary flow:

Then it explicitly does not fix the fourth, sibling site: split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template:34:

let past_seqlen = total_seqlen - uniforms.sequence_length;
let position_id = past_seqlen + seq_idx;

Same underflow pattern, byte for byte. Instead, the PR closes the door by tightening CanApplyFlashAttention so the multi-batch right-padded case can no longer reach this shader.

This is pragmatic, but it leaves a latent footgun: the gate's kv_empty (kv_sequence_length_ == 0) exemption still allows batch_size_ > 1 && seqlen_k != nullptr to take the FlashAttention path. That's the shared-KV regime, and per the test plan WebGPU_SharedKV_* passes — so the combination "shared-KV layer in a batched right-padded prefill" presumably either isn't reached in practice or is shaped such that total_seqlen >= sequence_length for every batch. But that's a subtle invariant to leave undocumented. Two reasonable options:

  1. Just fix it here. It's a one-line select(...) change matching the three already in the PR. The unfused gating change can stay as defense-in-depth, but the shader stops being a hidden bear-trap. This is what I'd prefer.
  2. File a follow-up issue referenced from a code comment at the _and_copykv:34 site explaining "still has u32 underflow; reachable only if CanApplyFlashAttention lets through kv_empty && batch>1 && seqlen_k!=nullptr".

Either way, please don't let this fix-by-gating live silently — the diff at rotary_embedding.cc and the template have nice "would underflow u32; clamp to 0" comments; the unfixed sibling should at least call out that it's known broken.

CanApplyFlashAttention gating change

bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters,
                            onnxruntime::webgpu::ComputeContext& context,
                            const Tensor* seqlen_k) {
  const bool kv_empty = parameters.kv_sequence_length_ == 0;
  return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) &&
         !parameters.is_packed_qkv_ &&
         parameters.head_size_ == parameters.v_head_size_ &&
         (...);
}

The three escape hatches are correct in spirit:

  • batch_size_ == 1 — no padding possible.
  • seqlen_k == nullptr — caller never supplied per-batch lengths, so right-padding can't be expressed.
  • kv_empty — the shared-KV layer case where FlashAttention is needed.

The default const Tensor* seqlen_k = nullptr in flash_attention.h:208 keeps existing callers compiling unchanged, and the one direct caller in group_query_attention.cc:353 now passes the real seqlen_k. Good.

One readability nit: the conjunction is now four conditions, the first wrapped in parens. Worth a one-line comment above the return explaining the three-way OR — future readers shouldn't have to grep the PR to understand why kv_empty is an escape hatch:

// Multi-batch + per-batch seqlens (right-padded prefill) hits a u32 underflow
// in split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template. Allow:
//   - batch_size_ == 1: no padding possible
//   - seqlen_k == nullptr: no per-batch lengths, padding inexpressible
//   - kv_empty (shared-KV layer): FlashAttention is mandatory here

Test design

BatchedRightPaddedRotaryPrefill_{CUDA,WebGPU} is well-shaped:

  • Comparative consistency check: each batch's real-last-token slice from the batched-padded run is compared against a batch=1 reference, both on the same EP. Cleanly factors out EP-vs-EP numerics and isolates "batched padding preserves per-batch correctness", which is exactly the contract the bug violated.
  • Real lengths {4, 2, 6} with sequence_length=6: two short batches, one full-length batch. Exactly the case where the bug shows up; the full batch (no underflow) serves as a sanity anchor.
  • Tolerance 5e-3: loose enough for fp16, tight enough that the bug (order-of-magnitude divergence or NaN/Inf, per the PR description) cannot pass. Good.
  • SetOutputTolerance(1e6f): the framework comparator is intentionally disabled because the test does its own per-batch slicing. Slightly hacky but standard in this file; fine.
  • max_seq_len = sequence_length + 8: a defensive cos/sin cache slack so the bounds checks downstream don't false-trigger for any legitimate position. Sensible.
  • CUDA variant included: cheap to add, broadens regression coverage to CUDA's rotary path. Nice.

Two minor test-side observations:

  • RunGQAPackedQKVRotaryPrefill(..., bool use_cuda = false, bool use_webgpu = false): the branching is if (use_cuda) ... else if (use_webgpu) ... else CPU. Calling with both true silently picks CUDA. Not a real bug since both call sites pass exclusive bools, but a enum class TargetEp { Cpu, Cuda, WebGpu } would scale better once you start adding more EPs. Pure nit.
  • AddOptionalInputEdge<float>(); // position_ids: the schema type for position_ids is int64, not float. For an unfilled optional edge it doesn't matter in practice, but the wrong template arg is mildly misleading to a reader trying to understand the input layout.

Other things worth a one-line mention

  • The fix is the same three-letter shader change in three places. Worth a tiny inline helper or template if a fourth site shows up — but with only three call sites and three identical fixes, inlining is fine for now.
  • use_smooth_softmax_ and use_sliding_window already disable will_use_flash_attention at group_query_attention.cc:265, so this PR's gating change interacts cleanly with the existing escape hatches. No double-gating to worry about.
  • Test plan ran phi4-prune end-to-end (real model, real generation, three concurrent prompts of unequal length) and whisper-tiny-int4 — that's the right shape of model-level smoke test to back the unit test.

Bottom line

Land it. The unit test, the WGSL clamp, and the FlashAttention bypass together solve the reported failure cleanly. Strongly suggest folding the same select(...) clamp into split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template so the broader rotary-RoPE underflow class is fully closed rather than gated around — otherwise please at least add a comment at that site naming it as a known-broken codepath reachable only through the kv_empty exemption.

qjia7 added 2 commits June 18, 2026 14:17
…otary prefill

Apply three reviewer suggestions on the GQA WebGPU right-padded prefill fix:

- Clamp the 4th packed-QKV rotary shader site against u32 underflow, byte
  identical to the sibling shader pattern.
- Document the disjunction in CanApplyFlashAttention as a positive contract:
  FlashAttention does not implement right-padded per-batch prefill, so the
  first disjunction restricts inputs to shapes where padding cannot occur.
- Replace the (bool use_cuda, bool use_webgpu) pair across three test
  helpers with a single GqaTargetEp enum, route EP construction through a
  central MakeExecutionProviderForGqaTest helper with an ORT_THROW default
  so a future enumerator cannot silently fall through to an empty provider
  vector, and migrate every caller.

All 56 GroupQueryAttentionTest cases pass locally (48 OK + 8 CUDA-skipped
on a machine without CUDA EP).
- Spelling: "Modelled" -> "Modeled" (US English, per misspell linter)
- clang-format: split the GqaTargetEp enum onto one enumerator per line
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants