Skip to content

Commit ae40cd2

Browse files
CUDA: limit number of FA stream-k CUDA blocks (ggml-org#20586)
1 parent ceef6b5 commit ae40cd2

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ void launch_fattn(
892892
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
893893
const int gqa_ratio = Q->ne[2] / K->ne[2];
894894
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
895-
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
895+
const int ntiles_dst = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
896896

897897
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
898898
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -919,37 +919,37 @@ void launch_fattn(
919919
GGML_ASSERT(max_blocks_per_sm > 0);
920920
int parallel_blocks = max_blocks_per_sm;
921921

922+
const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length.
923+
922924
dim3 blocks_num;
923925
if (stream_k) {
924926
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
925927
const int max_blocks = max_blocks_per_sm*nsm;
926-
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
927-
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
928+
const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
929+
const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
928930

929-
const int nblocks_stream_k = max_blocks;
931+
const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
930932

931933
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
932934

933-
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
935+
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
934936
blocks_num.y = 1;
935937
blocks_num.z = 1;
936938

937-
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
939+
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
938940
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
939941
}
940942
} else {
941-
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
942-
943943
// parallel_blocks must not be larger than what the tensor size allows:
944-
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
944+
parallel_blocks = std::min(parallel_blocks, ntiles_KV);
945945

946946
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
947947
// Test whether parallel_blocks can be set to a higher value for better efficiency.
948948
const int blocks_per_wave = nsm * max_blocks_per_sm;
949949
int nwaves_best = 0;
950950
int efficiency_percent_best = 0;
951-
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
952-
const int nblocks_total = ntiles_total * parallel_blocks_test;
951+
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV; ++parallel_blocks_test) {
952+
const int nblocks_total = ntiles_dst * parallel_blocks_test;
953953
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
954954
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
955955

@@ -1015,7 +1015,7 @@ void launch_fattn(
10151015
CUDA_CHECK(cudaGetLastError());
10161016

10171017
if (stream_k) {
1018-
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
1018+
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
10191019
const dim3 block_dim_combine(DV, 1, 1);
10201020
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
10211021

0 commit comments

Comments
 (0)