Skip to content

Commit 0e8fe0a

Browse files
authored
Remove FLA dependency from gated deltanet references (#124)
Replace all flash-linear-attention (FLA) imports and usage with inline PyTorch eager equivalents for utility functions and reference kernels. This removes the FLA install requirement while keeping identical behavior.
1 parent 4d8288a commit 0e8fe0a

3 files changed

Lines changed: 199 additions & 40 deletions

File tree

problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,89 @@
55

66
CHUNK_SIZE = 64
77

8-
# Use FLA's Triton kernels as reference (same Triton tl.dot as Helion)
9-
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h as fla_chunk_fwd_h
10-
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as fla_recompute_w_u_fwd
11-
from fla.ops.utils import chunk_local_cumsum, solve_tril
12-
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
8+
9+
def _chunk_local_cumsum_eager(g, chunk_size):
10+
B, T, H = g.shape
11+
C = chunk_size
12+
return g.float().reshape(B, T // C, C, H).cumsum(dim=2).reshape(B, T, H)
13+
14+
15+
def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
16+
B, T, H, K = k.shape
17+
C = chunk_size
18+
NT = T // C
19+
k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
20+
g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
21+
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
22+
kkt = k_c @ k_c.transpose(-1, -2)
23+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
24+
strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1)
25+
A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower
26+
return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)
27+
28+
29+
def _solve_tril_eager(A, output_dtype):
30+
B, T, H, C = A.shape
31+
NT = T // C
32+
A_mat = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4)
33+
eye = torch.eye(C, device=A.device).expand_as(A_mat)
34+
result = torch.linalg.solve_triangular(eye + A_mat, eye, upper=False)
35+
return result.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(output_dtype)
36+
37+
38+
def _recompute_w_u_fwd_eager(k, v, beta, A, g):
39+
B, T, H, K = k.shape
40+
V = v.shape[-1]
41+
C = A.shape[-1]
42+
NT = T // C
43+
k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
44+
v_c = v.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4)
45+
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
46+
g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
47+
A_c = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4)
48+
u_c = A_c @ (v_c * beta_c.unsqueeze(-1))
49+
w_c = A_c @ (k_c * (beta_c * torch.exp(g_c)).unsqueeze(-1))
50+
w = w_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, K).to(k.dtype)
51+
u = u_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(v.dtype)
52+
return w, u
1353

1454

1555
def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t:
1656
torch.manual_seed(seed)
1757
device = "cuda"
18-
# Generate pipeline-derived inputs: base inputs -> g_cumsum, A, w, u via FLA utilities
1958
k = F.normalize(torch.randn(B, T, H, K, dtype=torch.float32, device=device), p=2, dim=-1)
2059
v = torch.randn(B, T, H, V, dtype=torch.float32, device=device)
2160
beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
2261
g = F.logsigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
23-
g_cumsum = chunk_local_cumsum(g, chunk_size=CHUNK_SIZE)
24-
A = chunk_scaled_dot_kkt_fwd(k=k, g=g_cumsum, beta=beta, output_dtype=torch.float32)
25-
A = solve_tril(A=A, output_dtype=k.dtype)
26-
w, u = fla_recompute_w_u_fwd(k=k, v=v, beta=beta, A=A, g=g_cumsum)
62+
g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE)
63+
A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE)
64+
A = _solve_tril_eager(A=A, output_dtype=k.dtype)
65+
w, u = _recompute_w_u_fwd_eager(k=k, v=v, beta=beta, A=A, g=g_cumsum)
2766
return k.contiguous(), w.contiguous(), u.contiguous(), g_cumsum.contiguous()
2867

2968

3069
def ref_kernel(data: input_t) -> output_t:
3170
k, w, u, g = data
32-
h, v_new, _ = fla_chunk_fwd_h(
33-
k=k, w=w, u=u, g=g,
34-
initial_state=None,
35-
output_final_state=False,
36-
)
37-
return h, v_new
71+
B, T, H, K = k.shape
72+
V = u.shape[-1]
73+
C = CHUNK_SIZE
74+
NT = T // C
75+
k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
76+
w_c = w.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
77+
u_c = u.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4)
78+
g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
79+
h_all = torch.zeros(B, NT, H, K, V, dtype=torch.float32, device=k.device)
80+
v_new_c = torch.zeros_like(u_c)
81+
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=k.device)
82+
for c in range(NT):
83+
h_all[:, c] = h
84+
v_new_c[:, c] = u_c[:, c] - w_c[:, c] @ h
85+
g_last = g_c[:, c, :, -1]
86+
gate = torch.exp(g_last.unsqueeze(-1) - g_c[:, c])
87+
v_gated = v_new_c[:, c] * gate.unsqueeze(-1)
88+
h = h * torch.exp(g_last).unsqueeze(-1).unsqueeze(-1) + k_c[:, c].transpose(-1, -2) @ v_gated
89+
v_new_out = v_new_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(u.dtype)
90+
return h_all.to(k.dtype), v_new_out
3891

