|
5 | 5 | from typing import Dict, Tuple, Optional |
6 | 6 | import math |
7 | 7 |
|
8 | | -from aiter import ActivationType, QuantType |
| 8 | +import aiter |
| 9 | +from aiter import ActivationType, QuantType, dtypes |
9 | 10 | from aiter.fused_moe import fused_moe |
10 | 11 | from aiter.utility import fp4_utils |
11 | 12 | from aiter.ops.shuffle import shuffle_weight |
@@ -98,48 +99,23 @@ def generate_input( |
98 | 99 | topk_ids = torch.cat([routed_ids, shared_ids], dim=-1) # [M, total_top_k] |
99 | 100 | topk_weights = torch.cat([routed_weights, shared_weights], dim=-1) # [M, total_top_k] |
100 | 101 |
|
101 | | - # ── Expert weights: bf16 -> quantize to MXFP4 ── |
102 | | - # Generate weights for ALL experts (routed + shared) |
103 | | - # gate_up = fused [gate_proj; up_proj] per expert: [2*d_expert_pad, d_hidden_pad] |
104 | | - # down = down_proj per expert: [d_hidden_pad, d_expert_pad] |
105 | | - gate_up_q_list, gate_up_s_list = [], [] |
106 | | - down_q_list, down_s_list = [], [] |
107 | | - |
108 | | - for _ in range(E_total): |
109 | | - # gate_proj + up_proj -> fused [2*d_expert_pad, d_hidden_pad] |
110 | | - gate_bf16 = torch.randn( |
111 | | - (d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen |
112 | | - ) / math.sqrt(d_hidden) |
113 | | - up_bf16 = torch.randn( |
114 | | - (d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen |
115 | | - ) / math.sqrt(d_hidden) |
116 | | - gate_up_bf16 = torch.cat([gate_bf16, up_bf16], dim=0) |
117 | | - |
118 | | - # down_proj -> [d_hidden_pad, d_expert_pad] |
119 | | - down_bf16 = torch.randn( |
120 | | - (d_hidden_pad, d_expert_pad), device='cuda', dtype=torch.bfloat16, generator=gen |
121 | | - ) / math.sqrt(d_expert) |
122 | | - |
123 | | - # Quantize to MXFP4 |
124 | | - gu_q, gu_s = fp4_utils.dynamic_mxfp4_quant(gate_up_bf16) |
125 | | - dn_q, dn_s = fp4_utils.dynamic_mxfp4_quant(down_bf16) |
126 | | - |
127 | | - gate_up_q_list.append(gu_q) |
128 | | - gate_up_s_list.append(gu_s) |
129 | | - down_q_list.append(dn_q) |
130 | | - down_s_list.append(dn_s) |
131 | | - |
132 | | - # Stack into [E_total, ...] tensors — raw (before shuffle) |
133 | | - gate_up_weight = torch.stack(gate_up_q_list) # [E_total, 2*d_expert_pad, d_hidden_pad//2] fp4x2 |
134 | | - gate_up_weight_scale = torch.stack(gate_up_s_list) # [E_total, 2*d_expert_pad, scale_K] e8m0 |
135 | | - down_weight = torch.stack(down_q_list) # [E_total, d_hidden_pad, d_expert_pad//2] fp4x2 |
136 | | - down_weight_scale = torch.stack(down_s_list) # [E_total, d_hidden_pad, scale_K] e8m0 |
137 | | - |
138 | | - # Pre-shuffled weight. You can also shuffle the weights yourself before calling the kernel. |
139 | | - gate_up_weight_shuffled = shuffle_weight(gate_up_weight.clone()) |
140 | | - down_weight_shuffled = shuffle_weight(down_weight.clone()) |
141 | | - gate_up_weight_scale_shuffled = fp4_utils.e8m0_shuffle(gate_up_weight_scale.reshape(E_total, -1)) |
142 | | - down_weight_scale_shuffled = fp4_utils.e8m0_shuffle(down_weight_scale.reshape(E_total, -1)) |
| 102 | + gate_up_bf16 = torch.randn( |
| 103 | + (E_total, 2 * d_expert_pad, d_hidden_pad), device='cuda', dtype=torch.bfloat16, generator=gen, |
| 104 | + ) / math.sqrt(d_hidden) |
| 105 | + down_bf16 = torch.randn( |
| 106 | + (E_total, d_hidden_pad, d_expert_pad), device='cuda', dtype=torch.bfloat16, generator=gen, |
| 107 | + ) / math.sqrt(d_expert) |
| 108 | + |
| 109 | + torch_quant = aiter.get_torch_quant(QuantType.per_1x32) |
| 110 | + gate_up_weight, gate_up_weight_scale = torch_quant(gate_up_bf16, quant_dtype=dtypes.fp4x2) |
| 111 | + down_weight, down_weight_scale = torch_quant(down_bf16, quant_dtype=dtypes.fp4x2) |
| 112 | + gate_up_weight = gate_up_weight.view(E_total, 2 * d_expert_pad, d_hidden_pad // 2) |
| 113 | + down_weight = down_weight.view(E_total, d_hidden_pad, d_expert_pad // 2) |
| 114 | + |
| 115 | + gate_up_weight_shuffled = shuffle_weight(gate_up_weight, layout=(16, 16)) |
| 116 | + down_weight_shuffled = shuffle_weight(down_weight, layout=(16, 16)) |
| 117 | + gate_up_weight_scale_shuffled = fp4_utils.e8m0_shuffle(gate_up_weight_scale) |
| 118 | + down_weight_scale_shuffled = fp4_utils.e8m0_shuffle(down_weight_scale) |
143 | 119 |
|
144 | 120 | return ( |
145 | 121 | hidden_states, # [M, d_hidden] bf16 |
|
0 commit comments