Skip to content

Commit 2ee400c

Browse files
authored
Fix numerically unstable test inputs in gated deltanet references (#127)
Scale k by 1/sqrt(K) and make gates cumulative across all three gated deltanet references to prevent exponential state growth in the recurrence, which caused correctness check failures due to floating-point reduction ordering differences between backends.
1 parent 0e8fe0a commit 2ee400c

3 files changed

Lines changed: 9 additions & 9 deletions

File tree

problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.nn.functional as F
32
from task import input_t, output_t
43
from utils import verbose_allclose
54

@@ -55,10 +54,11 @@ def _recompute_w_u_fwd_eager(k, v, beta, A, g):
5554
def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t:
5655
torch.manual_seed(seed)
5756
device = "cuda"
58-
k = F.normalize(torch.randn(B, T, H, K, dtype=torch.float32, device=device), p=2, dim=-1)
57+
k = torch.randn(B, T, H, K, dtype=torch.float32, device=device) / K**0.5
5958
v = torch.randn(B, T, H, V, dtype=torch.float32, device=device)
6059
beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
61-
g = F.logsigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
60+
g_inc = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device=device))
61+
g = g_inc.cumsum(dim=1)
6262
g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE)
6363
A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE)
6464
A = _solve_tril_eager(A=A, output_dtype=k.dtype)

problems/helion/gated_deltanet_chunk_fwd_o_py/reference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.nn.functional as F
32
from task import input_t, output_t
43
from utils import make_match_reference
54

@@ -79,10 +78,11 @@ def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t
7978
torch.manual_seed(seed)
8079
device = "cuda"
8180
q = torch.randn(B, T, H, K, dtype=torch.float32, device=device)
82-
k = F.normalize(torch.randn(B, T, H, K, dtype=torch.float32, device=device), p=2, dim=-1)
81+
k = torch.randn(B, T, H, K, dtype=torch.float32, device=device) / K**0.5
8382
v = torch.randn(B, T, H, V, dtype=torch.float32, device=device)
8483
beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
85-
g = F.logsigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
84+
g_inc = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device=device))
85+
g = g_inc.cumsum(dim=1)
8686
g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE)
8787
A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE)
8888
A = _solve_tril_eager(A=A, output_dtype=k.dtype)

problems/helion/gated_deltanet_recompute_w_u_py/reference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.nn.functional as F
32
from task import input_t, output_t
43
from utils import verbose_allclose
54

@@ -38,10 +37,11 @@ def _solve_tril_eager(A, output_dtype):
3837
def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t:
3938
torch.manual_seed(seed)
4039
device = "cuda"
41-
k = F.normalize(torch.randn(B, T, H, K, dtype=torch.float32, device=device), p=2, dim=-1)
40+
k = torch.randn(B, T, H, K, dtype=torch.float32, device=device) / K**0.5
4241
v = torch.randn(B, T, H, V, dtype=torch.float32, device=device)
4342
beta = torch.sigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
44-
g = F.logsigmoid(torch.randn(B, T, H, dtype=torch.float32, device=device))
43+
g_inc = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device=device))
44+
g = g_inc.cumsum(dim=1)
4545
g_cumsum = _chunk_local_cumsum_eager(g, chunk_size=CHUNK_SIZE)
4646
A = _chunk_scaled_dot_kkt_fwd_eager(k=k, g_cumsum=g_cumsum, beta=beta, chunk_size=CHUNK_SIZE)
4747
A = _solve_tril_eager(A=A, output_dtype=k.dtype)

0 commit comments

Comments
 (0)