@@ -20,39 +20,40 @@ description: |
2020 persistent mode), which is ~2-3x faster than bf16 on MI355X.
2121
2222 DeepSeek R1 forward_absorb MLA config:
23- - num_heads = 16 (query heads, after TP split)
23+ - total_num_heads = 128 (query heads before TP split)
24+ - num_heads = 128 // tp (query heads per device, tp=4 → 32, tp=8 → 16)
2425 - num_kv_heads = 1 (shared latent KV head)
2526 - kv_lora_rank = 512
2627 - qk_rope_head_dim = 64
2728 - qk_head_dim = 576 (kv_lora_rank + qk_rope_head_dim, absorbed q/k dim)
2829 - v_head_dim = 512 (= kv_lora_rank, output dim)
2930 - sm_scale = 1/sqrt(576)
3031 - dtype: q=bfloat16
31- - decode only ( q_seq_len=1 , kv_seq_len up to 8k)
32+ - q_seq_len = 1 or 4 , kv_seq_len up to 8k
3233
3334 KV buffer format (forward_absorb):
3435 - Full 576 dims are used as keys (for Q@K^T score computation)
3536 - First 512 dims (kv_lora_rank) are used as values (for output computation)
3637
3738 Input tuple: (q, kv_data, qo_indptr, kv_indptr, config)
38- - q: (total_q, 16 , 576) bfloat16 — absorbed query
39+ - q: (total_q, num_heads , 576) bfloat16 — absorbed query
3940 - kv_data: dict with three KV cache formats:
4041 kv_data["bf16"] — Tensor (total_kv, 1, 576) bfloat16
4142 kv_data["fp8"] — (Tensor, Tensor): kv_buffer fp8 + scalar scale
4243 kv_data["mxfp4"] — (Tensor, Tensor): kv_buffer fp4x2 + fp8_e8m0 scale
4344 - qo_indptr: (batch_size+1,) int32 — query segment pointers
4445 - kv_indptr: (batch_size+1,) int32 — KV segment pointers
45- - config: dict with MLA parameters
46+ - config: dict with MLA parameters (includes num_heads computed from tp)
4647
4748 Return:
48- - attention output: (total_q, 16 , 512) bfloat16
49+ - attention output: (total_q, num_heads , 512) bfloat16
4950
5051 Key optimization opportunities:
5152 1. Use mxfp4 KV cache for even lower memory bandwidth (4x savings over bf16)
5253 - Fuse dequantization with attention to skip bf16 materialization
5354 2. Custom kernel with tighter memory access patterns
54- 3. MQA: 1 KV head shared across 16 query heads — minimize redundant memory loads
55- 4. Decode: q_seq_len=1, kv_seq_len up to 8k — memory-bound workload
55+ 3. MQA: 1 KV head shared across num_heads query heads — minimize redundant memory loads
56+ 4. q_seq_len=1 or 4 , kv_seq_len up to 8k — memory-bound workload
5657 5. Variable-length batching: indptr-based segmented attention
5758 6. Split K/V from buffer: full 576 dims for keys, first 512 dims for values
5859
@@ -69,27 +70,29 @@ benchmark_timeout: 900
6970ranked_timeout : 1200
7071
7172tests :
72- # bs=4
73- - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4220}
74- # bs=32
75- - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412}
76- # bs=64
77- - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360}
78- # bs=256
79- - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826}
73+ # bs=4, tp=8
74+ - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 8, "seed": 4220}
75+ - {"batchsize": 4, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 4231}
76+ # bs=32, tp=4
77+ - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5412}
78+ - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5423}
79+ # bs=128, tp=8
80+ - {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816}
81+ - {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 7827}
8082
8183benchmarks :
82- # bs=4
83- - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "seed": 4217}
84- - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 8192, "seed": 4220}
85- # bs=32
86- - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "seed": 5412}
87- - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "seed": 5415}
88- # bs=64
89- - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 1024, "seed": 1357}
90- - {"batchsize": 64, "qseqlen": 1, "kvseqlen": 8192, "seed": 1360}
91- # bs=256
92- - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 1024, "seed": 9823}
93- - {"batchsize": 256, "qseqlen": 1, "kvseqlen": 8192, "seed": 9826}
84+ # bs=4, tp=4
85+ - {"batchsize": 4, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 4237}
86+ - {"batchsize": 4, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 4251}
87+ # bs=32, tp=8
88+ - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 5415}
89+ - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 1024, "tp": 8, "seed": 5420}
90+ # bs=32, tp=4
91+ - {"batchsize": 32, "qseqlen": 1, "kvseqlen": 1024, "tp": 4, "seed": 5432}
92+ - {"batchsize": 32, "qseqlen": 4, "kvseqlen": 8192, "tp": 4, "seed": 5443}
93+ # bs=128, tp=8
94+ - {"batchsize": 128, "qseqlen": 1, "kvseqlen": 8192, "tp": 8, "seed": 7816}
95+ - {"batchsize": 128, "qseqlen": 4, "kvseqlen": 8192, "tp": 8, "seed": 7824}
96+
9497
9598ranking_by : " geom"
0 commit comments