Skip to content

Commit a611712

Browse files
yf225claude
andcommitted
Fix NaN in _chunk_scaled_dot_kkt_fwd_eager across all 3 gated deltanet kernels
Zero out g_diff outside the strict lower triangle before calling exp(), preventing inf * 0 = NaN when upper-triangle g differences overflow. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 24b5bce commit a611712

3 files changed

Lines changed: 6 additions & 3 deletions

File tree

problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
1919
g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2020
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2121
kkt = k_c @ k_c.transpose(-1, -2)
22-
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
2322
strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1)
23+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
24+
g_diff = g_diff * strict_lower
2425
A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower
2526
return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)
2627

problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
1919
g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2020
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2121
kkt = k_c @ k_c.transpose(-1, -2)
22-
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
2322
strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1)
23+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
24+
g_diff = g_diff * strict_lower
2425
A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower
2526
return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)
2627

problems/helion/gated_deltanet_recompute_w_u_py/reference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
1919
g_c = g_cumsum.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2020
beta_c = beta.float().reshape(B, NT, C, H).permute(0, 1, 3, 2)
2121
kkt = k_c @ k_c.transpose(-1, -2)
22-
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
2322
strict_lower = torch.tril(torch.ones(C, C, device=k.device), diagonal=-1)
23+
g_diff = g_c.unsqueeze(-1) - g_c.unsqueeze(-2)
24+
g_diff = g_diff * strict_lower
2425
A = kkt * beta_c.unsqueeze(-1) * torch.exp(g_diff) * strict_lower
2526
return A.permute(0, 1, 3, 2, 4).reshape(B, T, H, C).to(torch.float32)
2627

0 commit comments

Comments
 (0)