Skip to content

Commit 24b5bce

Browse files
yf225claude
andcommitted
Fix NaN in gated_deltanet_chunk_fwd_o reference and submission
The reference kernel computed exp(g_i - g_j) before applying the causal mask. When g values are very negative (cumulative sums of negative increments), the upper-triangle differences g_i - g_j overflow exp() to inf, and inf * 0 (causal mask) produces NaN. Fix: zero out g_diff in the upper triangle before calling exp(), so we never compute exp(large_positive). Apply the same fix in the submission kernel which had a similar issue with exp(-g) overflowing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5465b80 commit 24b5bce

2 files changed

Lines changed: 16 additions & 9 deletions

File tree

problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,11 @@ def ref_kernel(data: input_t) -> output_t:
103103
v_c = v_new.float().reshape(B, NT, C, H, V).permute(0, 1, 3, 2, 4)
104104
g_c = g.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
105105
o_inter = (q_c @ h.float()) * torch.exp(g_c).unsqueeze(-1)
106-
qk = q_c @ k_c.transpose(-1, -2) * torch.exp(g_c.unsqueeze(-1) - g_c.unsqueeze(-2))
107-
causal = torch.tril(torch.ones(C, C, device=q.device))
108-
o = (o_inter + (qk * causal) @ v_c) * scale
106+
causal = torch.tril(torch.ones(C, C, dtype=torch.bool, device=q.device))
107+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
108+
g_diff = torch.where(causal, g_diff, torch.zeros_like(g_diff))
109+
qk = q_c @ k_c.transpose(-1, -2) * torch.exp(g_diff) * causal
110+
o = (o_inter + qk @ v_c) * scale
109111
return o.permute(0, 1, 3, 2, 4).reshape(B, T, H, V).to(q.dtype)
110112

111113

problems/helion/gated_deltanet_chunk_fwd_o_py/submission.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,20 @@ def kernel(
5454
c_idx = tile_t.begin // C
5555

5656
g_vals = g[b_idx, tile_t, h_idx]
57-
q_s = q[b_idx, tile_t, h_idx, :] * torch.exp(g_vals)[:, None]
58-
k_s = k[b_idx, tile_t, h_idx, :] * torch.exp(-g_vals)[:, None]
57+
q_tile = q[b_idx, tile_t, h_idx, :]
58+
k_tile = k[b_idx, tile_t, h_idx, :]
59+
v_tile = v[b_idx, tile_t, h_idx, :]
5960

60-
sim = hl.dot(q_s, k_s.T)
61+
# intra-chunk: q @ k^T * exp(g_i - g_j), with causal mask
62+
qk = hl.dot(q_tile, k_tile.T)
6163
idx = hl.arange(tile_t.block_size)
62-
mask = idx[:, None] >= idx[None, :]
63-
sim = torch.where(mask, sim, 0.0)
64-
local_out = hl.dot(sim.to(v.dtype), v[b_idx, tile_t, h_idx, :])
64+
g_diff = g_vals[:, None] - g_vals[None, :]
65+
causal_mask = idx[:, None] >= idx[None, :]
66+
sim = torch.where(causal_mask, qk * torch.exp(g_diff), 0.0)
67+
local_out = hl.dot(sim.to(v.dtype), v_tile)
6568

69+
# inter-chunk: (q @ h) * exp(g)
70+
q_s = q_tile * torch.exp(g_vals)[:, None]
6671
global_out = hl.dot(q_s, h[b_idx, c_idx, h_idx, :, :])
6772

6873
out[b_idx, tile_t, h_idx, :] = ((global_out + local_out) * scale).to(out.dtype)

0 commit comments

Comments
 (0)