Skip to content

Commit a043b3b

Browse files
author
Mark Saroufim
committed
Revert mixed-mla to original config (no tp parameter)
The AMD-AIM tp parameter changes (num_heads=128//tp, qseqlen=4) made the test cases ~16x slower due to more heads and larger queries. Revert to the original 16-head decode-only config that runs within the test timeout.
1 parent a846c7e commit a043b3b

4 files changed

Lines changed: 38 additions & 50 deletions

File tree

problems/amd_202602/mixed-mla/reference.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
output v_head_dim = kv_lora_rank = 512.
77
88
The input provides:
9-
q: (total_q, num_heads, 576) bfloat16 — absorbed query (num_heads = 128 // tp)
9+
q: (total_q, 16, 576) bfloat16 — absorbed query
1010
kv_data: dict with KV cache in three formats:
1111
"bf16": Tensor (total_kv, 1, 576) bfloat16 — highest precision
1212
"fp8": (Tensor, Tensor) kv_buffer fp8 + scalar scale — per-tensor quantized
@@ -37,7 +37,7 @@
3737
# DeepSeek R1 latent MQA constants (forward_absorb path)
3838
# https://huggingface.co/deepseek-ai/DeepSeek-R1-0528/blob/main/config.json
3939
# ---------------------------------------------------------------------------
40-
TOTAL_NUM_HEADS = 128
40+
NUM_HEADS = 16
4141
NUM_KV_HEADS = 1
4242
KV_LORA_RANK = 512
4343
QK_ROPE_HEAD_DIM = 64
@@ -285,23 +285,17 @@ def _aiter_mla_decode(
285285
# generate_input / ref_kernel / check_implementation
286286
# ---------------------------------------------------------------------------
287287

288-
def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, tp: int, seed: int) -> input_t:
288+
def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, seed: int) -> input_t:
289289
"""
290290
Generate absorbed q and compressed kv_buffer for MLA decode.
291291
292-
Args:
293-
tp: tensor parallelism degree (4 or 8). num_heads = TOTAL_NUM_HEADS // tp.
294-
295292
Returns all three KV cache formats in kv_data dict:
296293
kv_data = {
297294
"bf16": Tensor — (total_kv, 1, 576) bfloat16
298295
"fp8": (Tensor, Tensor) — kv_buffer fp8 + scalar scale
299296
"mxfp4": (Tensor, Tensor) — kv_buffer fp4x2 + fp8_e8m0 scale
300297
}
301298
"""
302-
assert TOTAL_NUM_HEADS % tp == 0, f"TOTAL_NUM_HEADS ({TOTAL_NUM_HEADS}) must be divisible by tp ({tp})"
303-
num_heads = TOTAL_NUM_HEADS // tp
304-
305299
gen = torch.Generator(device="cuda")
306300
gen.manual_seed(seed)
307301

@@ -310,7 +304,7 @@ def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, tp: int, seed: i
310304

311305
# Absorbed query: (total_q, num_heads, 576) bf16
312306
q = torch.randn(
313-
(total_q, num_heads, QK_HEAD_DIM),
307+
(total_q, NUM_HEADS, QK_HEAD_DIM),
314308
dtype=torch.bfloat16, device="cuda", generator=gen,
315309
) * 0.02
316310

@@ -338,7 +332,7 @@ def generate_input(batchsize: int, qseqlen: int, kvseqlen: int, tp: int, seed: i
338332

339333
config = {
340334
"batch_size": batchsize,
341-
"num_heads": num_heads,
335+
"num_heads": NUM_HEADS,
342336
"num_kv_heads": NUM_KV_HEADS,
343337
"qk_head_dim": QK_HEAD_DIM,
344338
"kv_lora_rank": KV_LORA_RANK,

problems/amd_202602/mixed-mla/submission.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
Implement custom_kernel() to beat the aiter a8w8 reference (fp8 Q + fp8 KV).
55
66
DeepSeek R1 forward_absorb MLA config:
7-
total_num_heads = 128 (query heads before TP split)
8-
num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16)
7+
num_heads = 16 (query heads, after TP split)
98
num_kv_heads = 1 (shared latent KV head)
109
kv_lora_rank = 512 (latent dim)
1110
qk_rope_head_dim = 64 (RoPE dim)
@@ -18,24 +17,24 @@
1817
- First 512 dims (kv_lora_rank) used as values (for output computation)
1918
2019
Input tuple:
21-
q: (total_q, num_heads, 576) bfloat16 — absorbed query
20+
q: (total_q, 16, 576) bfloat16 — absorbed query
2221
kv_data: dict with three KV cache formats:
2322
kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16
2423
kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 (total_kv,1,576) + scalar scale
2524
kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 (total_kv,1,288) + fp8_e8m0 scale
2625
qo_indptr: (batch_size + 1,) int32 — query segment pointers
2726
kv_indptr: (batch_size + 1,) int32 — KV segment pointers
28-
config: dict with MLA parameters (includes num_heads computed from tp)
27+
config: dict with MLA parameters
2928
3029
Output:
31-
attention output: (total_q, num_heads, 512) bfloat16
30+
attention output: (total_q, 16, 512) bfloat16
3231
3332
The reference uses aiter's a8w8 persistent MLA kernel (fp8 Q + fp8 KV),
3433
which is ~2-3x faster than bf16. To beat it, consider:
3534
1. Use mxfp4 KV cache for even lower memory bandwidth
3635
- Fuse dequantization with attention to avoid bf16 materialization
3736
2. Custom kernel with tighter memory access patterns
38-
3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads
37+
3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads
3938
4. Variable-length batching: indptr-based segmented attention
4039
5. Split K/V from buffer: full 576 dims for keys, first 512 dims for values
4140
"""

problems/amd_202602/mixed-mla/task.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
#
66
# Input: (q, kv_data, qo_indptr, kv_indptr, config)
77
# q: (total_q, num_heads, qk_head_dim) bfloat16
8-
# num_heads = 128 // tp (tp=4 → 32, tp=8 → 16)
98
# kv_data: dict with three KV cache formats:
109
# "bf16": Tensor (total_kv, 1, 576) bfloat16
1110
# "fp8": (Tensor, Tensor) kv_buffer fp8 (total_kv, 1, 576) + scalar scale
1211
# "mxfp4": (Tensor, Tensor) kv_buffer fp4x2 (total_kv, 1, 288) + fp8_e8m0 scale
1312
# qo_indptr: (batch_size + 1,) int32
1413
# kv_indptr: (batch_size + 1,) int32
15-
# config: dict with MLA parameters (includes num_heads computed from tp)
14+
# config: dict with MLA parameters
1615
#
1716
# where qk_head_dim = kv_lora_rank + qk_rope_head_dim = 512 + 64 = 576
1817
#
@@ -34,5 +33,4 @@ class TestSpec(TypedDict):
3433
batchsize: int
3534
qseqlen: int
3635
kvseqlen: int
37-
tp: int
3836
seed: int

problems/amd_202602/mixed-mla/task.yml

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,40 +20,39 @@ description: |
2020
persistent mode), which is ~2-3x faster than bf16 on MI355X.
2121
2222
DeepSeek R1 forward_absorb MLA config:
23-
- total_num_heads = 128 (query heads before TP split)
24-
- num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16)
23+
- num_heads = 16 (query heads, after TP split)
2524
- num_kv_heads = 1 (shared latent KV head)
2625
- kv_lora_rank = 512
2726
- qk_rope_head_dim = 64
2827
- qk_head_dim = 576 (kv_lora_rank + qk_rope_head_dim, absorbed q/k dim)
2928
- v_head_dim = 512 (= kv_lora_rank, output dim)
3029
- sm_scale = 1/sqrt(576)
3130
- dtype: q=bfloat16
32-
- q_seq_len = 1 or 4, kv_seq_len up to 8k
31+
- decode only (q_seq_len=1, kv_seq_len up to 8k)
3332
3433
KV buffer format (forward_absorb):
3534
- Full 576 dims are used as keys (for Q@K^T score computation)
3635
- First 512 dims (kv_lora_rank) are used as values (for output computation)
3736
3837
Input tuple: (q, kv_data, qo_indptr, kv_indptr, config)
39-
- q: (total_q, num_heads, 576) bfloat16 — absorbed query
38+
- q: (total_q, 16, 576) bfloat16 — absorbed query
4039
- kv_data: dict with three KV cache formats:
4140
kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16
4241
kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 + scalar scale
4342
kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 + fp8_e8m0 scale
4443
- qo_indptr: (batch_size+1,) int32 — query segment pointers
4544
- kv_indptr: (batch_size+1,) int32 — KV segment pointers
46-
- config: dict with MLA parameters (includes num_heads computed from tp)
45+
- config: dict with MLA parameters
4746
4847
Return:
49-
- attention output: (total_q, num_heads, 512) bfloat16
48+
- attention output: (total_q, 16, 512) bfloat16
5049
5150
Key optimization opportunities:
5251
1. Use mxfp4 KV cache for even lower memory bandwidth (4x savings over bf16)
5352
- Fuse dequantization with attention to skip bf16 materialization
5453
2. Custom kernel with tighter memory access patterns
55-
3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads
56-
4. q_seq_len=1 or 4, kv_seq_len up to 8k — memory-bound workload
54+
3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads
55+
4. Decode: q_seq_len=1, kv_seq_len up to 8k — memory-bound workload
5756
5. Variable-length batching: indptr-based segmented attention
5857
6. Split K/V from buffer: full 576 dims for keys, first 512 dims for values
5958
@@ -70,29 +69,27 @@ benchmark_timeout: 900
7069
ranked_timeout: 1200
7170

7271
tests:
73-
# bs=4, tp=8
74-
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 8, "seed": 4220}
75-
- {"batchsize": 4, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 4231}
76-
# bs=32, tp=4
77-
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5412}
78-
- {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5423}
79-
# bs=128, tp=8
80-
- {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816}
81-
- {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 7827}
72+
# bs=4
73+
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4220}
74+
# bs=32
75+
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412}
76+
# bs=64
77+
- {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360}
78+
# bs=256
79+
- {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826}
8280

8381
benchmarks:
84-
# bs=4, tp=4
85-
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 4237}
86-
- {"batchsize": 4, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 4251}
87-
# bs=32, tp=8
88-
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 5415}
89-
- {"batchsize": 32, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 5420}
90-
# bs=32, tp=4
91-
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5432}
92-
- {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5443}
93-
# bs=128, tp=8
94-
- {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816}
95-
- {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 8, "seed": 7824}
96-
82+
# bs=4
83+
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4217}
84+
- {"batchsize": 4, "qseqlen": 1, "kvseqlen": 8192, "seed": 4220}
85+
# bs=32
86+
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412}
87+
- {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "seed": 5415}
88+
# bs=64
89+
- {"batchsize": 64, "qseqlen": 1, "kvseqlen": 1024, "seed": 1357}
90+
- {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360}
91+
# bs=256
92+
- {"batchsize": 256, "qseqlen": 1, "kvseqlen": 1024, "seed": 9823}
93+
- {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826}
9794

9895
ranking_by: "geom"

0 commit comments

Comments
 (0)