Skip to content

Commit 07f0321

Browse files
authored
change fp4 init range (#96)
1 parent 9ca8ea5 commit 07f0321

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

problems/nvidia/nvfp4_group_gemm/reference.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,17 @@ def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor):
123123
return reordered_f8_tensor
124124

125125

126+
def _create_fp4_tensors(l, mn, k):
127+
# generate uint8 tensor, then convert to float4e2m1fn_x2 data type
128+
# generate all bit patterns
129+
ref_i8 = torch.randint(255, size=(l, mn, k // 2), dtype=torch.uint8, device="cuda")
130+
131+
# for each nibble, only keep the sign bit and 2 LSBs
132+
# the possible values are [-1.5, -1, -0.5, 0, +0.5, +1, +1.5]
133+
ref_i8 = ref_i8 & 0b1011_1011
134+
return ref_i8.permute(1, 2, 0).view(torch.float4_e2m1fn_x2)
135+
136+
126137
def generate_input(
127138
m: tuple,
128139
n: tuple,
@@ -165,14 +176,8 @@ def generate_input(
165176
mi = m[group_idx]
166177
ni = n[group_idx]
167178
ki = k[group_idx]
168-
a_ref = torch.randint(
169-
-1, 2, (l, mi, ki // 2), dtype=torch.int8, device="cuda"
170-
).permute(1, 2, 0)
171-
b_ref = torch.randint(
172-
-1, 2, (l, ni, ki // 2), dtype=torch.int8, device="cuda"
173-
).permute(1, 2, 0)
174-
a_ref = a_ref.view(torch.float4_e2m1fn_x2)
175-
b_ref = b_ref.view(torch.float4_e2m1fn_x2)
179+
a_ref = _create_fp4_tensors(l, mi, ki)
180+
b_ref = _create_fp4_tensors(l, ni, ki)
176181

177182
c_ref = torch.randn((l, mi, ni), dtype=torch.float16, device="cuda").permute(
178183
1, 2, 0

0 commit comments

Comments
 (0)