perf(qwen3next): drop q/k/v/a/b contiguous copies in GDN fused_recurrent decode#1349
perf(qwen3next): drop q/k/v/a/b contiguous copies in GDN fused_recurrent decode#1349sufubao wants to merge 1 commit into
Conversation
…ent decode The gated-delta-rule decode path passed q/k/v/a_raw/b_raw to FusedRecurrentFunction, which copied each to a contiguous buffer before the Triton kernel. These are column views of one fused projection output, so the copies are pure per-step overhead. Instead, pass the per-token element stride of each tensor into the kernel so it reads the column views directly (base + i * token_stride). A helper derives the stride for the decode [tokens,1,H,D], varlen [1,tokens,H,D], and 2D a/b layouts, and falls back to .contiguous() only when a tensor can't be addressed that way. Adds unit_tests/models/qwen3next/test_fused_recurrent_strided.py asserting bit-exact (torch.equal) parity of output and SSM state between the strided and contiguous paths, for decode (bs 1/2/16) and the varlen layout. Static decode throughput (Qwen3.5-122B-A10B, TP8 / H200, output_len=256): consistent +2-8% across batch 2..128 (mean ~+5%); prefill path unchanged.
There was a problem hiding this comment.
Code Review
This pull request optimizes the fused recurrent gated delta rule kernel by allowing it to read non-contiguous column views of projection outputs (q, k, v, a_raw, b_raw) directly via per-token strides, eliminating the need for contiguous copies. It introduces helper functions to calculate token strides and adds comprehensive unit tests to verify correctness. The review feedback suggests replacing assert statements with explicit ValueError exceptions for validating tensor strides, ensuring these critical runtime checks are not stripped out when Python is run with optimization flags.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| elif ssm_state_indices.ndim == 1: | ||
| stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 | ||
| else: | ||
| assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows" |
There was a problem hiding this comment.
Using assert statements for runtime validation of tensor properties can be risky because assertions are stripped out when Python is run with optimization flags (-O). If these checks are bypassed, it could lead to silent correctness issues or out-of-bounds memory accesses in the Triton kernel. It is safer to raise a ValueError instead.
| assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows" | |
| if ssm_state_indices.stride(-1) != 1: | |
| raise ValueError("2D ssm_state_indices must have contiguous rows") |
| elif ssm_state_write_indices.ndim == 1: | ||
| stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 | ||
| else: | ||
| assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows" |
There was a problem hiding this comment.
Using assert statements for runtime validation of tensor properties can be risky because assertions are stripped out when Python is run with optimization flags (-O). If these checks are bypassed, it could lead to silent correctness issues or out-of-bounds memory accesses in the Triton kernel. It is safer to raise a ValueError instead.
| assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows" | |
| if ssm_state_write_indices.stride(-1) != 1: | |
| raise ValueError("2D ssm_state_write_indices must have contiguous rows") |
What
The Qwen3-Next / Qwen3.5 gated-delta-rule (GDN) decode path passes
q/k/v/a_raw/b_rawintoFusedRecurrentFunction, which previously called.contiguous()on each before launching the Triton kernel. These tensors are column views of a single fused projection output, so the copies are pure per-decode-step overhead (extra allocations + copy kernels).This PR teaches the kernel to read those views directly: it passes the per-token element stride of each tensor in and indexes token
iasbase + i * token_stride. A small helper (_ensure_token_strided) derives the stride for the decode[tokens,1,H,D], varlen[1,tokens,H,D], and 2Da/blayouts, and falls back to.contiguous()only when a tensor genuinely can't be addressed that way.Correctness
Adds
unit_tests/models/qwen3next/test_fused_recurrent_strided.py, which asserts bit-exact parity (torch.equal) of both the output and the written SSM state between the strided and the old contiguous path:ssm_state_indicesAlso hardens the 2D index-stride handling with explicit contiguous-row asserts (the kernel advances read/write indices with token-stride 1).
Performance
Static decode throughput, Qwen3.5-122B-A10B, TP8 on H200,
output_len=256(decode tok/s = mean of steady-state steps 100/200/255, optimized vs. baseline on the same commit):Consistent improvement at every batch size (mean ~+5%). The prefill path is unchanged.