Skip to content

Commit 930d48d

Browse files
author
Mark Saroufim
committed
re-release of dual gemm problem
1 parent 1d795e6 commit 930d48d

7 files changed

Lines changed: 1437 additions & 1 deletion

File tree

problems/nvidia.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,9 @@ problems:
2020
deadline: "2026-02-20 7:59"
2121
gpus:
2222
- NVIDIA
23-
23+
- directory: nvidia/final_nvfp4_dual_gemm
24+
name: final_nvfp4_dual_gemm
25+
deadline: "2026-01-20 7:59"
26+
gpus:
27+
- NVIDIA
28+
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import torch
2+
from task import input_t, output_t
3+
from utils import make_match_reference
4+
5+
# Scaling factor vector size
6+
sf_vec_size = 16
7+
8+
# Helper function for ceiling division
9+
def ceil_div(a, b):
10+
return (a + b - 1) // b
11+
12+
# Helper function to convert scale factor tensor to blocked format
13+
def to_blocked(input_matrix):
14+
rows, cols = input_matrix.shape
15+
16+
# Please ensure rows and cols are multiples of 128 and 4 respectively
17+
n_row_blocks = ceil_div(rows, 128)
18+
n_col_blocks = ceil_div(cols, 4)
19+
20+
padded = input_matrix
21+
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
22+
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
23+
24+
return rearranged.flatten()
25+
26+
27+
def ref_kernel(
28+
data: input_t,
29+
) -> output_t:
30+
"""
31+
PyTorch reference implementation of NVFP4 block-scaled dual GEMM with silu activation,
32+
C = silu(A @ B1) * (A @ B2).
33+
"""
34+
a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data
35+
36+
# Get dimensions from MxNxL layout
37+
m, n, l = c_ref.shape
38+
39+
# Call torch._scaled_mm to compute the GEMV result
40+
ref1 = torch.empty(
41+
(l, m, n),
42+
dtype=torch.float32,
43+
device="cuda",
44+
).permute(1, 2, 0)
45+
ref2 = torch.empty(
46+
(l, m, n),
47+
dtype=torch.float32,
48+
device="cuda",
49+
).permute(1, 2, 0)
50+
for l_idx in range(l):
51+
# Convert the scale factor tensor to blocked format
52+
scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx])
53+
scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx])
54+
scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx])
55+
# (m, k) @ (n, k).T -> (m, n)
56+
res1 = torch._scaled_mm(
57+
a_ref[:, :, l_idx],
58+
b1_ref[:, :, l_idx].transpose(0, 1),
59+
scale_a.cuda(),
60+
scale_b1.cuda(),
61+
bias=None,
62+
out_dtype=torch.float32,
63+
)
64+
ref1[:, :, l_idx] = res1
65+
66+
res2 = torch._scaled_mm(
67+
a_ref[:, :, l_idx],
68+
b2_ref[:, :, l_idx].transpose(0, 1),
69+
scale_a.cuda(),
70+
scale_b2.cuda(),
71+
bias=None,
72+
out_dtype=torch.float32,
73+
)
74+
ref2[:, :, l_idx] = res2
75+
# Do silu on the first GEMM result and multiply with the second GEMM result
76+
c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16)
77+
return c_ref
78+
79+
80+
def generate_input(
81+
m: int,
82+
n: int,
83+
k: int,
84+
l: int,
85+
seed: int,
86+
):
87+
"""
88+
Generate input tensors for NVFP4 block-scaled dual GEMM with silu activation,
89+
C = silu(A @ B1) * (A @ B2).
90+
91+
Args:
92+
m: Number of rows in matrix A
93+
n: Number of columns in matrix B1 and B2
94+
k: Number of columns in A and rows of B1 and B2
95+
l: Batch size
96+
seed: Random seed for reproducibility
97+
98+
Returns:
99+
Tuple of (a, b, scale_a, scale_b, c) where:
100+
a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
101+
b1: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
102+
b2: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
103+
scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type
104+
scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type
105+
scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type
106+
scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
107+
scale_b1_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
108+
scale_b2_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
109+
c: [m, n, l] - Output matrix in torch.float16 data type
110+
"""
111+
torch.manual_seed(seed)
112+
113+
def create_fp4_tensors(l, mn, k):
114+
# generate uint8 tensor, then convert to float4e2m1fn_x2 data type
115+
# generate all bit patterns
116+
ref_i8 = torch.randint(255, size=(l, mn, k // 2), dtype=torch.uint8, device="cuda")
117+
118+
# for each nibble, only keep the sign bit and 2 LSBs
119+
# the possible values are [-1.5, -1, -0.5, 0, +0.5, +1, +1.5]
120+
ref_i8 = ref_i8 & 0b1011_1011
121+
122+
return ref_i8.permute(1, 2, 0).view(torch.float4_e2m1fn_x2)
123+
124+
# Generate uint8 tensor, then convert to float4e2m1fn_x2 data type
125+
a_ref = create_fp4_tensors(l, m, k)
126+
b1_ref = create_fp4_tensors(l, n, k)
127+
b2_ref = create_fp4_tensors(l, n, k)
128+
a_ref = a_ref.view(torch.float4_e2m1fn_x2)
129+
b1_ref = b1_ref.view(torch.float4_e2m1fn_x2)
130+
b2_ref = b2_ref.view(torch.float4_e2m1fn_x2)
131+
132+
# Create float16 output tensor
133+
c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute(
134+
1, 2, 0
135+
)
136+
137+
# Helper function to prepare the scale factor tensors for both reference
138+
# kernel and customize kernel. The customized data layout can be found in:
139+
# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout
140+
def create_scale_factor_tensors(l, mn, sf_k):
141+
# Create the reference scale factor tensor (mn, sf_k, l) on CPU.
142+
ref_shape = (l, mn, sf_k)
143+
ref_permute_order = (1, 2, 0)
144+
# Init with fp32 tensor in [0,1), then convert to float8_e4m3fn
145+
ref_f8_random_fp32 = torch.rand(ref_shape, dtype=torch.float32, device='cuda')
146+
ref_f8_torch_tensor = ref_f8_random_fp32.to(dtype=torch.float8_e4m3fn)
147+
# permute to match ref_permute_order
148+
ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order)
149+
150+
atom_m = (32, 4)
151+
atom_k = 4
152+
mma_shape = (
153+
l, # batch size
154+
ceil_div(mn, atom_m[0] * atom_m[1]),
155+
ceil_div(sf_k, atom_k),
156+
atom_m[0],
157+
atom_m[1],
158+
atom_k,
159+
)
160+
161+
# Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout
162+
# Which is needed by the CuTe customized kernel
163+
mma_permute_order = (3, 4, 1, 5, 2, 0)
164+
# Generate a random int8 tensor, then convert to float8_e4m3fn
165+
rand_int_tensor = torch.empty(mma_shape, dtype=torch.int8, device='cuda')
166+
reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn)
167+
# Permute according to mma_permute_order
168+
reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order)
169+
170+
# GPU-side vectorized reordering (replaces slow CPU nested loops)
171+
# Create index grids for all dimensions
172+
i_idx = torch.arange(mn, device='cuda')
173+
j_idx = torch.arange(sf_k, device='cuda')
174+
b_idx = torch.arange(l, device='cuda')
175+
176+
# Create meshgrid for all combinations of (i, j, b)
177+
i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij')
178+
179+
# Calculate target indices in vectorized manner
180+
mm = i_grid // (atom_m[0] * atom_m[1])
181+
mm32 = i_grid % atom_m[0]
182+
mm4 = (i_grid % 128) // atom_m[0]
183+
kk = j_grid // atom_k
184+
kk4 = j_grid % atom_k
185+
186+
# Perform the reordering with advanced indexing (all on GPU)
187+
reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid]
188+
189+
return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor
190+
191+
sf_k = ceil_div(k, sf_vec_size)
192+
sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k)
193+
sfb1_ref_cpu, sfb1_ref_permuted = create_scale_factor_tensors(l, n, sf_k)
194+
sfb2_ref_cpu, sfb2_ref_permuted = create_scale_factor_tensors(l, n, sf_k)
195+
196+
return (a_ref, b1_ref, b2_ref, sfa_ref_cpu.to("cuda"), sfb1_ref_cpu.to("cuda"), sfb2_ref_cpu.to("cuda"), sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref)
197+
198+
199+
check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03)

0 commit comments

Comments
 (0)