3992

4093
def check_implementation(data, output):

problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,108 @@
55

66
CHUNK_SIZE = 64
77

8-
# Use FLA's Triton kernels as reference (same Triton tl.dot as Helion)
9-
from fla.ops.common.chunk_o import chunk_fwd_o as fla_chunk_fwd_o
10-
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h as fla_chunk_fwd_h
11-
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as fla_recompute_w_u_fwd
12-
from fla.ops.utils import chunk_local_cumsum, solve_tril
13-
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
8+
9+
def _chunk_local_cumsum_eager(g, chunk_size):
10+
B, T, H = g.shape
11+
C = chunk_size
12+
return g.float().reshape(B, T // C, C, H).cumsum(dim=2).reshape(B, T, H)
13+
14+
15+
def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
16+
B, T, H, K = k.shape
17+
C = chunk_size
18+
NT = T // C
19+
k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
20+
g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
21+
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
22+
kkt = k_c @ k_c.transpose(-1, -2)
23+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
24+
strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1)
25+
A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower
26+
return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)
27+
28+
29+
def _solve_tril_eager(A, output_dtype):
30+
B, T, H, C = A.shape
31+
NT = T // C
32+
A_mat = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4)
33+
eye = torch.eye(C, device=A.device).expand_as(A_mat)
34+
result = torch.linalg.solve_triangular(eye + A_mat, eye, upper=False)
35+
return result.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(output_dtype)
36+
37+
38+
def _recompute_w_u_fwd_eager(k, v, beta, A, g):
39+
B, T, H, K = k.shape
40+
V = v.shape[-1]
41+
C = A.shape[-1]
42+
NT = T // C
43+
k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
44+
v_c = v.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4)
45+
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
46+
g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
47+
A_c = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4)
48+
u_c = A_c @ (v_c * beta_c.unsqueeze(-1))
49+
w_c = A_c @ (k_c * (beta_c * torch.exp(g_c)).unsqueeze(-1))
50+
w = w_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, K).to(k.dtype)
51+
u = u_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(v.dtype)
52+
return w, u
53+
54+
55+
def _chunk_fwd_h_eager(k, w, u, g):
56+
B, T, H, K = k.shape
57+
V = u.shape[-1]
58+
C = CHUNK_SIZE
59+
NT = T // C
60+
k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
61+
w_c = w.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
62+
u_c = u.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4)
63+
g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
64+
h_all = torch.zeros(B, NT, H, K, V, dtype=torch.float32, device=k.device)
65+
v_new_c = torch.zeros_like(u_c)
66+
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=k.device)
67+
for c in range(NT):
68+
h_all[:, c] = h
69+
v_new_c[:, c] = u_c[:, c] - w_c[:, c] @ h
70+
g_last = g_c[:, c, :, -1]
71+
gate = torch.exp(g_last.unsqueeze(-1) - g_c[:, c])
72+
v_gated = v_new_c[:, c] * gate.unsqueeze(-1)
73+
h = h * torch.exp(g_last).unsqueeze(-1).unsqueeze(-1) + k_c[:, c].transpose(-1, -2) @ v_gated
74+
v_new_out = v_new_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(u.dtype)
75+
return h_all.to(k.dtype), v_new_out
1476

1577

1678
def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t:
1779
torch.manual_seed(seed)
1880
device = "cuda"
19-
# Generate pipeline-derived inputs: base inputs -> g_cumsum, A, w, u, h, v_new via FLA utilities
2081
q = torch.randn(B, T, H, K, dtype=torch.float32, device=device)
2182
k = F.normalize(torch.randn(B, T, H, K, dtype=torch.float32, device=device), p=2, dim=-1)
2283
v = torch.randn(B, T, H, V, dtype=torch.float32, device=device)
2384
beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
2485
g = F.logsigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
25-
g_cumsum = chunk_local_cumsum(g, chunk_size=CHUNK_SIZE)
26-
A = chunk_scaled_dot_kkt_fwd(k=k, g=g_cumsum, beta=beta, output_dtype=torch.float32)
27-
A = solve_tril(A=A, output_dtype=k.dtype)
28-
w, u = fla_recompute_w_u_fwd(k=k, v=v, beta=beta, A=A, g=g_cumsum)
29-
h, v_new, _ = fla_chunk_fwd_h(k=k, w=w, u=u, g=g_cumsum, output_final_state=False)
86+
g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE)
87+
A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE)
88+
A = _solve_tril_eager(A=A, output_dtype=k.dtype)
89+
w, u = _recompute_w_u_fwd_eager(k=k, v=v, beta=beta, A=A, g=g_cumsum)
90+
h, v_new = _chunk_fwd_h_eager(k=k, w=w, u=u, g=g_cumsum)
3091
return q.contiguous(), k.contiguous(), v_new.contiguous(), h.contiguous(), g_cumsum.contiguous()
3192

