diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index bb48bfe49c..e4d80e6ff9 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -422,8 +422,8 @@ def _gdn_decode_kernel( conv_state_indices=infer_state.b_buffer_idx, ) - # Recurrent processing with fused gating - # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally + # Recurrent processing with fused gating; the kernel reads the + # q/k/v/a/b column views directly via per-token strides (no copies) query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) core_attn_out, _ = fused_recurrent_gated_delta_rule( q=query, diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py index 22a93a2c99..dda4710c4c 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -54,6 +54,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + stride_q_tok: tl.constexpr, + stride_k_tok: tl.constexpr, + stride_v_tok: tl.constexpr, + stride_a_tok: tl.constexpr, + stride_b_tok: tl.constexpr, stride_init_state_token: tl.constexpr, stride_final_state_token: tl.constexpr, stride_indices_seq: tl.constexpr, @@ -94,15 +99,15 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( o_k = i_k * BK + tl.arange(0, BK) o_v = i_v * BV + tl.arange(0, BV) - p_q = q + (bos * H + i_h) * K + o_k - p_k = k + (bos * H + i_h) * K + o_k - p_v = v + (bos * HV + i_hv) * V + o_v + p_q = q + bos * stride_q_tok + i_h * K + o_k + p_k = k + bos * stride_k_tok + i_h * K + o_k + p_v = v + bos * stride_v_tok + i_hv * V + o_v if FUSE_GATING: # Fused gating: load per-head constants once, compute g/beta inline per token b_A_log = tl.load(A_log + i_hv).to(tl.float32) b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32) - p_a_raw = a_raw + bos * HV + i_hv - p_b_raw = b_raw + bos * HV + i_hv + p_a_raw = a_raw + bos * stride_a_tok + i_hv + p_b_raw = b_raw + bos * stride_b_tok + i_hv else: if IS_BETA_HEADWISE: p_beta = beta + (bos * HV + i_hv) * V + o_v @@ -193,13 +198,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) - p_q += H * K - p_k += H * K + p_q += stride_q_tok + p_k += stride_k_tok p_o += HV * V - p_v += HV * V + p_v += stride_v_tok if FUSE_GATING: - p_a_raw += HV - p_b_raw += HV + p_a_raw += stride_a_tok + p_b_raw += stride_b_tok else: if not IS_KDA: p_g += HV @@ -208,6 +213,43 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_beta += HV * (V if IS_BETA_HEADWISE else 1) +def _token_stride(x: torch.Tensor, inner_numel: int, cu_seqlens) -> int: + """Per-token element stride of x addressed as [tokens, ...inner dims...]. + + The kernel reads token ``i`` at ``base + i * token_stride`` with the inner + dims packed, which supports column views of one wider projection output + (token stride larger than inner_numel). Returns -1 if x's layout cannot be + addressed that way (caller must fall back to .contiguous()). + """ + if x.dim() == 2: + # [tokens, inner] (a_raw / b_raw) + return x.stride(0) if x.stride(1) == 1 else -1 + # 4D q/k/v + if cu_seqlens is not None: + # varlen layout [1, tokens, head, dim] + token_dim = 1 + elif x.shape[1] == 1: + # decode layout [tokens, 1, head, dim] + token_dim = 0 + else: + # [B, T>1, head, dim]: a single token stride only exists if contiguous + return inner_numel if x.is_contiguous() else -1 + if x.stride(-1) == 1 and x.stride(-2) == x.shape[-1]: + return x.stride(token_dim) + return -1 + + +def _ensure_token_strided(x: torch.Tensor, inner_numel: int, cu_seqlens): + """Return (tensor, token_stride); copies to contiguous only when needed.""" + if x is None: + return None, 0 + stride = _token_stride(x, inner_numel, cu_seqlens) + if stride < 0: + x = x.contiguous() + stride = inner_numel + return x, stride + + def fused_recurrent_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, @@ -232,6 +274,11 @@ def fused_recurrent_gated_delta_rule_fwd( B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] N = B if cu_seqlens is None else len(cu_seqlens) - 1 + q, stride_q_tok = _ensure_token_strided(q, H * K, cu_seqlens) + k, stride_k_tok = _ensure_token_strided(k, H * K, cu_seqlens) + v, stride_v_tok = _ensure_token_strided(v, HV * V, cu_seqlens) + a_raw, stride_a_tok = _ensure_token_strided(a_raw, HV, cu_seqlens) + b_raw, stride_b_tok = _ensure_token_strided(b_raw, HV, cu_seqlens) BK = triton.next_power_of_2(K) if T == 1: # Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16) @@ -261,20 +308,23 @@ def fused_recurrent_gated_delta_rule_fwd( stride_init_state_token = initial_state.stride(0) stride_final_state_token = final_state.stride(0) - # Strides for read indices + # Strides for read indices. The kernel advances along a row with `+ i_t` + # (token stride 1), so 2D index tensors must have contiguous rows. if ssm_state_indices is None: stride_indices_seq, stride_indices_tok = 1, 1 elif ssm_state_indices.ndim == 1: stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 else: + assert ssm_state_indices.stride(-1) == 1, "2D ssm_state_indices must have contiguous rows" stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() - # Strides for write indices (if provided) + # Strides for write indices (if provided); same contiguous-row requirement if ssm_state_write_indices is None: stride_write_indices_seq, stride_write_indices_tok = 1, 1 elif ssm_state_write_indices.ndim == 1: stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 else: + assert ssm_state_write_indices.stride(-1) == 1, "2D ssm_state_write_indices must have contiguous rows" stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride() grid = (NK, NV, N * HV) @@ -305,6 +355,11 @@ def fused_recurrent_gated_delta_rule_fwd( V=V, BK=BK, BV=BV, + stride_q_tok=stride_q_tok, + stride_k_tok=stride_k_tok, + stride_v_tok=stride_v_tok, + stride_a_tok=stride_a_tok, + stride_b_tok=stride_b_tok, stride_init_state_token=stride_init_state_token, stride_final_state_token=stride_final_state_token, stride_indices_seq=stride_indices_seq, @@ -348,10 +403,12 @@ def forward( b_raw: torch.Tensor | None = None, out: torch.Tensor | None = None, ): + # q/k/v/a_raw/b_raw may be non-contiguous column views of one projection + # output; the kernel handles them via per-token strides (no copies). o, final_state = fused_recurrent_gated_delta_rule_fwd( - q=q.contiguous(), - k=k.contiguous(), - v=v.contiguous(), + q=q, + k=k, + v=v, g=g.contiguous() if g is not None else None, beta=beta.contiguous() if beta is not None else None, scale=scale, @@ -364,8 +421,8 @@ def forward( use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, A_log=A_log, dt_bias=dt_bias, - a_raw=a_raw.contiguous() if a_raw is not None else None, - b_raw=b_raw.contiguous() if b_raw is not None else None, + a_raw=a_raw, + b_raw=b_raw, out=out, ) diff --git a/unit_tests/models/qwen3next/test_fused_recurrent_strided.py b/unit_tests/models/qwen3next/test_fused_recurrent_strided.py new file mode 100644 index 0000000000..748da704d4 --- /dev/null +++ b/unit_tests/models/qwen3next/test_fused_recurrent_strided.py @@ -0,0 +1,125 @@ +import pytest +import torch + +from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import ( + fused_recurrent_gated_delta_rule, +) + +if not torch.cuda.is_available(): + pytest.skip("CUDA required", allow_module_level=True) + + +@pytest.mark.parametrize("batch", [1, 2, 16]) +def test_decode_strided_views_match_contiguous(batch): + """q/k/v/a/b passed as column views of one projection output (the decode + path layout) must produce the same result as contiguous copies.""" + torch.manual_seed(0) + H, HV, K, V = 2, 8, 128, 128 + key_dim, value_dim = H * K, HV * V + qkv_dim = 2 * key_dim + value_dim + total_dim = qkv_dim + value_dim + 2 * HV # qkv + z + b + a + cache_slots = 64 + + mixed = torch.randn(batch, total_dim, device="cuda", dtype=torch.bfloat16) + mixed_qkv = mixed[:, :qkv_dim] + b_raw = mixed[:, qkv_dim + value_dim : qkv_dim + value_dim + HV] + a_raw = mixed[:, qkv_dim + value_dim + HV :] + + query, key, value = torch.split(mixed_qkv, [key_dim, key_dim, value_dim], dim=-1) + q = query.view(batch, 1, H, K) + k = key.view(batch, 1, H, K) + v = value.view(batch, 1, HV, V) + + A_log = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 + dt_bias = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 + ssm_state = torch.randn(cache_slots, HV, K, V, device="cuda", dtype=torch.bfloat16) + idx = torch.randperm(cache_slots, device="cuda")[:batch].to(torch.int32) + + def run(q_, k_, v_, a_, b_, state): + out, _ = fused_recurrent_gated_delta_rule( + q=q_, + k=k_, + v=v_, + initial_state=state, + inplace_final_state=True, + ssm_state_indices=idx, + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_, + b_raw=b_, + ) + return out + + state_ref = ssm_state.clone() + out_ref = run(q.contiguous(), k.contiguous(), v.contiguous(), a_raw.contiguous(), b_raw.contiguous(), state_ref) + state_strided = ssm_state.clone() + out_strided = run(q, k, v, a_raw, b_raw, state_strided) + + assert torch.equal(out_ref, out_strided) + assert torch.equal(state_ref, state_strided) + + +def test_varlen_strided_views_match_contiguous(): + """Varlen layout [1, tokens, H, K] with column-view inputs.""" + torch.manual_seed(1) + H, HV, K, V = 2, 8, 128, 128 + key_dim, value_dim = H * K, HV * V + qkv_dim = 2 * key_dim + value_dim + total_dim = qkv_dim + value_dim + 2 * HV + seqlens = [3, 5, 1] + tokens = sum(seqlens) + cu = torch.tensor([0, 3, 8, 9], device="cuda", dtype=torch.long) + + mixed = torch.randn(tokens, total_dim, device="cuda", dtype=torch.bfloat16) + mixed_qkv = mixed[:, :qkv_dim] + b_raw = mixed[:, qkv_dim + value_dim : qkv_dim + value_dim + HV] + a_raw = mixed[:, qkv_dim + value_dim + HV :] + query, key, value = torch.split(mixed_qkv, [key_dim, key_dim, value_dim], dim=-1) + q = query.view(1, tokens, H, K) + k = key.view(1, tokens, H, K) + v = value.view(1, tokens, HV, V) + + A_log = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 + dt_bias = torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1 + # ssm_state_indices is required: the non-continuous-batching varlen branch + # indexes h0 by token offset (bos) instead of sequence index, reading out + # of bounds for any sequence after the first (latent upstream bug; all + # production call sites pass ssm_state_indices). With inplace_final_state + # the kernel writes a state per token, so indices are 2D [N, max_seqlen] + # mapping each (seq, token) to a distinct slot; the seq's initial state is + # read from its token-0 slot. + max_len = max(seqlens) + idx = torch.zeros(len(seqlens), max_len, device="cuda", dtype=torch.int32) + slot = 0 + for i, sl in enumerate(seqlens): + idx[i, :sl] = torch.arange(slot, slot + sl, device="cuda", dtype=torch.int32) + slot += sl + init_state = torch.randn(tokens, HV, K, V, device="cuda", dtype=torch.bfloat16) + + def run(q_, k_, v_, a_, b_): + state = init_state.clone() + out, _ = fused_recurrent_gated_delta_rule( + q=q_, + k=k_, + v=v_, + initial_state=state, + inplace_final_state=True, + cu_seqlens=cu, + ssm_state_indices=idx, + use_qk_l2norm_in_kernel=True, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_, + b_raw=b_, + ) + return out, state + + out_ref, final_ref = run(q.contiguous(), k.contiguous(), v.contiguous(), a_raw.contiguous(), b_raw.contiguous()) + out_strided, final_strided = run(q, k, v, a_raw, b_raw) + assert torch.equal(out_ref, out_strided) + assert torch.equal(final_ref, final_strided) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])