Skip to content

Commit ec71b91

Browse files
authored
Merge pull request #131 from yf225/fix-gated-deltanet-reference-nan
Fix NaN in all 3 gated deltanet Helion references and submission
2 parents 5465b80 + a611712 commit ec71b91

4 files changed

Lines changed: 22 additions & 12 deletions

File tree

problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
1919
g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2020
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2121
kkt = k_c @ k_c.transpose(-1, -2)
22-
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
2322
strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1)
23+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
24+
g_diff = g_diff * strict_lower
2425
A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower
2526
return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)
2627

problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
1919
g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2020
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2121
kkt = k_c @ k_c.transpose(-1, -2)
22-
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
2322
strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1)
23+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
24+
g_diff = g_diff * strict_lower
2425
A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower
2526
return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)
2627

@@ -103,9 +104,11 @@ def ref_kernel(data: input_t) -> output_t:
103104
v_c = v_new.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4)
104105
g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
105106
o_inter = (q_c @ h.float()) * torch.exp(g_c).unsqueeze(-1)
106-
qk = q_c @ k_c.transpose(-1, -2) * torch.exp(g_c.unsqueeze(-1) - g_c.unsqueeze(-2))
107-
causal = torch.tril(torch.ones(C, C, device=q.device))
108-
o = (o_inter + (qk * causal) @ v_c) * scale
107+
causal = torch.tril(torch.ones(C, C, dtype=torch.bool, device=q.device))
108+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
109+
g_diff = torch.where(causal, g_diff, torch.zeros_like(g_diff))
110+
qk = q_c @ k_c.transpose(-1, -2) * torch.exp(g_diff) * causal
111+
o = (o_inter + qk @ v_c) * scale
109112
return o.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(q.dtype)
110113

111114

problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,20 @@ def kernel(
5454
c_idx = tile_t.begin // C
5555

5656
g_vals = g[b_idx, tile_t, h_idx]
57-
q_s = q[b_idx, tile_t, h_idx, :] * torch.exp(g_vals)[:, None]
58-
k_s = k[b_idx, tile_t, h_idx, :] * torch.exp(-g_vals)[:, None]
57+
q_tile = q[b_idx, tile_t, h_idx, :]
58+
k_tile = k[b_idx, tile_t, h_idx, :]
59+
v_tile = v[b_idx, tile_t, h_idx, :]
5960

60-
sim = hl.dot(q_s, k_s.T)
61+
# intra-chunk: q @ k^T * exp(g_i - g_j), with causal mask
62+
qk = hl.dot(q_tile, k_tile.T)
6163
idx = hl.arange(tile_t.block_size)
62-
mask = idx[:, None] >= idx[None, :]
63-
sim = torch.where(mask, sim, 0.0)
64-
local_out = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :])
64+
g_diff = g_vals[:, None] - g_vals[None, :]
65+
causal_mask = idx[:, None] >= idx[None, :]
66+
sim = torch.where(causal_mask, qk * torch.exp(g_diff), 0.0)
67+
local_out = hl.dot(sim.to(v.dtype), v_tile)
6568

69+
# inter-chunk: (q @ h) * exp(g)
70+
q_s = q_tile * torch.exp(g_vals)[:, None]
6671
global_out = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :])
6772

6873
out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype)

problems/helion/gated_deltanet_recompute_w_u_py/reference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
1919
g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2020
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2121
kkt = k_c @ k_c.transpose(-1, -2)
22-
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
2322
strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1)
23+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
24+
g_diff = g_diff * strict_lower
2425
A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower
2526
return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)
2627

0 commit comments

Comments
 (0)