Skip to content

Commit d773d99

Browse files
yf225claude
andauthored
Optimize gated deltanet chunk_fwd_o helion kernel (#130)
Remove redundant duplicate dot products and increase warps from 1 to 8. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 206d550 commit d773d99

1 file changed

Lines changed: 13 additions & 19 deletions

File tree

problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@
99
# Autotune locally for each shape, then paste the best config here.
1010
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
1111
# Test shapes
12-
(1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
13-
(2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
14-
(1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
12+
(1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check
13+
(2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check
14+
(1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: use any config that passes correctness check
1515
# Benchmark shapes
16-
(1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
17-
(2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
18-
(2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
19-
(3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
20-
(4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
21-
(2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
22-
(4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
16+
(1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
17+
(2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
18+
(2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
19+
(3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
20+
(4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
21+
(2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
22+
(4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=8, num_stages=2), # TODO: replace with your autotuned config
2323
}
2424

2525

@@ -57,19 +57,13 @@ def kernel(
5757
q_s = q[b_idx, tile_t, h_idx, :] * torch.exp(g_vals)[:, None]
5858
k_s = k[b_idx, tile_t, h_idx, :] * torch.exp(-g_vals)[:, None]
5959

60-
sim1 = hl.dot(q_s, k_s.T)
61-
sim2 = hl.dot(q_s, k_s.T)
62-
sim = (sim1 + sim2) * 0.5
60+
sim = hl.dot(q_s, k_s.T)
6361
idx = hl.arange(tile_t.block_size)
6462
mask = idx[:, None] >= idx[None, :]
6563
sim = torch.where(mask, sim, 0.0)
66-
local1 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :])
67-
local2 = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :])
68-
local_out = (local1 + local2) * 0.5
64+
local_out = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :])
6965

70-
glob1 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :])
71-
glob2 = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :])
72-
global_out = (glob1 + glob2) * 0.5
66+
global_out = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :])
7367

7468
out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype)
7569

0 commit comments

Comments
 (0)