Skip to content

Commit 46dba9f

Browse files
authored
vulkan: fix flash attention dot product precision (ggml-org#20589)
1 parent de8f01c commit 46dba9f

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ void main() {
245245
#endif
246246
}
247247
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
248-
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
248+
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
249249
}
250250
}
251251
}
@@ -270,7 +270,7 @@ void main() {
270270
#endif
271271
}
272272
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
273-
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
273+
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
274274
}
275275
}
276276
}

0 commit comments

Comments
 (0)