Skip to content

Commit ebc65d7

Browse files
authored
[MXFP4] Hard code dynamic_mxfp4_quant from aiter.ops.triton.quant (#120)
1 parent 3ff2aa5 commit ebc65d7

2 files changed

Lines changed: 25 additions & 14 deletions

File tree

problems/amd_202602/mxfp4-mm/reference.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
"""
22
FP4 quant + FP4 GEMM reference: bf16 A, MXFP4 B -> MXFP4 per-1x32 quant A -> gemm_a4w4 -> bf16 C.
33
Quant logic follows aiter op_tests/test_gemm_a4w4.py (get_triton_quant(QuantType.per_1x32)).
4+
5+
NOTE: Explicitly uses dynamic_mxfp4_quant from aiter.ops.triton.quant (patched in #975)
6+
rather than going through aiter.get_triton_quant, which may dispatch to the
7+
unpatched fp4_utils.py kernel. See ROCm/aiter#974, ROCm/aiter#975.
48
"""
59
import torch
610
from task import input_t, output_t
711
from utils import make_match_reference
812
from aiter import QuantType,dtypes
913
import aiter
1014
from aiter.ops.shuffle import shuffle_weight
15+
from aiter.ops.triton.quant import dynamic_mxfp4_quant # #975-patched kernel
16+
from aiter.utility.fp4_utils import e8m0_shuffle
1117
# K must be divisible by 64 (scale group 32 and fp4 pack 2)
1218
SCALE_GROUP_SIZE = 32
1319

20+
def _quant_mxfp4(x, shuffle=True):
21+
x_fp4, bs_e8m0 = dynamic_mxfp4_quant(x)
22+
if shuffle:
23+
bs_e8m0 = e8m0_shuffle(bs_e8m0)
24+
return x_fp4.view(dtypes.fp4x2), bs_e8m0.view(dtypes.fp8_e8m0)
25+
1426
def generate_input(m: int, n: int, k: int, seed: int):# -> input_t:
1527
"""
1628
Generate random bf16 inputs A [m, k], B [n, k] and quantized MXFP4 B, shuffled B and B_scale.
@@ -23,11 +35,7 @@ def generate_input(m: int, n: int, k: int, seed: int):# -> input_t:
2335
gen.manual_seed(seed)
2436
A = torch.randn((m, k), dtype=torch.bfloat16, device="cuda", generator=gen)
2537
B = torch.randn((n, k), dtype=torch.bfloat16, device="cuda", generator=gen)
26-
27-
# quantized mxfp4 B
28-
quant_func = aiter.get_triton_quant(QuantType.per_1x32)
29-
B_q, B_scale_sh = quant_func(B, shuffle=True)
30-
38+
B_q, B_scale_sh = _quant_mxfp4(B, shuffle=True)
3139
# shuffle B(weight) to (16,16) tile coalesced
3240
B_shuffle = shuffle_weight(B_q, layout=(16, 16))
3341
return (A, B, B_q, B_shuffle, B_scale_sh)
@@ -76,10 +84,8 @@ def ref_kernel(data: input_t) -> output_t:
7684

7785
# 1) PyTorch impl just for your reference: dequant fp4 + e8m0 -> f32 -> mm -> bf16
7886
# Per-1x32 MXFP4 quant
79-
# quant_func = aiter.get_triton_quant(QuantType.per_1x32)
80-
# quant_func(x, shuffle=False) -> (dtypes.fp4x2, scale); scale layout matches gemm_a4w4
81-
# A_q, A_scale = quant_func(A, shuffle=False)
82-
# B_q, B_scale = quant_func(B, shuffle=False)
87+
# A_q, A_scale = _quant_mxfp4(A, shuffle=False)
88+
# B_q, B_scale = _quant_mxfp4(B, shuffle=False)
8389

8490
# gemm_a4w4 expects A [M,K/2], B [N,K/2] as dtypes.fp4x2; A_scale/B_scale [*,K/32] E8M0
8591
# quant_func returns scale as dtypes.fp8_e8m0; gemm_a4w4 accepts E8M0, no view to uint8 needed
@@ -91,9 +97,7 @@ def ref_kernel(data: input_t) -> output_t:
9197
# out_torch = run_torch_fp4_mm(A_q, B_q, A_scale, B_scale, torch.bfloat16)
9298

9399
# 2) aiter.gemm_a4w4 path: needs shuffled B_q and shuffled scales (see test_gemm_a4w4.py:102-105)
94-
# Per-1x32 MXFP4 quant
95-
quant_func = aiter.get_triton_quant(QuantType.per_1x32)
96-
A_q, A_scale_sh = quant_func(A, shuffle=True)
100+
A_q, A_scale_sh = _quant_mxfp4(A, shuffle=True)
97101
# to be noted, aiter also has other a4w4 implements using triton, https://github.com/ROCm/aiter/blob/main/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py
98102
out_gemm = aiter.gemm_a4w4(
99103
A_q,

problems/amd_202602/mxfp4-mm/submission.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,22 @@ def custom_kernel(data: input_t) -> output_t:
1212
"""
1313
import aiter
1414
from aiter import QuantType, dtypes
15+
from aiter.ops.triton.quant import dynamic_mxfp4_quant
16+
from aiter.utility.fp4_utils import e8m0_shuffle
1517

18+
def _quant_mxfp4(x, shuffle=True):
19+
x_fp4, bs_e8m0 = dynamic_mxfp4_quant(x)
20+
if shuffle:
21+
bs_e8m0 = e8m0_shuffle(bs_e8m0)
22+
return x_fp4.view(dtypes.fp4x2), bs_e8m0.view(dtypes.fp8_e8m0)
23+
1624
A, B, B_q, B_shuffle, B_scale_sh = data
1725
A = A.contiguous()
1826
B = B.contiguous()
1927
m, k = A.shape
2028
n, _ = B.shape
2129

22-
quant_func = aiter.get_triton_quant(QuantType.per_1x32)
23-
A_q, A_scale_sh = quant_func(A, shuffle=True)
30+
A_q, A_scale_sh = _quant_mxfp4(A, shuffle=True)
2431
out_gemm = aiter.gemm_a4w4(
2532
A_q,
2633
B_shuffle,

0 commit comments

Comments
 (0)