Skip to content

Commit bc2bc28

Browse files
author
Mark Saroufim
committed
Fix moe-mxfp4 quantization: use batched torch_quant instead of per-expert loop
Port fix from AMD-AIM#17. Switches generate_input to use aiter.get_torch_quant(QuantType.per_1x32) for batched quantization and adds layout=(16, 16) to shuffle_weight calls.
1 parent 93f6bfe commit bc2bc28

1 file changed

Lines changed: 19 additions & 43 deletions

File tree

problems/amd_202602/moe-mxfp4/reference.py

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from typing import Dict, Tuple, Optional
66
import math
77

8-
from aiter import ActivationType, QuantType
8+
import aiter
9+
from aiter import ActivationType, QuantType, dtypes
910
from aiter.fused_moe import fused_moe
1011
from aiter.utility import fp4_utils
1112
from aiter.ops.shuffle import shuffle_weight
@@ -98,48 +99,23 @@ def generate_input(
9899
topk_ids = torch.cat([routed_ids, shared_ids], dim=-1) # [M, total_top_k]
99100
topk_weights = torch.cat([routed_weights, shared_weights], dim=-1) # [M, total_top_k]
100101

101-
# ── Expert weights: bf16 -> quantize to MXFP4 ──
102-
# Generate weights for ALL experts (routed + shared)
103-
# gate_up = fused [gate_proj; up_proj] per expert: [2*d_expert_pad, d_hidden_pad]
104-
# down = down_proj per expert: [d_hidden_pad, d_expert_pad]
105-
gate_up_q_list, gate_up_s_list = [], []
106-
down_q_list, down_s_list = [], []
107-
108-
for _ in range(E_total):
109-
# gate_proj + up_proj -> fused [2*d_expert_pad, d_hidden_pad]
110-
gate_bf16 = torch.randn(
111-
(d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen
112-
) / math.sqrt(d_hidden)
113-
up_bf16 = torch.randn(
114-
(d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen
115-
) / math.sqrt(d_hidden)
116-
gate_up_bf16 = torch.cat([gate_bf16, up_bf16], dim=0)
117-
118-
# down_proj -> [d_hidden_pad, d_expert_pad]
119-
down_bf16 = torch.randn(
120-
(d_hidden_pad, d_expert_pad), device='cuda', dtype=torch.bfloat16, generator=gen
121-
) / math.sqrt(d_expert)
122-
123-
# Quantize to MXFP4
124-
gu_q, gu_s = fp4_utils.dynamic_mxfp4_quant(gate_up_bf16)
125-
dn_q, dn_s = fp4_utils.dynamic_mxfp4_quant(down_bf16)
126-
127-
gate_up_q_list.append(gu_q)
128-
gate_up_s_list.append(gu_s)
129-
down_q_list.append(dn_q)
130-
down_s_list.append(dn_s)
131-
132-
# Stack into [E_total, ...] tensors — raw (before shuffle)
133-
gate_up_weight = torch.stack(gate_up_q_list) # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2
134-
gate_up_weight_scale = torch.stack(gate_up_s_list) # [E_total, 2*d_expert_pad, scale_K] e8m0
135-
down_weight = torch.stack(down_q_list) # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2
136-
down_weight_scale = torch.stack(down_s_list) # [E_total, d_hidden_pad, scale_K] e8m0
137-
138-
# Pre-shuffled weight. You can also shuffle the weights yourself before calling the kernel.
139-
gate_up_weight_shuffled = shuffle_weight(gate_up_weight.clone())
140-
down_weight_shuffled = shuffle_weight(down_weight.clone())
141-
gate_up_weight_scale_shuffled = fp4_utils.e8m0_shuffle(gate_up_weight_scale.reshape(E_total, -1))
142-
down_weight_scale_shuffled = fp4_utils.e8m0_shuffle(down_weight_scale.reshape(E_total, -1))
102+
gate_up_bf16 = torch.randn(
103+
(E_total, 2 * d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen,
104+
) / math.sqrt(d_hidden)
105+
down_bf16 = torch.randn(
106+
(E_total, d_hidden_pad, d_expert_pad), device='cuda', dtype=torch.bfloat16, generator=gen,
107+
) / math.sqrt(d_expert)
108+
109+
torch_quant = aiter.get_torch_quant(QuantType.per_1x32)
110+
gate_up_weight, gate_up_weight_scale = torch_quant(gate_up_bf16, quant_dtype=dtypes.fp4x2)
111+
down_weight, down_weight_scale = torch_quant(down_bf16, quant_dtype=dtypes.fp4x2)
112+
gate_up_weight = gate_up_weight.view(E_total, 2 * d_expert_pad, d_hidden_pad // 2)
113+
down_weight = down_weight.view(E_total, d_hidden_pad, d_expert_pad // 2)
114+
115+
gate_up_weight_shuffled = shuffle_weight(gate_up_weight, layout=(16, 16))
116+
down_weight_shuffled = shuffle_weight(down_weight, layout=(16, 16))
117+
gate_up_weight_scale_shuffled = fp4_utils.e8m0_shuffle(gate_up_weight_scale)
118+
down_weight_scale_shuffled = fp4_utils.e8m0_shuffle(down_weight_scale)
143119

144120
return (
145121
hidden_states, # [M, d_hidden] bf16

0 commit comments

Comments
 (0)