@@ -19,8 +19,9 @@ def _chunk_scaled_dot_kkt_fwd_eager(k, g_cumsum, beta, chunk_size):
1919 g_c = g_cumsum .float ().reshape (B , NT , C , H ).permute (0 , 1 , 3 , 2 )
2020 beta_c = beta .float ().reshape (B , NT , C , H ).permute (0 , 1 , 3 , 2 )
2121 kkt = k_c @ k_c .transpose (- 1 , - 2 )
22- g_diff = g_c .unsqueeze (- 1 ) - g_c .unsqueeze (- 2 )
2322 strict_lower = torch .tril (torch .ones (C , C , device = k .device ), diagonal = - 1 )
23+ g_diff = g_c .unsqueeze (- 1 ) - g_c .unsqueeze (- 2 )
24+ g_diff = g_diff * strict_lower
2425 A = kkt * beta_c .unsqueeze (- 1 ) * torch .exp (g_diff ) * strict_lower
2526 return A .permute (0 , 1 , 3 , 2 , 4 ).reshape (B , T , H , C ).to (torch .float32 )
2627
@@ -103,9 +104,11 @@ def ref_kernel(data: input_t) -> output_t:
103104 v_c = v_new .float ().reshape (B , NT , C , H , V ).permute (0 , 1 , 3 , 2 , 4 )
104105 g_c = g .float ().reshape (B , NT , C , H ).permute (0 , 1 , 3 , 2 )
105106 o_inter = (q_c @ h .float ()) * torch .exp (g_c ).unsqueeze (- 1 )
106- qk = q_c @ k_c .transpose (- 1 , - 2 ) * torch .exp (g_c .unsqueeze (- 1 ) - g_c .unsqueeze (- 2 ))
107- causal = torch .tril (torch .ones (C , C , device = q .device ))
108- o = (o_inter + (qk * causal ) @ v_c ) * scale
107+ causal = torch .tril (torch .ones (C , C , dtype = torch .bool , device = q .device ))
108+ g_diff = g_c .unsqueeze (- 1 ) - g_c .unsqueeze (- 2 )
109+ g_diff = torch .where (causal , g_diff , torch .zeros_like (g_diff ))
110+ qk = q_c @ k_c .transpose (- 1 , - 2 ) * torch .exp (g_diff ) * causal
111+ o = (o_inter + qk @ v_c ) * scale
109112 return o .permute (0 , 1 , 3 , 2 , 4 ).reshape (B , T , H , V ).to (q .dtype )
110113
111114
0 commit comments