|
1 | 1 | from task import input_t, output_t |
2 | 2 |
|
| 3 | +import torch |
| 4 | +import helion |
| 5 | +import helion.language as hl |
| 6 | + |
| 7 | + |
| 8 | +# NOTE: This is an intentionally inefficient baseline implementation. |
| 9 | +@helion.kernel( |
| 10 | + static_shapes=True, |
| 11 | + dot_precision="ieee", |
| 12 | + config=helion.Config(block_sizes=[], num_warps=1, num_stages=1), |
| 13 | +) |
| 14 | +def gated_chunk_attn( |
| 15 | + q: torch.Tensor, # [B, T, H, K] |
| 16 | + k: torch.Tensor, # [B, T, H, K] |
| 17 | + v: torch.Tensor, # [B, T, H, V] |
| 18 | + h: torch.Tensor, # [B, NT, H, K, V] |
| 19 | + g: torch.Tensor, # [B, T, H] |
| 20 | + scale: float, |
| 21 | +) -> torch.Tensor: |
| 22 | + B, T, H, K = q.shape |
| 23 | + V = v.shape[-1] |
| 24 | + C = 64 |
| 25 | + K = hl.specialize(K) |
| 26 | + V = hl.specialize(V) |
3 | 27 |
|
4 | | -def custom_kernel(data: input_t) -> output_t: |
5 | | - import torch |
| 28 | + out = torch.empty_like(v) |
6 | 29 |
|
7 | | - q, k, v_new, h, g = data |
8 | | - B, T, H, K = q.shape |
9 | | - V = v_new.shape[-1] |
10 | | - BT = 64 |
11 | | - scale = K ** -0.5 |
| 30 | + BH = B * H |
| 31 | + for flat_bh, tile_t in hl.tile([BH, T], block_size=[1, C]): |
| 32 | + b_idx = flat_bh.begin // H |
| 33 | + h_idx = flat_bh.begin % H |
| 34 | + c_idx = tile_t.begin // C |
12 | 35 |
|
13 | | - o = torch.empty_like(v_new) |
14 | | - causal = torch.tril(torch.ones(BT, BT, device=q.device, dtype=torch.bool)) |
| 36 | + g_vals = g[b_idx, tile_t, h_idx] |
| 37 | + q_s = q[b_idx, tile_t, h_idx, :] * torch.exp(g_vals)[:, None] |
| 38 | + k_s = k[b_idx, tile_t, h_idx, :] * torch.exp(-g_vals)[:, None] |
15 | 39 |
|
16 | | - for cs in range(0, T, BT): |
17 | | - ce = cs + BT |
18 | | - c_idx = cs // BT |
| 40 | + sim1 = hl.dot(q_s, k_s.T) |
| 41 | + sim2 = hl.dot(q_s, k_s.T) |
| 42 | + sim = (sim1 + sim2) * 0.5 |
| 43 | + idx = hl.arange(tile_t.block_size) |
| 44 | + mask = idx[:, None] >= idx[None, :] |
| 45 | + sim = torch.where(mask, sim, 0.0) |
| 46 | + local1 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) |
| 47 | + local2 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) |
| 48 | + local_out = (local1 + local2) * 0.5 |
19 | 49 |
|
20 | | - b_q = q[:, cs:ce, :, :].permute(0, 2, 1, 3).float() |
21 | | - b_k = k[:, cs:ce, :, :].permute(0, 2, 1, 3).float() |
22 | | - b_v = v_new[:, cs:ce, :, :].permute(0, 2, 1, 3).float() |
23 | | - b_h = h[:, c_idx, :, :, :].float() |
24 | | - b_g = g[:, cs:ce, :].permute(0, 2, 1).float() |
| 50 | + glob1 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) |
| 51 | + glob2 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) |
| 52 | + global_out = (glob1 + glob2) * 0.5 |
25 | 53 |
|
26 | | - inter = torch.matmul(b_q, b_h) |
27 | | - inter = inter * torch.exp(b_g).unsqueeze(-1) |
| 54 | + out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype) |
28 | 55 |
|
29 | | - attn = torch.matmul(b_q, b_k.transpose(-1, -2)) |
30 | | - g_diff = b_g.unsqueeze(-1) - b_g.unsqueeze(-2) |
31 | | - attn = attn * torch.exp(g_diff) |
32 | | - attn = attn.masked_fill(~causal, 0.0) |
33 | | - intra = torch.matmul(attn, b_v) |
| 56 | + return out |
34 | 57 |
|
35 | | - b_o = (inter + intra) * scale |
36 | | - o[:, cs:ce, :, :] = b_o.permute(0, 2, 1, 3) |
37 | 58 |
|
38 | | - return o |
| 59 | +def custom_kernel(data: input_t) -> output_t: |
| 60 | + q, k, v_new, h, g = data |
| 61 | + scale = q.shape[-1] ** -0.5 |
| 62 | + return gated_chunk_attn(q, k, v_new, h, g, scale) |
0 commit comments