Skip to content

Commit 6c7120f

Browse files
authored
Remove initial_state from gated_deltanet chunk_fwd_h problem (#110)
initial_state is an inference-only feature (multi-turn/streaming) and not used during training. Simplify the problem to always start from zeros, matching the typical training workload.
1 parent eea891f commit 6c7120f

4 files changed

Lines changed: 22 additions & 28 deletions

File tree

problems/helion/gated_deltanet_chunk_fwd_h_py/reference.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,19 @@
55
CHUNK_SIZE = 64
66

77

8-
def generate_input(B: int, T: int, H: int, K: int, V: int, use_initial_state: bool, seed: int) -> input_t:
8+
def generate_input(B: int, T: int, H: int, K: int, V: int, seed: int) -> input_t:
99
gen = torch.Generator(device="cuda")
1010
gen.manual_seed(seed)
1111
k = torch.randn(B, T, H, K, dtype=torch.float32, device="cuda", generator=gen).contiguous()
1212
w = torch.randn(B, T, H, K, dtype=torch.float32, device="cuda", generator=gen).contiguous()
1313
u = torch.randn(B, T, H, V, dtype=torch.float32, device="cuda", generator=gen).contiguous()
1414
# Use negative values for g to keep exp(g) bounded in (0, 1] and prevent overflow
1515
g = -torch.abs(torch.randn(B, T, H, dtype=torch.float32, device="cuda", generator=gen)).contiguous()
16-
if use_initial_state:
17-
initial_state = torch.randn(B, H, K, V, dtype=torch.float32, device="cuda", generator=gen).contiguous()
18-
else:
19-
initial_state = torch.zeros(B, H, K, V, dtype=torch.float32, device="cuda").contiguous()
20-
return k, w, u, g, initial_state
16+
return k, w, u, g
2117

2218

2319
def ref_kernel(data: input_t) -> output_t:
24-
k, w, u, g, initial_state = data
20+
k, w, u, g = data
2521
B, T, H, K = k.shape
2622
V = u.shape[-1]
2723
BT = CHUNK_SIZE
@@ -32,7 +28,7 @@ def ref_kernel(data: input_t) -> output_t:
3228

3329
for b in range(B):
3430
for hh in range(H):
35-
b_h = initial_state[b, hh].float().clone() # [K, V]
31+
b_h = torch.zeros(K, V, dtype=torch.float32, device=k.device)
3632

3733
for c in range(NT):
3834
cs = c * BT

problems/helion/gated_deltanet_chunk_fwd_h_py/submission.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
def custom_kernel(data: input_t) -> output_t:
55
import torch
66

7-
k, w, u, g, initial_state = data
7+
k, w, u, g = data
88
B, T, H, K = k.shape
99
V = u.shape[-1]
1010
BT = 64
@@ -15,7 +15,7 @@ def custom_kernel(data: input_t) -> output_t:
1515

1616
for b in range(B):
1717
for hh in range(H):
18-
b_h = initial_state[b, hh].float().clone()
18+
b_h = torch.zeros(K, V, dtype=torch.float32, device=k.device)
1919

2020
for c in range(NT):
2121
cs = c * BT
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TypedDict, TypeVar
22
import torch
33

4-
input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
4+
input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
55
output_t = TypeVar("output_t", bound=tuple[torch.Tensor, torch.Tensor])
66

77
class TestSpec(TypedDict):
@@ -10,5 +10,4 @@ class TestSpec(TypedDict):
1010
H: int
1111
K: int
1212
V: int
13-
use_initial_state: bool
1413
seed: int

problems/helion/gated_deltanet_chunk_fwd_h_py/task.yml

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,19 @@ description: |
1717
The sequence is divided into chunks of BT=64 timesteps. Processing is sequential
1818
across chunks but parallel across (B, H) and within each chunk:
1919
20-
For each (b, h) pair, starting with h_state = initial_state[b, h] (zeros or provided):
20+
For each (b, h) pair, starting with h_state = zeros(K, V):
2121
For each chunk c = 0, 1, ..., NT-1:
2222
1. Store: h_out[b, c, h] = h_state
2323
2. Compute: v_new = u - w @ h_state
2424
3. Gate: v_gated[t] = v_new[t] * exp(g[last_t] - g[t])
2525
4. Decay: h_state = h_state * exp(g[last_t])
2626
5. Update: h_state = h_state + k^T @ v_gated
2727
28-
Input: tuple(k, w, u, g, initial_state) where:
28+
Input: tuple(k, w, u, g) where:
2929
- k: torch.Tensor of shape [B, T, H, K] (float32) — keys
3030
- w: torch.Tensor of shape [B, T, H, K] (float32) — WY-transformed keys
3131
- u: torch.Tensor of shape [B, T, H, V] (float32) — WY-transformed values
3232
- g: torch.Tensor of shape [B, T, H] (float32) — cumulative gate
33-
- initial_state: torch.Tensor of shape [B, H, K, V] (float32) — initial hidden state (zeros or random)
3433
3534
Output: tuple(h, v_new) where:
3635
- h: torch.Tensor of shape [B, NT, H, K, V] (float32) — per-chunk hidden states
@@ -39,7 +38,7 @@ description: |
3938
Constraint: T must be a multiple of 64. NT = T // 64.
4039
4140
See also: Helion examples/gdn_fwd_h.py for a related implementation
42-
(simpler variant that returns only h, without v_new or initial_state support).
41+
(simpler variant that returns only h, without v_new output).
4342
4443
config:
4544
main: "eval.py"
@@ -48,20 +47,20 @@ templates:
4847
Python: "../template.py"
4948

5049
tests:
51-
- {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "use_initial_state": false, "seed": 4242}
52-
- {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "use_initial_state": true, "seed": 5236}
53-
- {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "use_initial_state": false, "seed": 1001}
54-
- {"B": 1, "T": 64, "H": 1, "K": 128, "V": 128, "use_initial_state": true, "seed": 5531}
55-
- {"B": 2, "T": 128, "H": 2, "K": 100, "V": 100, "use_initial_state": true, "seed": 9173}
50+
- {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "seed": 4242}
51+
- {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "seed": 5236}
52+
- {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "seed": 1001}
53+
- {"B": 1, "T": 64, "H": 1, "K": 128, "V": 128, "seed": 5531}
54+
- {"B": 2, "T": 128, "H": 2, "K": 100, "V": 100, "seed": 9173}
5655

5756
benchmarks:
58-
- {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "use_initial_state": false, "seed": 31232}
59-
- {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "use_initial_state": true, "seed": 4052}
60-
- {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "use_initial_state": false, "seed": 2146}
61-
- {"B": 3, "T": 1024, "H": 4, "K": 100, "V": 100, "use_initial_state": true, "seed": 3129}
62-
- {"B": 4, "T": 1024, "H": 4, "K": 128, "V": 128, "use_initial_state": false, "seed": 54352}
63-
- {"B": 2, "T": 1536, "H": 4, "K": 128, "V": 128, "use_initial_state": true, "seed": 71234}
64-
- {"B": 4, "T": 2048, "H": 8, "K": 64, "V": 64, "use_initial_state": true, "seed": 82345}
57+
- {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "seed": 31232}
58+
- {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "seed": 4052}
59+
- {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "seed": 2146}
60+
- {"B": 3, "T": 1024, "H": 4, "K": 100, "V": 100, "seed": 3129}
61+
- {"B": 4, "T": 1024, "H": 4, "K": 128, "V": 128, "seed": 54352}
62+
- {"B": 2, "T": 1536, "H": 4, "K": 128, "V": 128, "seed": 71234}
63+
- {"B": 4, "T": 2048, "H": 8, "K": 64, "V": 64, "seed": 82345}
6564

6665
test_timeout: 180
6766
benchmark_timeout: 180

0 commit comments

Comments
 (0)