@@ -892,7 +892,7 @@ void launch_fattn(
892892 const int ntiles_x = ((Q->ne [1 ] + ncols1 - 1 ) / ncols1);
893893 const int gqa_ratio = Q->ne [2 ] / K->ne [2 ];
894894 const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1 ) / ncols2);
895- const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne [2 ] * Q->ne [3 ];
895+ const int ntiles_dst = ntiles_x * ntiles_z_gqa * K->ne [2 ] * Q->ne [3 ];
896896
897897 // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
898898 // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -919,37 +919,37 @@ void launch_fattn(
919919 GGML_ASSERT (max_blocks_per_sm > 0 );
920920 int parallel_blocks = max_blocks_per_sm;
921921
922+ const int ntiles_KV = (K->ne [1 ] + nbatch_fa - 1 ) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length.
923+
922924 dim3 blocks_num;
923925 if (stream_k) {
924926 // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
925927 const int max_blocks = max_blocks_per_sm*nsm;
926- const int tiles_nwaves = (ntiles_total + max_blocks - 1 ) / max_blocks;
927- const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
928+ const int tiles_nwaves = (ntiles_dst + max_blocks - 1 ) / max_blocks;
929+ const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
928930
929- const int nblocks_stream_k = max_blocks;
931+ const int nblocks_stream_k = std::min ( max_blocks, ntiles_KV*ntiles_dst) ;
930932
931933 const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available (cc) || tiles_efficiency_percent < 75 ;
932934
933- blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total ;
935+ blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst ;
934936 blocks_num.y = 1 ;
935937 blocks_num.z = 1 ;
936938
937- if (ntiles_total % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
939+ if (ntiles_dst % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
938940 dst_tmp_meta.alloc ((size_t (blocks_num.x ) * ncols * (2 + DV/2 )));
939941 }
940942 } else {
941- const int ntiles_KQ = (K->ne [1 ] + nbatch_fa - 1 ) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
942-
943943 // parallel_blocks must not be larger than what the tensor size allows:
944- parallel_blocks = std::min (parallel_blocks, ntiles_KQ );
944+ parallel_blocks = std::min (parallel_blocks, ntiles_KV );
945945
946946 // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
947947 // Test whether parallel_blocks can be set to a higher value for better efficiency.
948948 const int blocks_per_wave = nsm * max_blocks_per_sm;
949949 int nwaves_best = 0 ;
950950 int efficiency_percent_best = 0 ;
951- for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ ; ++parallel_blocks_test) {
952- const int nblocks_total = ntiles_total * parallel_blocks_test;
951+ for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV ; ++parallel_blocks_test) {
952+ const int nblocks_total = ntiles_dst * parallel_blocks_test;
953953 const int nwaves = (nblocks_total + blocks_per_wave - 1 ) / blocks_per_wave;
954954 const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
955955
@@ -1015,7 +1015,7 @@ void launch_fattn(
10151015 CUDA_CHECK (cudaGetLastError ());
10161016
10171017 if (stream_k) {
1018- if (ntiles_total % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
1018+ if (ntiles_dst % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
10191019 const dim3 block_dim_combine (DV, 1 , 1 );
10201020 const dim3 blocks_num_combine = {blocks_num.x , ncols1, ncols2};
10211021
0 commit comments