-
Notifications
You must be signed in to change notification settings - Fork 333
perf(qwen3next): drop q/k/v/a/b contiguous copies in GDN fused_recurrent decode #1349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -54,6 +54,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( | |||||||
| V: tl.constexpr, | ||||||||
| BK: tl.constexpr, | ||||||||
| BV: tl.constexpr, | ||||||||
| stride_q_tok: tl.constexpr, | ||||||||
| stride_k_tok: tl.constexpr, | ||||||||
| stride_v_tok: tl.constexpr, | ||||||||
| stride_a_tok: tl.constexpr, | ||||||||
| stride_b_tok: tl.constexpr, | ||||||||
| stride_init_state_token: tl.constexpr, | ||||||||
| stride_final_state_token: tl.constexpr, | ||||||||
| stride_indices_seq: tl.constexpr, | ||||||||
|
|
@@ -94,15 +99,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( | |||||||
| o_k = i_k * BK + tl.arange(0, BK) | ||||||||
| o_v = i_v * BV + tl.arange(0, BV) | ||||||||
|
|
||||||||
| p_q = q + (bos * H + i_h) * K + o_k | ||||||||
| p_k = k + (bos * H + i_h) * K + o_k | ||||||||
| p_v = v + (bos * HV + i_hv) * V + o_v | ||||||||
| p_q = q + bos * stride_q_tok + i_h * K + o_k | ||||||||
| p_k = k + bos * stride_k_tok + i_h * K + o_k | ||||||||
| p_v = v + bos * stride_v_tok + i_hv * V + o_v | ||||||||
| if FUSE_GATING: | ||||||||
| # Fused gating: load per-head constants once, compute g/beta inline per token | ||||||||
| b_A_log = tl.load(A_log + i_hv).to(tl.float32) | ||||||||
| b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32) | ||||||||
| p_a_raw = a_raw + bos * HV + i_hv | ||||||||
| p_b_raw = b_raw + bos * HV + i_hv | ||||||||
| p_a_raw = a_raw + bos * stride_a_tok + i_hv | ||||||||
| p_b_raw = b_raw + bos * stride_b_tok + i_hv | ||||||||
| else: | ||||||||
| if IS_BETA_HEADWISE: | ||||||||
| p_beta = beta + (bos * HV + i_hv) * V + o_v | ||||||||
|
|
@@ -193,13 +198,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( | |||||||
| p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] | ||||||||
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) | ||||||||
|
|
||||||||
| p_q += H * K | ||||||||
| p_k += H * K | ||||||||
| p_q += stride_q_tok | ||||||||
| p_k += stride_k_tok | ||||||||
| p_o += HV * V | ||||||||
| p_v += HV * V | ||||||||
| p_v += stride_v_tok | ||||||||
| if FUSE_GATING: | ||||||||
| p_a_raw += HV | ||||||||
| p_b_raw += HV | ||||||||
| p_a_raw += stride_a_tok | ||||||||
| p_b_raw += stride_b_tok | ||||||||
| else: | ||||||||
| if not IS_KDA: | ||||||||
| p_g += HV | ||||||||
|
|
@@ -208,6 +213,43 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( | |||||||
| p_beta += HV * (V if IS_BETA_HEADWISE else 1) | ||||||||
|
|
||||||||
|
|
||||||||
| def _token_stride(x: torch.Tensor, inner_numel: int, cu_seqlens) -> int: | ||||||||
| """Per-token element stride of x addressed as [tokens, ...inner dims...]. | ||||||||
|
|
||||||||
| The kernel reads token ``i`` at ``base + i * token_stride`` with the inner | ||||||||
| dims packed, which supports column views of one wider projection output | ||||||||
| (token stride larger than inner_numel). Returns -1 if x's layout cannot be | ||||||||
| addressed that way (caller must fall back to .contiguous()). | ||||||||
| """ | ||||||||
| if x.dim() == 2: | ||||||||
| # [tokens, inner] (a_raw / b_raw) | ||||||||
| return x.stride(0) if x.stride(1) == 1 else -1 | ||||||||
| # 4D q/k/v | ||||||||
| if cu_seqlens is not None: | ||||||||
| # varlen layout [1, tokens, head, dim] | ||||||||
| token_dim = 1 | ||||||||
| elif x.shape[1] == 1: | ||||||||
| # decode layout [tokens, 1, head, dim] | ||||||||
| token_dim = 0 | ||||||||
| else: | ||||||||
| # [B, T>1, head, dim]: a single token stride only exists if contiguous | ||||||||
| return inner_numel if x.is_contiguous() else -1 | ||||||||
| if x.stride(-1) == 1 and x.stride(-2) == x.shape[-1]: | ||||||||
| return x.stride(token_dim) | ||||||||
| return -1 | ||||||||
|
|
||||||||
|
|
||||||||
| def _ensure_token_strided(x: torch.Tensor, inner_numel: int, cu_seqlens): | ||||||||
| """Return (tensor, token_stride); copies to contiguous only when needed.""" | ||||||||
| if x is None: | ||||||||
| return None, 0 | ||||||||
| stride = _token_stride(x, inner_numel, cu_seqlens) | ||||||||
| if stride < 0: | ||||||||
| x = x.contiguous() | ||||||||
| stride = inner_numel | ||||||||
| return x, stride | ||||||||
|
|
||||||||
|
|
||||||||
| def fused_recurrent_gated_delta_rule_fwd( | ||||||||
| q: torch.Tensor, | ||||||||
| k: torch.Tensor, | ||||||||
|
|
@@ -232,6 +274,11 @@ def fused_recurrent_gated_delta_rule_fwd( | |||||||
| B, T, H, K, V = *k.shape, v.shape[-1] | ||||||||
| HV = v.shape[2] | ||||||||
| N = B if cu_seqlens is None else len(cu_seqlens) - 1 | ||||||||
| q, stride_q_tok = _ensure_token_strided(q, H * K, cu_seqlens) | ||||||||
| k, stride_k_tok = _ensure_token_strided(k, H * K, cu_seqlens) | ||||||||
| v, stride_v_tok = _ensure_token_strided(v, HV * V, cu_seqlens) | ||||||||
| a_raw, stride_a_tok = _ensure_token_strided(a_raw, HV, cu_seqlens) | ||||||||
| b_raw, stride_b_tok = _ensure_token_strided(b_raw, HV, cu_seqlens) | ||||||||
| BK = triton.next_power_of_2(K) | ||||||||
| if T == 1: | ||||||||
| # Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16) | ||||||||
|
|
@@ -261,20 +308,23 @@ def fused_recurrent_gated_delta_rule_fwd( | |||||||
| stride_init_state_token = initial_state.stride(0) | ||||||||
| stride_final_state_token = final_state.stride(0) | ||||||||
|
|
||||||||
| # Strides for read indices | ||||||||
| # Strides for read indices. The kernel advances along a row with `+ i_t` | ||||||||
| # (token stride 1), so 2D index tensors must have contiguous rows. | ||||||||
| if ssm_state_indices is None: | ||||||||
| stride_indices_seq, stride_indices_tok = 1, 1 | ||||||||
| elif ssm_state_indices.ndim == 1: | ||||||||
| stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 | ||||||||
| else: | ||||||||
| assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows" | ||||||||
| stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() | ||||||||
|
|
||||||||
| # Strides for write indices (if provided) | ||||||||
| # Strides for write indices (if provided); same contiguous-row requirement | ||||||||
| if ssm_state_write_indices is None: | ||||||||
| stride_write_indices_seq, stride_write_indices_tok = 1, 1 | ||||||||
| elif ssm_state_write_indices.ndim == 1: | ||||||||
| stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 | ||||||||
| else: | ||||||||
| assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows" | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||
| stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride() | ||||||||
|
|
||||||||
| grid = (NK, NV, N * HV) | ||||||||
|
|
@@ -305,6 +355,11 @@ def fused_recurrent_gated_delta_rule_fwd( | |||||||
| V=V, | ||||||||
| BK=BK, | ||||||||
| BV=BV, | ||||||||
| stride_q_tok=stride_q_tok, | ||||||||
| stride_k_tok=stride_k_tok, | ||||||||
| stride_v_tok=stride_v_tok, | ||||||||
| stride_a_tok=stride_a_tok, | ||||||||
| stride_b_tok=stride_b_tok, | ||||||||
| stride_init_state_token=stride_init_state_token, | ||||||||
| stride_final_state_token=stride_final_state_token, | ||||||||
| stride_indices_seq=stride_indices_seq, | ||||||||
|
|
@@ -348,10 +403,12 @@ def forward( | |||||||
| b_raw: torch.Tensor | None = None, | ||||||||
| out: torch.Tensor | None = None, | ||||||||
| ): | ||||||||
| # q/k/v/a_raw/b_raw may be non-contiguous column views of one projection | ||||||||
| # output; the kernel handles them via per-token strides (no copies). | ||||||||
| o, final_state = fused_recurrent_gated_delta_rule_fwd( | ||||||||
| q=q.contiguous(), | ||||||||
| k=k.contiguous(), | ||||||||
| v=v.contiguous(), | ||||||||
| q=q, | ||||||||
| k=k, | ||||||||
| v=v, | ||||||||
| g=g.contiguous() if g is not None else None, | ||||||||
| beta=beta.contiguous() if beta is not None else None, | ||||||||
| scale=scale, | ||||||||
|
|
@@ -364,8 +421,8 @@ def forward( | |||||||
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, | ||||||||
| A_log=A_log, | ||||||||
| dt_bias=dt_bias, | ||||||||
| a_raw=a_raw.contiguous() if a_raw is not None else None, | ||||||||
| b_raw=b_raw.contiguous() if b_raw is not None else None, | ||||||||
| a_raw=a_raw, | ||||||||
| b_raw=b_raw, | ||||||||
| out=out, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| import pytest | ||
| import torch | ||
|
|
||
| from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( | ||
| fused_recurrent_gated_delta_rule, | ||
| ) | ||
|
|
||
| if not torch.cuda.is_available(): | ||
| pytest.skip("CUDA required", allow_module_level=True) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("batch", [1, 2, 16]) | ||
| def test_decode_strided_views_match_contiguous(batch): | ||
| """q/k/v/a/b passed as column views of one projection output (the decode | ||
| path layout) must produce the same result as contiguous copies.""" | ||
| torch.manual_seed(0) | ||
| H, HV, K, V = 2, 8, 128, 128 | ||
| key_dim, value_dim = H * K, HV * V | ||
| qkv_dim = 2 * key_dim + value_dim | ||
| total_dim = qkv_dim + value_dim + 2 * HV # qkv + z + b + a | ||
| cache_slots = 64 | ||
|
|
||
| mixed = torch.randn(batch, total_dim, device="cuda", dtype=torch.bfloat16) | ||
| mixed_qkv = mixed[:, :qkv_dim] | ||
| b_raw = mixed[:, qkv_dim + value_dim : qkv_dim + value_dim + HV] | ||
| a_raw = mixed[:, qkv_dim + value_dim + HV :] | ||
|
|
||
| query, key, value = torch.split(mixed_qkv, [key_dim, key_dim, value_dim], dim=-1) | ||
| q = query.view(batch, 1, H, K) | ||
| k = key.view(batch, 1, H, K) | ||
| v = value.view(batch, 1, HV, V) | ||
|
|
||
| A_log = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 | ||
| dt_bias = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 | ||
| ssm_state = torch.randn(cache_slots, HV, K, V, device="cuda", dtype=torch.bfloat16) | ||
| idx = torch.randperm(cache_slots, device="cuda")[:batch].to(torch.int32) | ||
|
|
||
| def run(q_, k_, v_, a_, b_, state): | ||
| out, _ = fused_recurrent_gated_delta_rule( | ||
| q=q_, | ||
| k=k_, | ||
| v=v_, | ||
| initial_state=state, | ||
| inplace_final_state=True, | ||
| ssm_state_indices=idx, | ||
| use_qk_l2norm_in_kernel=True, | ||
| A_log=A_log, | ||
| dt_bias=dt_bias, | ||
| a_raw=a_, | ||
| b_raw=b_, | ||
| ) | ||
| return out | ||
|
|
||
| state_ref = ssm_state.clone() | ||
| out_ref = run(q.contiguous(), k.contiguous(), v.contiguous(), a_raw.contiguous(), b_raw.contiguous(), state_ref) | ||
| state_strided = ssm_state.clone() | ||
| out_strided = run(q, k, v, a_raw, b_raw, state_strided) | ||
|
|
||
| assert torch.equal(out_ref, out_strided) | ||
| assert torch.equal(state_ref, state_strided) | ||
|
|
||
|
|
||
| def test_varlen_strided_views_match_contiguous(): | ||
| """Varlen layout [1, tokens, H, K] with column-view inputs.""" | ||
| torch.manual_seed(1) | ||
| H, HV, K, V = 2, 8, 128, 128 | ||
| key_dim, value_dim = H * K, HV * V | ||
| qkv_dim = 2 * key_dim + value_dim | ||
| total_dim = qkv_dim + value_dim + 2 * HV | ||
| seqlens = [3, 5, 1] | ||
| tokens = sum(seqlens) | ||
| cu = torch.tensor([0, 3, 8, 9], device="cuda", dtype=torch.long) | ||
|
|
||
| mixed = torch.randn(tokens, total_dim, device="cuda", dtype=torch.bfloat16) | ||
| mixed_qkv = mixed[:, :qkv_dim] | ||
| b_raw = mixed[:, qkv_dim + value_dim : qkv_dim + value_dim + HV] | ||
| a_raw = mixed[:, qkv_dim + value_dim + HV :] | ||
| query, key, value = torch.split(mixed_qkv, [key_dim, key_dim, value_dim], dim=-1) | ||
| q = query.view(1, tokens, H, K) | ||
| k = key.view(1, tokens, H, K) | ||
| v = value.view(1, tokens, HV, V) | ||
|
|
||
| A_log = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 | ||
| dt_bias = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 | ||
| # ssm_state_indices is required: the non-continuous-batching varlen branch | ||
| # indexes h0 by token offset (bos) instead of sequence index, reading out | ||
| # of bounds for any sequence after the first (latent upstream bug; all | ||
| # production call sites pass ssm_state_indices). With inplace_final_state | ||
| # the kernel writes a state per token, so indices are 2D [N, max_seqlen] | ||
| # mapping each (seq, token) to a distinct slot; the seq's initial state is | ||
| # read from its token-0 slot. | ||
| max_len = max(seqlens) | ||
| idx = torch.zeros(len(seqlens), max_len, device="cuda", dtype=torch.int32) | ||
| slot = 0 | ||
| for i, sl in enumerate(seqlens): | ||
| idx[i, :sl] = torch.arange(slot, slot + sl, device="cuda", dtype=torch.int32) | ||
| slot += sl | ||
| init_state = torch.randn(tokens, HV, K, V, device="cuda", dtype=torch.bfloat16) | ||
|
|
||
| def run(q_, k_, v_, a_, b_): | ||
| state = init_state.clone() | ||
| out, _ = fused_recurrent_gated_delta_rule( | ||
| q=q_, | ||
| k=k_, | ||
| v=v_, | ||
| initial_state=state, | ||
| inplace_final_state=True, | ||
| cu_seqlens=cu, | ||
| ssm_state_indices=idx, | ||
| use_qk_l2norm_in_kernel=True, | ||
| A_log=A_log, | ||
| dt_bias=dt_bias, | ||
| a_raw=a_, | ||
| b_raw=b_, | ||
| ) | ||
| return out, state | ||
|
|
||
| out_ref, final_ref = run(q.contiguous(), k.contiguous(), v.contiguous(), a_raw.contiguous(), b_raw.contiguous()) | ||
| out_strided, final_strided = run(q, k, v, a_raw, b_raw) | ||
| assert torch.equal(out_ref, out_strided) | ||
| assert torch.equal(final_ref, final_strided) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
assertstatements for runtime validation of tensor properties can be risky because assertions are stripped out when Python is run with optimization flags (-O). If these checks are bypassed, it could lead to silent correctness issues or out-of-bounds memory accesses in the Triton kernel. It is safer to raise aValueErrorinstead.