|
5 | 5 |
|
6 | 6 | CHUNK_SIZE = 64 |
7 | 7 |
|
8 | | -# Use FLA's Triton kernels as reference (same Triton tl.dot as Helion) |
9 | | -from fla.ops.common.chunk_o import chunk_fwd_o as fla_chunk_fwd_o |
10 | | -from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h as fla_chunk_fwd_h |
11 | | -from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as fla_recompute_w_u_fwd |
12 | | -from fla.ops.utils import chunk_local_cumsum, solve_tril |
13 | | -from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd |
| 8 | + |
| 9 | +def _chunk_local_cumsum_eager(g, chunk_size): |
| 10 | + B, T, H = g.shape |
| 11 | + C = chunk_size |
| 12 | + return g.float().reshape(B, T // C, C, H).cumsum(dim=2).reshape(B, T, H) |
| 13 | + |
| 14 | + |
| 15 | +def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size): |
| 16 | + B, T, H, K = k.shape |
| 17 | + C = chunk_size |
| 18 | + NT = T // C |
| 19 | + k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) |
| 20 | + g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) |
| 21 | + beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) |
| 22 | + kkt = k_c @ k_c.transpose(-1, -2) |
| 23 | + g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2) |
| 24 | + strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1) |
| 25 | + A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower |
| 26 | + return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32) |
| 27 | + |
| 28 | + |
| 29 | +def _solve_tril_eager(A, output_dtype): |
| 30 | + B, T, H, C = A.shape |
| 31 | + NT = T // C |
| 32 | + A_mat = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4) |
| 33 | + eye = torch.eye(C, device=A.device).expand_as(A_mat) |
| 34 | + result = torch.linalg.solve_triangular(eye + A_mat, eye, upper=False) |
| 35 | + return result.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(output_dtype) |
| 36 | + |
| 37 | + |
| 38 | +def _recompute_w_u_fwd_eager(k, v, beta, A, g): |
| 39 | + B, T, H, K = k.shape |
| 40 | + V = v.shape[-1] |
| 41 | + C = A.shape[-1] |
| 42 | + NT = T // C |
| 43 | + k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) |
| 44 | + v_c = v.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) |
| 45 | + beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) |
| 46 | + g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) |
| 47 | + A_c = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4) |
| 48 | + u_c = A_c @ (v_c * beta_c.unsqueeze(-1)) |
| 49 | + w_c = A_c @ (k_c * (beta_c * torch.exp(g_c)).unsqueeze(-1)) |
| 50 | + w = w_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, K).to(k.dtype) |
| 51 | + u = u_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(v.dtype) |
| 52 | + return w, u |
| 53 | + |
| 54 | + |
| 55 | +def _chunk_fwd_h_eager(k, w, u, g): |
| 56 | + B, T, H, K = k.shape |
| 57 | + V = u.shape[-1] |
| 58 | + C = CHUNK_SIZE |
| 59 | + NT = T // C |
| 60 | + k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) |
| 61 | + w_c = w.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) |
| 62 | + u_c = u.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) |
| 63 | + g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) |
| 64 | + h_all = torch.zeros(B, NT, H, K, V, dtype=torch.float32, device=k.device) |
| 65 | + v_new_c = torch.zeros_like(u_c) |
| 66 | + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=k.device) |
| 67 | + for c in range(NT): |
| 68 | + h_all[:, c] = h |
| 69 | + v_new_c[:, c] = u_c[:, c] - w_c[:, c] @ h |
| 70 | + g_last = g_c[:, c, :, -1] |
| 71 | + gate = torch.exp(g_last.unsqueeze(-1) - g_c[:, c]) |
| 72 | + v_gated = v_new_c[:, c] * gate.unsqueeze(-1) |
| 73 | + h = h * torch.exp(g_last).unsqueeze(-1).unsqueeze(-1) + k_c[:, c].transpose(-1, -2) @ v_gated |
| 74 | + v_new_out = v_new_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(u.dtype) |
| 75 | + return h_all.to(k.dtype), v_new_out |
14 | 76 |
|
15 | 77 |
|
16 | 78 | def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t: |
17 | 79 | torch.manual_seed(seed) |
18 | 80 | device = "cuda" |
19 | | - # Generate pipeline-derived inputs: base inputs -> g_cumsum, A, w, u, h, v_new via FLA utilities |
20 | 81 | q = torch.randn(B, T, H, K, dtype=torch.float32, device=device) |
21 | 82 | k = F.normalize(torch.randn(B, T, H, K, dtype=torch.float32, device=device), p=2, dim=-1) |
22 | 83 | v = torch.randn(B, T, H, V, dtype=torch.float32, device=device) |
23 | 84 | beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device)) |
24 | 85 | g = F.logsigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device)) |
25 | | - g_cumsum = chunk_local_cumsum(g, chunk_size=CHUNK_SIZE) |
26 | | - A = chunk_scaled_dot_kkt_fwd(k=k, g=g_cumsum, beta=beta, output_dtype=torch.float32) |
27 | | - A = solve_tril(A=A, output_dtype=k.dtype) |
28 | | - w, u = fla_recompute_w_u_fwd(k=k, v=v, beta=beta, A=A, g=g_cumsum) |
29 | | - h, v_new, _ = fla_chunk_fwd_h(k=k, w=w, u=u, g=g_cumsum, output_final_state=False) |
| 86 | + g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE) |
| 87 | + A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE) |
| 88 | + A = _solve_tril_eager(A=A, output_dtype=k.dtype) |
| 89 | + w, u = _recompute_w_u_fwd_eager(k=k, v=v, beta=beta, A=A, g=g_cumsum) |
| 90 | + h, v_new = _chunk_fwd_h_eager(k=k, w=w, u=u, g=g_cumsum) |
30 | 91 | return q.contiguous(), k.contiguous(), v_new.contiguous(), h.contiguous(), g_cumsum.contiguous() |
31 | 92 |
|
32 | 93 |
|
33 | 94 | def ref_kernel(data: input_t) -> output_t: |
34 | 95 | q, k, v_new, h, g = data |
35 | | - K = q.shape[-1] |
| 96 | + B, T, H, K = q.shape |
| 97 | + V = v_new.shape[-1] |
| 98 | + C = CHUNK_SIZE |
| 99 | + NT = T // C |
36 | 100 | scale = K ** -0.5 |
37 | | - o = fla_chunk_fwd_o(q=q, k=k, v=v_new, h=h, g=g, scale=scale) |
38 | | - return o |
| 101 | + q_c = q.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) |
| 102 | + k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4) |
| 103 | + v_c = v_new.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4) |
| 104 | + g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2) |
| 105 | + o_inter = (q_c @ h.float()) * torch.exp(g_c).unsqueeze(-1) |
| 106 | + qk = q_c @ k_c.transpose(-1, -2) * torch.exp(g_c.unsqueeze(-1) - g_c.unsqueeze(-2)) |
| 107 | + causal = torch.tril(torch.ones(C, C, device=q.device)) |
| 108 | + o = (o_inter + (qk * causal) @ v_c) * scale |
| 109 | + return o.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(q.dtype) |
39 | 110 |
|
40 | 111 |
|
41 | 112 | check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2) |
0 commit comments