11"""
22FP4 quant + FP4 GEMM reference: bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> bf16 C.
33Quant logic follows aiter op_tests/test_gemm_a4w4.py (get_triton_quant(QuantType.per_1x32)).
4+
5+ NOTE: Explicitly uses dynamic_mxfp4_quant from aiter.ops.triton.quant (patched in #975)
6+ rather than going through aiter.get_triton_quant, which may dispatch to the
7+ unpatched fp4_utils.py kernel. See ROCm/aiter#974, ROCm/aiter#975.
48"""
59import torch
610from task import input_t , output_t
711from utils import make_match_reference
812from aiter import QuantType ,dtypes
913import aiter
1014from aiter .ops .shuffle import shuffle_weight
15+ from aiter .ops .triton .quant import dynamic_mxfp4_quant # #975-patched kernel
16+ from aiter .utility .fp4_utils import e8m0_shuffle
1117# K must be divisible by 64 (scale group 32 and fp4 pack 2)
1218SCALE_GROUP_SIZE = 32
1319
20+ def _quant_mxfp4 (x , shuffle = True ):
21+ x_fp4 , bs_e8m0 = dynamic_mxfp4_quant (x )
22+ if shuffle :
23+ bs_e8m0 = e8m0_shuffle (bs_e8m0 )
24+ return x_fp4 .view (dtypes .fp4x2 ), bs_e8m0 .view (dtypes .fp8_e8m0 )
25+
1426def generate_input (m : int , n : int , k : int , seed : int ):# -> input_t:
1527 """
1628 Generate random bf16 inputs A [m, k], B [n, k] and quantized MXFP4 B, shuffled B and B_scale.
@@ -23,11 +35,7 @@ def generate_input(m: int, n: int, k: int, seed: int):# -> input_t:
2335 gen .manual_seed (seed )
2436 A = torch .randn ((m , k ), dtype = torch .bfloat16 , device = "cuda" , generator = gen )
2537 B = torch .randn ((n , k ), dtype = torch .bfloat16 , device = "cuda" , generator = gen )
26-
27- # quantized mxfp4 B
28- quant_func = aiter .get_triton_quant (QuantType .per_1x32 )
29- B_q , B_scale_sh = quant_func (B , shuffle = True )
30-
38+ B_q , B_scale_sh = _quant_mxfp4 (B , shuffle = True )
3139 # shuffle B(weight) to (16,16) tile coalesced
3240 B_shuffle = shuffle_weight (B_q , layout = (16 , 16 ))
3341 return (A , B , B_q , B_shuffle , B_scale_sh )
@@ -76,10 +84,8 @@ def ref_kernel(data: input_t) -> output_t:
7684
7785 # 1) PyTorch impl just for your reference: dequant fp4 + e8m0 -> f32 -> mm -> bf16
7886 # Per-1x32 MXFP4 quant
79- # quant_func = aiter.get_triton_quant(QuantType.per_1x32)
80- # quant_func(x, shuffle=False) -> (dtypes.fp4x2, scale); scale layout matches gemm_a4w4
81- # A_q, A_scale = quant_func(A, shuffle=False)
82- # B_q, B_scale = quant_func(B, shuffle=False)
87+ # A_q, A_scale = _quant_mxfp4(A, shuffle=False)
88+ # B_q, B_scale = _quant_mxfp4(B, shuffle=False)
8389
8490 # gemm_a4w4 expects A [M,K/2], B [N,K/2] as dtypes.fp4x2; A_scale/B_scale [*,K/32] E8M0
8591 # quant_func returns scale as dtypes.fp8_e8m0; gemm_a4w4 accepts E8M0, no view to uint8 needed
@@ -91,9 +97,7 @@ def ref_kernel(data: input_t) -> output_t:
9197 # out_torch = run_torch_fp4_mm(A_q, B_q, A_scale, B_scale, torch.bfloat16)
9298
9399 # 2) aiter.gemm_a4w4 path: needs shuffled B_q and shuffled scales (see test_gemm_a4w4.py:102-105)
94- # Per-1x32 MXFP4 quant
95- quant_func = aiter .get_triton_quant (QuantType .per_1x32 )
96- A_q , A_scale_sh = quant_func (A , shuffle = True )
100+ A_q , A_scale_sh = _quant_mxfp4 (A , shuffle = True )
97101 # to be noted, aiter also has other a4w4 implements using triton, https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py
98102 out_gemm = aiter .gemm_a4w4 (
99103 A_q ,
0 commit comments