Skip to content

Commit ee9ef17

Browse files
yf225claude
andauthored
Add per-shape config dispatch pattern to all Helion submission.py files (#125)
* Add per-shape config dispatch pattern to all Helion submissions - Use SHAPE_CONFIGS dict mapping shape tuples to helion.Config objects - Factory pattern with _make_kernel() creates separate kernel instances per config - All test and benchmark shapes from task.yml listed in SHAPE_CONFIGS - Test shapes: TODO to replace with default config or any config that passes correctness - Benchmark shapes: TODO to replace with autotuned config Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add commented-out ACF usage hint to all Helion submissions Each submission now shows the matching ACF file path as a commented example so participants know which booster pack files to try. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Rename inner kernel functions to 'kernel' in all submissions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Replace placeholder helion.Config(...) with original baseline configs * Simplify TODO comments for test shape configs --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2ee400c commit ee9ef17

6 files changed

Lines changed: 378 additions & 222 deletions

File tree

problems/helion/causal_conv1d_py/submission.py

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,77 @@
55
import helion.language as hl
66

77

8+
# Per-shape configs: map (B, D, S, W) to optimized helion.Config objects.
9+
# Autotune locally for each shape, then paste the best config here.
10+
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
11+
# Test shapes
12+
(1, 64, 64, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
13+
(2, 128, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
14+
(1, 256, 256, 3): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
15+
(1, 128, 64, 8): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
16+
(4, 64, 128, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
17+
# Benchmark shapes
18+
(1, 768, 512, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
19+
(1, 768, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
20+
(1, 1536, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
21+
(1, 2560, 2048, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
22+
(1, 2560, 4096, 4): helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
23+
}
24+
25+
26+
# Optional: add advanced_controls_file to your Config for extra performance (see docs).
27+
# Autotune with autotune_search_acf to find the best ACF, then hardcode it:
28+
# helion.Config(..., advanced_controls_file="/opt/booster_pack/causal_conv_0.acf")
29+
30+
831
# NOTE: This is an intentionally inefficient baseline implementation.
9-
@helion.kernel(
10-
static_shapes=True,
11-
config=helion.Config(block_sizes=[1, 8], num_warps=1, num_stages=1),
12-
)
13-
def conv1d_kernel(
14-
x_pad: torch.Tensor, # (B, D, L) zero-padded input
15-
w: torch.Tensor, # (D, W) filter coefficients
16-
b: torch.Tensor, # (D,) additive offset
17-
) -> torch.Tensor:
18-
B = x_pad.size(0)
19-
D = x_pad.size(1)
20-
L = x_pad.size(2)
21-
W = hl.specialize(w.size(1))
22-
N = L - W + 1
23-
24-
y = torch.empty(B, D, N, dtype=x_pad.dtype, device=x_pad.device)
25-
26-
for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]):
27-
bi = rb.begin
28-
acc1 = hl.zeros([rd, rs], dtype=torch.float32)
29-
acc2 = hl.zeros([rd, rs], dtype=torch.float32)
30-
acc3 = hl.zeros([rd, rs], dtype=torch.float32)
31-
for j in range(W):
32-
c1 = w[rd, j].to(torch.float32)
33-
x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
34-
acc1 = acc1 + x1 * c1[:, None]
35-
c2 = w[rd, j].to(torch.float32)
36-
x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
37-
acc2 = acc2 + x2 * c2[:, None]
38-
c3 = w[rd, j].to(torch.float32)
39-
x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
40-
acc3 = acc3 + x3 * c3[:, None]
41-
acc = (acc1 + acc2 + acc3) / 3.0
42-
acc = acc + b[rd].to(torch.float32)[:, None]
43-
y[rb, rd, rs] = acc[None, :, :].to(y.dtype)
44-
45-
return y
32+
def _make_kernel(config: helion.Config):
33+
@helion.kernel(static_shapes=True, config=config)
34+
def kernel(
35+
x_pad: torch.Tensor, # (B, D, L) zero-padded input
36+
w: torch.Tensor, # (D, W) filter coefficients
37+
b: torch.Tensor, # (D,) additive offset
38+
) -> torch.Tensor:
39+
B = x_pad.size(0)
40+
D = x_pad.size(1)
41+
L = x_pad.size(2)
42+
W = hl.specialize(w.size(1))
43+
N = L - W + 1
44+
45+
y = torch.empty(B, D, N, dtype=x_pad.dtype, device=x_pad.device)
46+
47+
for rb, rd, rs in hl.tile([B, D, N], block_size=[1, None, None]):
48+
bi = rb.begin
49+
acc1 = hl.zeros([rd, rs], dtype=torch.float32)
50+
acc2 = hl.zeros([rd, rs], dtype=torch.float32)
51+
acc3 = hl.zeros([rd, rs], dtype=torch.float32)
52+
for j in range(W):
53+
c1 = w[rd, j].to(torch.float32)
54+
x1 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
55+
acc1 = acc1 + x1 * c1[:, None]
56+
c2 = w[rd, j].to(torch.float32)
57+
x2 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
58+
acc2 = acc2 + x2 * c2[:, None]
59+
c3 = w[rd, j].to(torch.float32)
60+
x3 = hl.load(x_pad, [bi, rd, rs.index + j]).to(torch.float32)
61+
acc3 = acc3 + x3 * c3[:, None]
62+
acc = (acc1 + acc2 + acc3) / 3.0
63+
acc = acc + b[rd].to(torch.float32)[:, None]
64+
y[rb, rd, rs] = acc[None, :, :].to(y.dtype)
65+
66+
return y
67+
68+
return kernel
69+
70+
71+
_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()}
4672

4773

4874
def custom_kernel(data: input_t) -> output_t:
4975
x, weight, bias = data
76+
B, D, S = x.shape
5077
W = weight.shape[1]
51-
pad_zeros = torch.zeros(x.shape[0], x.shape[1], W - 1, dtype=x.dtype, device=x.device)
78+
kernel = _KERNELS[(B, D, S, W)]
79+
pad_zeros = torch.zeros(B, D, W - 1, dtype=x.dtype, device=x.device)
5280
padded = torch.cat([pad_zeros, x], dim=2)
53-
return conv1d_kernel(padded, weight, bias)
81+
return kernel(padded, weight, bias)

problems/helion/fp8_quant_py/submission.py

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,52 +5,68 @@
55
import helion.language as hl
66
from pathlib import Path
77

8-
COFIG_DICT={
9-
"block_sizes": [1],
10-
"num_warps": 1,
11-
"num_stages": 1,
8+
9+
# Per-shape configs: map (num_tokens, hidden_dim, group_size) to optimized helion.Config objects.
10+
# Autotune locally for each shape, then paste the best config here.
11+
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
12+
# Test shapes
13+
(1, 256, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
14+
(4, 512, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
15+
(16, 1024, 64): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
16+
(1, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
17+
(8, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
18+
# Benchmark shapes
19+
# (1, 4096, 128) already covered above
20+
(16, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
21+
(256, 4096, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
22+
(256, 8192, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
23+
(4096, 7168, 128): helion.Config(block_sizes=[1], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
1224
}
1325

14-
ACF_FILE = "booster_pack/fp8_group_quant_0.acf"
15-
if Path(ACF_FILE).exists():
16-
print(f"Using ACF file: {ACF_FILE}")
17-
COFIG_DICT["advanced_controls_file"] = ACF_FILE
26+
27+
# Optional: add advanced_controls_file to your Config for extra performance (see docs).
28+
# Autotune with autotune_search_acf to find the best ACF, then hardcode it:
29+
# helion.Config(..., advanced_controls_file="/opt/booster_pack/fp8_group_quant_0.acf")
30+
1831

1932
# NOTE: This is an intentionally inefficient baseline implementation.
20-
@helion.kernel(
21-
static_shapes=True,
22-
config=helion.Config(**COFIG_DICT),
23-
)
24-
def normalize_to_range(
25-
data: torch.Tensor, # [N, G] input rows
26-
scales_out: torch.Tensor, # [N] output normalization factors
27-
) -> torch.Tensor:
28-
nrows = data.size(0)
29-
ncols = hl.specialize(data.size(1))
30-
MAX_VAL = 448.0
31-
32-
qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device)
33-
34-
for rr in hl.tile(nrows):
35-
row = data[rr, :].to(torch.float32)
36-
37-
abs1 = torch.abs(row)
38-
amax1 = torch.amax(abs1, -1)
39-
abs2 = torch.abs(row)
40-
amax2 = torch.amax(abs2, -1)
41-
abs3 = torch.abs(row)
42-
amax3 = torch.amax(abs3, -1)
43-
amax = (amax1 + amax2 + amax3) / 3.0
44-
amax = torch.clamp(amax, min=1e-10)
45-
scale = amax / MAX_VAL
46-
47-
q1 = row / scale[:, None]
48-
q2 = row / scale[:, None]
49-
q3 = row / scale[:, None]
50-
qout[rr, :] = (q1 + q2 + q3) / 3.0
51-
scales_out[rr] = scale
52-
53-
return qout
33+
def _make_kernel(config: helion.Config):
34+
@helion.kernel(static_shapes=True, config=config)
35+
def kernel(
36+
data: torch.Tensor, # [N, G] input rows
37+
scales_out: torch.Tensor, # [N] output normalization factors
38+
) -> torch.Tensor:
39+
nrows = data.size(0)
40+
ncols = hl.specialize(data.size(1))
41+
MAX_VAL = 448.0
42+
43+
qout = torch.empty(nrows, ncols, dtype=torch.float32, device=data.device)
44+
45+
for rr in hl.tile(nrows):
46+
row = data[rr, :].to(torch.float32)
47+
48+
abs1 = torch.abs(row)
49+
amax1 = torch.amax(abs1, -1)
50+
abs2 = torch.abs(row)
51+
amax2 = torch.amax(abs2, -1)
52+
abs3 = torch.abs(row)
53+
amax3 = torch.amax(abs3, -1)
54+
amax = (amax1 + amax2 + amax3) / 3.0
55+
amax = torch.clamp(amax, min=1e-10)
56+
scale = amax / MAX_VAL
57+
58+
q1 = row / scale[:, None]
59+
q2 = row / scale[:, None]
60+
q3 = row / scale[:, None]
61+
qout[rr, :] = (q1 + q2 + q3) / 3.0
62+
scales_out[rr] = scale
63+
64+
return qout
65+
66+
return kernel
67+
68+
69+
_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()}
5470

5571

5672
def custom_kernel(data: input_t) -> output_t:
@@ -60,10 +76,12 @@ def custom_kernel(data: input_t) -> output_t:
6076
gsz = H // G
6177
N = T * G
6278

79+
kernel = _KERNELS[(T, H, gsz)]
80+
6381
flat_in = x.reshape(N, gsz)
6482
flat_s = x_s.reshape(N)
6583

66-
flat_q = normalize_to_range(flat_in, flat_s)
84+
flat_q = kernel(flat_in, flat_s)
6785

6886
x_q[...] = flat_q.reshape(T, H)
6987
x_s[...] = flat_s.reshape(T, G)

problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py

Lines changed: 76 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,65 +5,93 @@
55
import helion.language as hl
66

77

8+
# Per-shape configs: map (B, T, H, K, V) to optimized helion.Config objects.
9+
# Autotune locally for each shape, then paste the best config here.
10+
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
11+
# Test shapes
12+
(1, 64, 2, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
13+
(2, 128, 4, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
14+
(1, 256, 4, 64, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: use any config that passes correctness check
15+
# Benchmark shapes
16+
(1, 64, 1, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
17+
(2, 512, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
18+
(2, 1024, 3, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
19+
(3, 1024, 4, 100, 100): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
20+
(4, 1024, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
21+
(2, 1536, 4, 128, 128): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
22+
(4, 2048, 8, 64, 64): helion.Config(block_sizes=[], num_warps=1, num_stages=1), # TODO: replace with your autotuned config
23+
}
24+
25+
26+
# Optional: add advanced_controls_file to your Config for extra performance (see docs).
27+
# Autotune with autotune_search_acf to find the best ACF, then hardcode it:
28+
# helion.Config(..., advanced_controls_file="/opt/booster_pack/chunk_fwd_h_0.acf")
29+
30+
831
# NOTE: This is an intentionally inefficient baseline implementation.
9-
@helion.kernel(
10-
static_shapes=True,
11-
dot_precision="ieee",
12-
config=helion.Config(block_sizes=[], num_warps=1, num_stages=1),
13-
)
14-
def chunk_state_pass(
15-
k: torch.Tensor, # [B, T, H, K]
16-
w: torch.Tensor, # [B, T, H, K]
17-
u: torch.Tensor, # [B, T, H, V]
18-
g: torch.Tensor, # [B, T, H]
19-
) -> tuple[torch.Tensor, torch.Tensor]:
20-
B, T, H, K = k.shape
21-
V = u.shape[-1]
22-
C = 64
23-
K = hl.specialize(K)
24-
V = hl.specialize(V)
32+
def _make_kernel(config: helion.Config):
33+
@helion.kernel(static_shapes=True, dot_precision="ieee", config=config)
34+
def kernel(
35+
k: torch.Tensor, # [B, T, H, K]
36+
w: torch.Tensor, # [B, T, H, K]
37+
u: torch.Tensor, # [B, T, H, V]
38+
g: torch.Tensor, # [B, T, H]
39+
) -> tuple[torch.Tensor, torch.Tensor]:
40+
B, T, H, K = k.shape
41+
V = u.shape[-1]
42+
C = 64
43+
K = hl.specialize(K)
44+
V = hl.specialize(V)
2545

26-
NT = (T + C - 1) // C
27-
h_out = torch.empty(B, NT, H, K, V, dtype=k.dtype, device=k.device)
28-
v_out = torch.empty_like(u)
46+
NT = (T + C - 1) // C
47+
h_out = torch.empty(B, NT, H, K, V, dtype=k.dtype, device=k.device)
48+
v_out = torch.empty_like(u)
2949

30-
BH = B * H
50+
BH = B * H
3151

32-
for flat, tv in hl.tile([BH, V], block_size=[1, 8]):
33-
b_idx = flat.begin // H
34-
h_idx = flat.begin % H
35-
state = hl.zeros([K, tv], dtype=torch.float32)
52+
for flat, tv in hl.tile([BH, V], block_size=[1, 8]):
53+
b_idx = flat.begin // H
54+
h_idx = flat.begin % H
55+
state = hl.zeros([K, tv], dtype=torch.float32)
3656

37-
for tc in hl.tile(T, block_size=C):
38-
chunk_idx = tc.begin // C
39-
t_end = min(tc.begin + C, T) - 1
57+
for tc in hl.tile(T, block_size=C):
58+
chunk_idx = tc.begin // C
59+
t_end = min(tc.begin + C, T) - 1
4060

41-
h_out[b_idx, chunk_idx, h_idx, :, tv] = state.to(k.dtype)
61+
h_out[b_idx, chunk_idx, h_idx, :, tv] = state.to(k.dtype)
4262

43-
proj1 = hl.dot(
44-
w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32
45-
)
46-
proj2 = hl.dot(
47-
w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32
48-
)
49-
proj = (proj1 + proj2) * 0.5
50-
diff = u[b_idx, tc, h_idx, tv].to(torch.float32) - proj
51-
v_out[b_idx, tc, h_idx, tv] = diff.to(u.dtype)
63+
proj1 = hl.dot(
64+
w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32
65+
)
66+
proj2 = hl.dot(
67+
w[b_idx, tc, h_idx, :], state, out_dtype=torch.float32
68+
)
69+
proj = (proj1 + proj2) * 0.5
70+
diff = u[b_idx, tc, h_idx, tv].to(torch.float32) - proj
71+
v_out[b_idx, tc, h_idx, tv] = diff.to(u.dtype)
5272

53-
g_end = g[b_idx, t_end, h_idx].to(torch.float32)
54-
g_t = g[b_idx, tc, h_idx].to(torch.float32)
55-
valid = tc.index < T
56-
alpha = torch.where(valid, torch.exp(g_end - g_t), 0.0)
57-
k_adj = k[b_idx, tc, h_idx, :] * alpha[:, None]
73+
g_end = g[b_idx, t_end, h_idx].to(torch.float32)
74+
g_t = g[b_idx, tc, h_idx].to(torch.float32)
75+
valid = tc.index < T
76+
alpha = torch.where(valid, torch.exp(g_end - g_t), 0.0)
77+
k_adj = k[b_idx, tc, h_idx, :] * alpha[:, None]
5878

59-
state = state * torch.exp(g_end)
60-
upd1 = hl.dot(k_adj.T, diff, out_dtype=torch.float32)
61-
upd2 = hl.dot(k_adj.T, diff, out_dtype=torch.float32)
62-
state = state + (upd1 + upd2) * 0.5
79+
state = state * torch.exp(g_end)
80+
upd1 = hl.dot(k_adj.T, diff, out_dtype=torch.float32)
81+
upd2 = hl.dot(k_adj.T, diff, out_dtype=torch.float32)
82+
state = state + (upd1 + upd2) * 0.5
6383

64-
return h_out, v_out
84+
return h_out, v_out
85+
86+
return kernel
87+
88+
89+
_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()}
6590

6691

6792
def custom_kernel(data: input_t) -> output_t:
6893
k, w, u, g = data
69-
return chunk_state_pass(k, w, u, g)
94+
B, T, H, K = k.shape
95+
V = u.shape[-1]
96+
kernel = _KERNELS[(B, T, H, K, V)]
97+
return kernel(k, w, u, g)

0 commit comments

Comments
 (0)