@@ -17,20 +17,19 @@ description: |
1717 The sequence is divided into chunks of BT=64 timesteps. Processing is sequential
1818 across chunks but parallel across (B, H) and within each chunk:
1919
20- For each (b, h) pair, starting with h_state = initial_state[b, h] (zeros or provided ):
20+ For each (b, h) pair, starting with h_state = zeros(K, V ):
2121 For each chunk c = 0, 1, ..., NT-1:
2222 1. Store: h_out[b, c, h] = h_state
2323 2. Compute: v_new = u - w @ h_state
2424 3. Gate: v_gated[t] = v_new[t] * exp(g[last_t] - g[t])
2525 4. Decay: h_state = h_state * exp(g[last_t])
2626 5. Update: h_state = h_state + k^T @ v_gated
2727
28- Input: tuple(k, w, u, g, initial_state ) where:
28+ Input: tuple(k, w, u, g) where:
2929 - k: torch.Tensor of shape [B, T, H, K] (float32) — keys
3030 - w: torch.Tensor of shape [B, T, H, K] (float32) — WY-transformed keys
3131 - u: torch.Tensor of shape [B, T, H, V] (float32) — WY-transformed values
3232 - g: torch.Tensor of shape [B, T, H] (float32) — cumulative gate
33- - initial_state: torch.Tensor of shape [B, H, K, V] (float32) — initial hidden state (zeros or random)
3433
3534 Output: tuple(h, v_new) where:
3635 - h: torch.Tensor of shape [B, NT, H, K, V] (float32) — per-chunk hidden states
@@ -39,7 +38,7 @@ description: |
3938 Constraint: T must be a multiple of 64. NT = T // 64.
4039
4140 See also: Helion examples/gdn_fwd_h.py for a related implementation
42- (simpler variant that returns only h, without v_new or initial_state support ).
41+ (simpler variant that returns only h, without v_new output ).
4342
4443config :
4544 main : " eval.py"
@@ -48,20 +47,20 @@ templates:
4847 Python : " ../template.py"
4948
5049tests :
51- - {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "use_initial_state": false, " seed": 4242}
52- - {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "use_initial_state": true, " seed": 5236}
53- - {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "use_initial_state": false, " seed": 1001}
54- - {"B": 1, "T": 64, "H": 1, "K": 128, "V": 128, "use_initial_state": true, " seed": 5531}
55- - {"B": 2, "T": 128, "H": 2, "K": 100, "V": 100, "use_initial_state": true, " seed": 9173}
50+ - {"B": 1, "T": 64, "H": 2, "K": 64, "V": 64, "seed": 4242}
51+ - {"B": 2, "T": 128, "H": 4, "K": 64, "V": 64, "seed": 5236}
52+ - {"B": 1, "T": 256, "H": 4, "K": 64, "V": 128, "seed": 1001}
53+ - {"B": 1, "T": 64, "H": 1, "K": 128, "V": 128, "seed": 5531}
54+ - {"B": 2, "T": 128, "H": 2, "K": 100, "V": 100, "seed": 9173}
5655
5756benchmarks :
58- - {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "use_initial_state": false, " seed": 31232}
59- - {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "use_initial_state": true, " seed": 4052}
60- - {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "use_initial_state": false, " seed": 2146}
61- - {"B": 3, "T": 1024, "H": 4, "K": 100, "V": 100, "use_initial_state": true, " seed": 3129}
62- - {"B": 4, "T": 1024, "H": 4, "K": 128, "V": 128, "use_initial_state": false, " seed": 54352}
63- - {"B": 2, "T": 1536, "H": 4, "K": 128, "V": 128, "use_initial_state": true, " seed": 71234}
64- - {"B": 4, "T": 2048, "H": 8, "K": 64, "V": 64, "use_initial_state": true, " seed": 82345}
57+ - {"B": 1, "T": 64, "H": 1, "K": 64, "V": 64, "seed": 31232}
58+ - {"B": 2, "T": 512, "H": 3, "K": 64, "V": 64, "seed": 4052}
59+ - {"B": 2, "T": 1024, "H": 3, "K": 64, "V": 64, "seed": 2146}
60+ - {"B": 3, "T": 1024, "H": 4, "K": 100, "V": 100, "seed": 3129}
61+ - {"B": 4, "T": 1024, "H": 4, "K": 128, "V": 128, "seed": 54352}
62+ - {"B": 2, "T": 1536, "H": 4, "K": 128, "V": 128, "seed": 71234}
63+ - {"B": 4, "T": 2048, "H": 8, "K": 64, "V": 64, "seed": 82345}
6564
6665test_timeout : 180
6766benchmark_timeout : 180
0 commit comments