Skip to content

Commit 55af3b9

Browse files
author
Mark Saroufim
committed
Sync all amd_202602 problems with AMD-AIM upstream
- moe-mxfp4: updated test cases and benchmark configs (PRs #11-#14) - mixed-mla: added tp parameter, updated reference/submission/task (PR #15) - mxfp4-mm: updated benchmark configs - Removed known non-determinism disclaimer (fixed by quantization changes)
1 parent bc2bc28 commit 55af3b9

6 files changed

Lines changed: 70 additions & 64 deletions

File tree

problems/amd_202602/mixed-mla/reference.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
output v_head_dim = kv_lora_rank = 512.
77
88
The input provides:
9-
q: (total_q, 16, 576) bfloat16 — absorbed query
9+
q: (total_q, num_heads, 576) bfloat16 — absorbed query (num_heads = 128 // tp)
1010
kv_data: dict with KV cache in three formats:
1111
"bf16": Tensor (total_kv, 1, 576) bfloat16 — highest precision
1212
"fp8": (Tensor, Tensor) kv_buffer fp8 + scalar scale — per-tensor quantized
@@ -37,7 +37,7 @@
3737
# DeepSeek R1 latent MQA constants (forward_absorb path)
3838
# https://huggingface.co/deepseek-ai/DeepSeek-R1-0528/blob/main/config.json
3939
# ---------------------------------------------------------------------------
40-
NUM_HEADS = 16
40+
TOTAL_NUM_HEADS = 128
4141
NUM_KV_HEADS = 1
4242
KV_LORA_RANK = 512
4343
QK_ROPE_HEAD_DIM = 64
@@ -285,17 +285,23 @@ def _aiter_mla_decode(
285285
# generate_input / ref_kernel / check_implementation
286286
# ---------------------------------------------------------------------------
287287

288-
def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> input_t:
288+
def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, tp: int, seed: int) -> input_t:
289289
"""
290290
Generate absorbed q and compressed kv_buffer for MLA decode.
291291
292+
Args:
293+
tp: tensor parallelism degree (4 or 8). num_heads = TOTAL_NUM_HEADS // tp.
294+
292295
Returns all three KV cache formats in kv_data dict:
293296
kv_data = {
294297
"bf16": Tensor — (total_kv, 1, 576) bfloat16
295298
"fp8": (Tensor, Tensor) — kv_buffer fp8 + scalar scale
296299
"mxfp4": (Tensor, Tensor) — kv_buffer fp4x2 + fp8_e8m0 scale
297300
}
298301
"""
302+
assert TOTAL_NUM_HEADS % tp == 0, f"TOTAL_NUM_HEADS ({TOTAL_NUM_HEADS}) must be divisible by tp ({tp})"
303+
num_heads = TOTAL_NUM_HEADS // tp
304+
299305
gen = torch.Generator(device="cuda")
300306
gen.manual_seed(seed)
301307

@@ -304,7 +310,7 @@ def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> in
304310

305311
# Absorbed query: (total_q, num_heads, 576) bf16
306312
q = torch.randn(
307-
(total_q, NUM_HEADS, QK_HEAD_DIM),
313+
(total_q, num_heads, QK_HEAD_DIM),
308314
dtype=torch.bfloat16, device="cuda", generator=gen,
309315
) * 0.02
310316

@@ -332,7 +338,7 @@ def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> in
332338

333339
config = {
334340
"batch_size": batchsize,
335-
"num_heads": NUM_HEADS,
341+
"num_heads": num_heads,
336342
"num_kv_heads": NUM_KV_HEADS,
337343
"qk_head_dim": QK_HEAD_DIM,
338344
"kv_lora_rank": KV_LORA_RANK,

problems/amd_202602/mixed-mla/submission.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Implement custom_kernel() to beat the aiter a8w8 reference (fp8 Q + fp8 KV).
55
66
DeepSeek R1 forward_absorb MLA config:
7-
num_heads = 16 (query heads, after TP split)
7+
total_num_heads = 128 (query heads before TP split)
8+
num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16)
89
num_kv_heads = 1 (shared latent KV head)
910
kv_lora_rank = 512 (latent dim)
1011
qk_rope_head_dim = 64 (RoPE dim)
@@ -17,24 +18,24 @@
1718
- First 512 dims (kv_lora_rank) used as values (for output computation)
1819
1920
Input tuple:
20-
q: (total_q, 16, 576) bfloat16 — absorbed query
21+
q: (total_q, num_heads, 576) bfloat16 — absorbed query
2122
kv_data: dict with three KV cache formats:
2223
kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16
2324
kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 (total_kv,1,576) + scalar scale
2425
kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 (total_kv,1,288) + fp8_e8m0 scale
2526
qo_indptr: (batch_size + 1,) int32 — query segment pointers
2627
kv_indptr: (batch_size + 1,) int32 — KV segment pointers
27-
config: dict with MLA parameters
28+
config: dict with MLA parameters (includes num_heads computed from tp)
2829
2930
Output:
30-
attention output: (total_q, 16, 512) bfloat16
31+
attention output: (total_q, num_heads, 512) bfloat16
3132
3233
The reference uses aiter's a8w8 persistent MLA kernel (fp8 Q + fp8 KV),
3334
which is ~2-3x faster than bf16. To beat it, consider:
3435
1. Use mxfp4 KV cache for even lower memory bandwidth
3536
- Fuse dequantization with attention to avoid bf16 materialization
3637
2. Custom kernel with tighter memory access patterns
37-
3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads
38+
3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads
3839
4. Variable-length batching: indptr-based segmented attention
3940
5. Split K/V from buffer: full 576 dims for keys, first 512 dims for values
4041
"""

problems/amd_202602/mixed-mla/task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
#
66
# Input: (q, kv_data, qo_indptr, kv_indptr, config)
77
# q: (total_q, num_heads, qk_head_dim) bfloat16
8+
# num_heads = 128 // tp (tp=4 → 32, tp=8 → 16)
89
# kv_data: dict with three KV cache formats:
910
# "bf16": Tensor (total_kv, 1, 576) bfloat16
1011
# "fp8": (Tensor, Tensor) kv_buffer fp8 (total_kv, 1, 576) + scalar scale
1112
# "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 (total_kv, 1, 288) + fp8_e8m0 scale
1213
# qo_indptr: (batch_size + 1,) int32
1314
# kv_indptr: (batch_size + 1,) int32
14-
# config: dict with MLA parameters
15+
# config: dict with MLA parameters (includes num_heads computed from tp)
1516
#
1617
# where qk_head_dim = kv_lora_rank + qk_rope_head_dim = 512 + 64 = 576
1718
#
@@ -33,4 +34,5 @@ class TestSpec(TypedDict):
3334
batchsize: int
3435
qseqlen: int
3536
kvseqlen: int
37+
tp: int
3638
seed: int

problems/amd_202602/mixed-mla/task.yml

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,39 +20,40 @@ description: |
2020
persistent mode), which is ~2-3x faster than bf16 on MI355X.
2121
2222
DeepSeek R1 forward_absorb MLA config:
23-
- num_heads = 16 (query heads, after TP split)
23+
- total_num_heads = 128 (query heads before TP split)
24+
- num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16)
2425
- num_kv_heads = 1 (shared latent KV head)
2526
- kv_lora_rank = 512
2627
- qk_rope_head_dim = 64
2728
- qk_head_dim = 576 (kv_lora_rank + qk_rope_head_dim, absorbed q/k dim)
2829
- v_head_dim = 512 (= kv_lora_rank, output dim)
2930
- sm_scale = 1/sqrt(576)
3031
- dtype: q=bfloat16
31-
- decode only (q_seq_len=1, kv_seq_len up to 8k)
32+
- q_seq_len = 1 or 4, kv_seq_len up to 8k
3233
3334
KV buffer format (forward_absorb):
3435
- Full 576 dims are used as keys (for Q@K^T score computation)
3536
- First 512 dims (kv_lora_rank) are used as values (for output computation)
3637
3738
Input tuple: (q, kv_data, qo_indptr, kv_indptr, config)
38-
- q: (total_q, 16, 576) bfloat16 — absorbed query
39+
- q: (total_q, num_heads, 576) bfloat16 — absorbed query
3940
- kv_data: dict with three KV cache formats:
4041
kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16
4142
kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 + scalar scale
4243
kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 + fp8_e8m0 scale
4344
- qo_indptr: (batch_size+1,) int32 — query segment pointers
4445
- kv_indptr: (batch_size+1,) int32 — KV segment pointers
45-
- config: dict with MLA parameters
46+
- config: dict with MLA parameters (includes num_heads computed from tp)
4647
4748
Return:
48-
- attention output: (total_q, 16, 512) bfloat16
49+
- attention output: (total_q, num_heads, 512) bfloat16
4950
5051
Key optimization opportunities:
5152
1. Use mxfp4 KV cache for even lower memory bandwidth (4x savings over bf16)
5253
- Fuse dequantization with attention to skip bf16 materialization
5354
2. Custom kernel with tighter memory access patterns
54-
3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads
55-
4. Decode: q_seq_len=1, kv_seq_len up to 8k — memory-bound workload
55+
3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads
56+
4. q_seq_len=1 or 4, kv_seq_len up to 8k — memory-bound workload
5657
5. Variable-length batching: indptr-based segmented attention
5758
6. Split K/V from buffer: full 576 dims for keys, first 512 dims for values
5859
@@ -69,27 +70,29 @@ benchmark_timeout: 900
6970
ranked_timeout: 1200
7071

7172
tests:
72-
# bs=4
73-
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4220}
74-
# bs=32
75-
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412}
76-
# bs=64
77-
- {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360}
78-
# bs=256
79-
- {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826}
73+
# bs=4, tp=8
74+
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 8, "seed": 4220}
75+
- {"batchsize": 4, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 4231}
76+
# bs=32, tp=4
77+
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5412}
78+
- {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5423}
79+
# bs=128, tp=8
80+
- {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816}
81+
- {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 7827}
8082

8183
benchmarks:
82-
# bs=4
83-
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4217}
84-
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 8192, "seed": 4220}
85-
# bs=32
86-
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412}
87-
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "seed": 5415}
88-
# bs=64
89-
- {"batchsize": 64, "qseqlen": 1, "kvseqlen": 1024, "seed": 1357}
90-
- {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360}
91-
# bs=256
92-
- {"batchsize": 256, "qseqlen": 1, "kvseqlen": 1024, "seed": 9823}
93-
- {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826}
84+
# bs=4, tp=4
85+
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 4237}
86+
- {"batchsize": 4, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 4251}
87+
# bs=32, tp=8
88+
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 5415}
89+
- {"batchsize": 32, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 5420}
90+
# bs=32, tp=4
91+
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5432}
92+
- {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5443}
93+
# bs=128, tp=8
94+
- {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816}
95+
- {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 8, "seed": 7824}
96+
9497

9598
ranking_by: "geom"

problems/amd_202602/moe-mxfp4/task.yml

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,18 @@ description: |
6565
6666
d_hidden_pad and d_expert_pad are the dimensions padded to 256-alignment for the CK kernel.
6767
68-
**Known issue:** The reference submission (which calls aiter's fused_moe) is non-deterministic
69-
on MI355X — it does not pass correctness checks against itself. This appears to be an aiter
70-
fused_moe kernel bug on gfx950. Submissions will be evaluated on benchmark performance only
71-
until this is resolved.
72-
7368
The ranking criteria is the geometric mean of the benchmark results.
7469
7570
```
7671
The AITER reference performance is (E includes shared expert, top_k = routed + shared):
7772
bs E d_hidden d_expert top_k time[us]
78-
4 257 7168 256 9 46.9
79-
64 257 7168 256 9 187.7
80-
256 257 7168 256 9 245.7
81-
64 33 7168 2048 9 220.6
82-
256 33 7168 2048 9 276.4
83-
1024 33 7168 2048 9 572.2
73+
16 257 7168 256 9 152.7
74+
128 257 7168 256 9 239.0
75+
512 257 7168 256 9 336.5
76+
16 33 7168 512 9 106.2
77+
128 33 7168 512 9 141.1
78+
512 33 7168 512 9 225.0
79+
512 33 7168 2048 9 380.4
8480
```
8581
8682
Input:
@@ -112,16 +108,18 @@ ranked_timeout: 840
112108
ranking_by: "geom"
113109

114110
tests:
115-
- {"dhidden": 4096, "dexpert": 1024, "nroutedexperts": 16, "nexpertspertoken": 4, "nsharedexperts": 1, "bs": 8, "seed": 9371}
111+
- {"dhidden": 4096, "dexpert": 1024, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 8, "seed": 9371}
116112
- {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 32, "seed": 2291}
117113
- {"dhidden": 4096, "dexpert": 1536, "nroutedexperts": 64, "nexpertspertoken": 6, "nsharedexperts": 1, "bs": 128, "seed": 81934}
118114

119115
benchmarks:
120-
# EP off (all 257 experts on 1 GPU): E=257, top_k=9 (8 routed + 1 shared)
121-
- {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 4, "seed": 9371}
122-
- {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 64, "seed": 2291}
123-
- {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 256, "seed": 81934}
124-
# EP on (EP=8, 33 experts per GPU): E=33, top_k=9 (8 routed + 1 shared)
125-
- {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 64, "seed": 2291}
126-
- {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 256, "seed": 81934}
127-
- {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 1024, "seed": 81934}
116+
# TP=8
117+
- {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 16, "seed": 9371}
118+
- {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 128, "seed": 2291}
119+
- {"dhidden": 7168, "dexpert": 256, "nroutedexperts": 256, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934}
120+
# TP=4
121+
- {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 16, "seed": 2291}
122+
- {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 128, "seed": 81934}
123+
- {"dhidden": 7168, "dexpert": 512, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934}
124+
# EP on
125+
- {"dhidden": 7168, "dexpert": 2048, "nroutedexperts": 32, "nexpertspertoken": 8, "nsharedexperts": 1, "bs": 512, "seed": 81934}

problems/amd_202602/mxfp4-mm/task.yml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@ description: |
3636
4 2880 512 8.198
3737
16 2112 7168 20.873
3838
32 4096 512 9.462
39+
32 2880 512 9.173
3940
64 7168 2048 12.738
40-
64 2880 512 9.873
41-
128 2112 7168 27.284
4241
256 3072 1536 12.219
43-
256 7168 2048 13.506
4442
```
4543
config:
4644
main: "eval.py"
@@ -55,8 +53,6 @@ benchmarks:
5553
- {"m": 4, "n": 2880, "k": 512, "seed": 4565}
5654
- {"m": 16, "n": 2112, "k": 7168, "seed": 15}
5755
- {"m": 32, "n": 4096, "k": 512, "seed": 457}
56+
- {"m": 32, "n": 2880, "k": 512, "seed": 54}
5857
- {"m": 64, "n": 7168, "k": 2048, "seed": 687}
59-
- {"m": 64, "n": 2880, "k": 512, "seed": 54}
60-
- {"m": 128, "n": 2112, "k": 7168, "seed": 24}
6158
- {"m": 256, "n": 3072, "k": 1536, "seed": 7856}
62-
- {"m": 256, "n": 7168, "k": 2048, "seed": 223}

0 commit comments

Comments
 (0)