@@ -20,40 +20,39 @@ description: |
2020 persistent mode), which is ~2-3x faster than bf16 on MI355X.
2121
2222 DeepSeek R1 forward_absorb MLA config:
23- - total_num_heads = 128 (query heads before TP split)
24- - num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16)
23+ - num_heads = 16 (query heads, after TP split)
2524 - num_kv_heads = 1 (shared latent KV head)
2625 - kv_lora_rank = 512
2726 - qk_rope_head_dim = 64
2827 - qk_head_dim = 576 (kv_lora_rank + qk_rope_head_dim, absorbed q/k dim)
2928 - v_head_dim = 512 (= kv_lora_rank, output dim)
3029 - sm_scale = 1/sqrt(576)
3130 - dtype: q=bfloat16
32- - q_seq_len = 1 or 4 , kv_seq_len up to 8k
31+ - decode only (q_seq_len=1 , kv_seq_len up to 8k)
3332
3433 KV buffer format (forward_absorb):
3534 - Full 576 dims are used as keys (for Q@K^T score computation)
3635 - First 512 dims (kv_lora_rank) are used as values (for output computation)
3736
3837 Input tuple: (q, kv_data, qo_indptr, kv_indptr, config)
39- - q: (total_q, num_heads , 576) bfloat16 — absorbed query
38+ - q: (total_q, 16 , 576) bfloat16 — absorbed query
4039 - kv_data: dict with three KV cache formats:
4140 kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16
4241 kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 + scalar scale
4342 kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 + fp8_e8m0 scale
4443 - qo_indptr: (batch_size+1,) int32 — query segment pointers
4544 - kv_indptr: (batch_size+1,) int32 — KV segment pointers
46- - config: dict with MLA parameters (includes num_heads computed from tp)
45+ - config: dict with MLA parameters
4746
4847 Return:
49- - attention output: (total_q, num_heads , 512) bfloat16
48+ - attention output: (total_q, 16 , 512) bfloat16
5049
5150 Key optimization opportunities:
5251 1. Use mxfp4 KV cache for even lower memory bandwidth (4x savings over bf16)
5352 - Fuse dequantization with attention to skip bf16 materialization
5453 2. Custom kernel with tighter memory access patterns
55- 3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads
56- 4. q_seq_len=1 or 4 , kv_seq_len up to 8k — memory-bound workload
54+ 3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads
55+ 4. Decode: q_seq_len=1, kv_seq_len up to 8k — memory-bound workload
5756 5. Variable-length batching: indptr-based segmented attention
5857 6. Split K/V from buffer: full 576 dims for keys, first 512 dims for values
5958
@@ -70,29 +69,27 @@ benchmark_timeout: 900
7069ranked_timeout : 1200
7170
7271tests :
73- # bs=4, tp=8
74- - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 8, "seed": 4220}
75- - {"batchsize": 4, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 4231}
76- # bs=32, tp=4
77- - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5412}
78- - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5423}
79- # bs=128, tp=8
80- - {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816}
81- - {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 7827}
72+ # bs=4
73+ - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4220}
74+ # bs=32
75+ - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412}
76+ # bs=64
77+ - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360}
78+ # bs=256
79+ - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826}
8280
8381benchmarks :
84- # bs=4, tp=4
85- - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 4237}
86- - {"batchsize": 4, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 4251}
87- # bs=32, tp=8
88- - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 5415}
89- - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 5420}
90- # bs=32, tp=4
91- - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5432}
92- - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5443}
93- # bs=128, tp=8
94- - {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816}
95- - {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 8, "seed": 7824}
96-
82+ # bs=4
83+ - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4217}
84+ - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 8192, "seed": 4220}
85+ # bs=32
86+ - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412}
87+ - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "seed": 5415}
88+ # bs=64
89+ - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 1024, "seed": 1357}
90+ - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360}
91+ # bs=256
92+ - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 1024, "seed": 9823}
93+ - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826}
9794
9895ranking_by : " geom"
0 commit comments