|
9 | 9 | # Autotune locally for each shape, then paste the best config here. |
10 | 10 | SHAPE_CONFIGS: dict[tuple, helion.Config] = { |
11 | 11 | # Test shapes |
12 | | - (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check |
13 | | - (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check |
14 | | - (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check |
| 12 | + (1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check |
| 13 | + (2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check |
| 14 | + (1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check |
15 | 15 | # Benchmark shapes |
16 | | - (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config |
17 | | - (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config |
18 | | - (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config |
19 | | - (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config |
20 | | - (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config |
21 | | - (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config |
22 | | - (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config |
| 16 | + (1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config |
| 17 | + (2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config |
| 18 | + (2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config |
| 19 | + (3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config |
| 20 | + (4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config |
| 21 | + (2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config |
| 22 | + (4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config |
23 | 23 | } |
24 | 24 |
|
25 | 25 |
|
@@ -57,19 +57,13 @@ def kernel( |
57 | 57 | q_s = q[b_idx, tile_t, h_idx, :] * torch.exp(g_vals)[:, None] |
58 | 58 | k_s = k[b_idx, tile_t, h_idx, :] * torch.exp(-g_vals)[:, None] |
59 | 59 |
|
60 | | - sim1 = hl.dot(q_s, k_s.T) |
61 | | - sim2 = hl.dot(q_s, k_s.T) |
62 | | - sim = (sim1 + sim2) * 0.5 |
| 60 | + sim = hl.dot(q_s, k_s.T) |
63 | 61 | idx = hl.arange(tile_t.block_size) |
64 | 62 | mask = idx[:, None] >= idx[None, :] |
65 | 63 | sim = torch.where(mask, sim, 0.0) |
66 | | - local1 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) |
67 | | - local2 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) |
68 | | - local_out = (local1 + local2) * 0.5 |
| 64 | + local_out = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :]) |
69 | 65 |
|
70 | | - glob1 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) |
71 | | - glob2 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) |
72 | | - global_out = (glob1 + glob2) * 0.5 |
| 66 | + global_out = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :]) |
73 | 67 |
|
74 | 68 | out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype) |
75 | 69 |
|
|
0 commit comments