3293

3394
def ref_kernel(data: input_t) -> output_t:
3495
q, k, v_new, h, g = data
35-
K = q.shape[-1]
96+
B, T, H, K = q.shape
97+
V = v_new.shape[-1]
98+
C = CHUNK_SIZE
99+
NT = T // C
36100
scale = K ** -0.5
37-
o = fla_chunk_fwd_o(q=q, k=k, v=v_new, h=h, g=g, scale=scale)
38-
return o
101+
q_c = q.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
102+
k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
103+
v_c = v_new.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4)
104+
g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
105+
o_inter = (q_c @ h.float()) * torch.exp(g_c).unsqueeze(-1)
106+
qk = q_c @ k_c.transpose(-1, -2) * torch.exp(g_c.unsqueeze(-1) - g_c.unsqueeze(-2))
107+
causal = torch.tril(torch.ones(C, C, device=q.device))
108+
o = (o_inter + (qk * causal) @ v_c) * scale
109+
return o.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(q.dtype)
39110

40111

41112
check_implementation = make_match_reference(ref_kernel, rtol=1e-2, atol=1e-2)

problems/helion/gated_deltanet_recompute_w_u_py/reference.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,64 @@
55

66
CHUNK_SIZE = 64
77

8-
# Use FLA's Triton kernels as reference (same Triton tl.dot as Helion)
9-
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as fla_recompute_w_u_fwd
10-
from fla.ops.utils import chunk_local_cumsum, solve_tril
11-
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
8+
9+
def _chunk_local_cumsum_eager(g, chunk_size):
10+
B, T, H = g.shape
11+
C = chunk_size
12+
return g.float().reshape(B, T // C, C, H).cumsum(dim=2).reshape(B, T, H)
13+
14+
15+
def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
16+
B, T, H, K = k.shape
17+
C = chunk_size
18+
NT = T // C
19+
k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
20+
g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
21+
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
22+
kkt = k_c @ k_c.transpose(-1, -2)
23+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
24+
strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1)
25+
A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower
26+
return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)
27+
28+
29+
def _solve_tril_eager(A, output_dtype):
30+
B, T, H, C = A.shape
31+
NT = T // C
32+
A_mat = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4)
33+
eye = torch.eye(C, device=A.device).expand_as(A_mat)
34+
result = torch.linalg.solve_triangular(eye + A_mat, eye, upper=False)
35+
return result.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(output_dtype)
1236

1337

1438
def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t:
1539
torch.manual_seed(seed)
1640
device = "cuda"
17-
# Generate pipeline-derived inputs: base inputs -> g_cumsum, A via FLA utilities
1841
k = F.normalize(torch.randn(B, T, H, K, dtype=torch.float32, device=device), p=2, dim=-1)
1942
v = torch.randn(B, T, H, V, dtype=torch.float32, device=device)
2043
beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
2144
g = F.logsigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
22-
g_cumsum = chunk_local_cumsum(g, chunk_size=CHUNK_SIZE)
23-
A = chunk_scaled_dot_kkt_fwd(k=k, g=g_cumsum, beta=beta, output_dtype=torch.float32)
24-
A = solve_tril(A=A, output_dtype=k.dtype)
45+
g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE)
46+
A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE)
47+
A = _solve_tril_eager(A=A, output_dtype=k.dtype)
2548
return k.contiguous(), v.contiguous(), beta.contiguous(), A.contiguous(), g_cumsum.contiguous()
2649

2750

2851
def ref_kernel(data: input_t) -> output_t:
2952
k, v, beta, A, g = data
30-
w, u = fla_recompute_w_u_fwd(k=k, v=v, beta=beta, A=A, g=g)
53+
B, T, H, K = k.shape
54+
V = v.shape[-1]
55+
C = A.shape[-1]
56+
NT = T // C
57+
k_c = k.float().reshape(B, NT, C, H, K).permute(0, 1, 3, 2, 4)
58+
v_c = v.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4)
59+
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
60+
g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
61+
A_c = A.float().reshape(B, NT, C, H, C).permute(0, 1, 3, 2, 4)
62+
u_c = A_c @ (v_c * beta_c.unsqueeze(-1))
63+
w_c = A_c @ (k_c * (beta_c * torch.exp(g_c)).unsqueeze(-1))
64+
w = w_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, K).to(k.dtype)
65+
u = u_c.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(v.dtype)
3166
return w, u
3267

3368

0 commit comments

Comments
 (0)