@@ -123,6 +123,17 @@ def create_reordered_scale_factor_tensor(l, mn, k, ref_f8_tensor):
123123 return reordered_f8_tensor
124124
125125
126+ def _create_fp4_tensors (l , mn , k ):
127+ # generate uint8 tensor, then convert to float4e2m1fn_x2 data type
128+ # generate all bit patterns
129+ ref_i8 = torch .randint (255 , size = (l , mn , k // 2 ), dtype = torch .uint8 , device = "cuda" )
130+
131+ # for each nibble, only keep the sign bit and 2 LSBs
132+ # the possible values are [-1.5, -1, -0.5, 0, +0.5, +1, +1.5]
133+ ref_i8 = ref_i8 & 0b1011_1011
134+ return ref_i8 .permute (1 , 2 , 0 ).view (torch .float4_e2m1fn_x2 )
135+
136+
126137def generate_input (
127138 m : tuple ,
128139 n : tuple ,
@@ -165,14 +176,8 @@ def generate_input(
165176 mi = m [group_idx ]
166177 ni = n [group_idx ]
167178 ki = k [group_idx ]
168- a_ref = torch .randint (
169- - 1 , 2 , (l , mi , ki // 2 ), dtype = torch .int8 , device = "cuda"
170- ).permute (1 , 2 , 0 )
171- b_ref = torch .randint (
172- - 1 , 2 , (l , ni , ki // 2 ), dtype = torch .int8 , device = "cuda"
173- ).permute (1 , 2 , 0 )
174- a_ref = a_ref .view (torch .float4_e2m1fn_x2 )
175- b_ref = b_ref .view (torch .float4_e2m1fn_x2 )
179+ a_ref = _create_fp4_tensors (l , mi , ki )
180+ b_ref = _create_fp4_tensors (l , ni , ki )
176181
177182 c_ref = torch .randn ((l , mi , ni ), dtype = torch .float16 , device = "cuda" ).permute (
178183 1 , 2 , 0
0 commit comments