@@ -60,11 +60,17 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
6060enum mmvq_parameter_table_id {
6161 MMVQ_PARAMETERS_GENERIC = 0 ,
6262 MMVQ_PARAMETERS_GCN,
63- MMVQ_PARAMETERS_RDNA2
63+ MMVQ_PARAMETERS_RDNA2,
64+ MMVQ_PARAMETERS_RDNA3_0,
65+ MMVQ_PARAMETERS_RDNA4
6466};
6567
6668static constexpr __device__ mmvq_parameter_table_id get_device_table_id () {
67- #if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
69+ #if defined(RDNA4)
70+ return MMVQ_PARAMETERS_RDNA4;
71+ #elif defined(RDNA3_0)
72+ return MMVQ_PARAMETERS_RDNA3_0;
73+ #elif defined(RDNA2) || defined(RDNA3_5)
6874 return MMVQ_PARAMETERS_RDNA2;
6975#elif defined(GCN) || defined(CDNA)
7076 return MMVQ_PARAMETERS_GCN;
@@ -74,7 +80,13 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
7480}
7581
7682static __host__ mmvq_parameter_table_id get_device_table_id (int cc) {
77- if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3 (cc) || GGML_CUDA_CC_IS_RDNA4 (cc)) {
83+ if (GGML_CUDA_CC_IS_RDNA4 (cc)) {
84+ return MMVQ_PARAMETERS_RDNA4;
85+ }
86+ if (GGML_CUDA_CC_IS_RDNA3_0 (cc)) {
87+ return MMVQ_PARAMETERS_RDNA3_0;
88+ }
89+ if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3_5 (cc)) {
7890 return MMVQ_PARAMETERS_RDNA2;
7991 }
8092 if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
@@ -83,7 +95,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
8395 return MMVQ_PARAMETERS_GENERIC;
8496}
8597
86- static constexpr __host__ __device__ int calc_nwarps (int ncols_dst, mmvq_parameter_table_id table_id) {
98+ static constexpr __host__ __device__ int calc_nwarps (ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
8799 if (table_id == MMVQ_PARAMETERS_GENERIC) {
88100 switch (ncols_dst) {
89101 case 1 :
@@ -114,6 +126,50 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
114126 return 1 ;
115127 }
116128 }
129+ if (table_id == MMVQ_PARAMETERS_RDNA4) {
130+ // nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).
131+ // Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
132+ // pressure and lookup table contention at higher thread counts.
133+ if (ncols_dst == 1 ) {
134+ switch (type) {
135+ case GGML_TYPE_Q4_0:
136+ case GGML_TYPE_Q4_1:
137+ case GGML_TYPE_Q5_0:
138+ case GGML_TYPE_Q5_1:
139+ case GGML_TYPE_Q8_0:
140+ case GGML_TYPE_Q2_K:
141+ case GGML_TYPE_Q4_K:
142+ case GGML_TYPE_Q5_K:
143+ case GGML_TYPE_Q6_K:
144+ case GGML_TYPE_IQ4_NL:
145+ case GGML_TYPE_IQ4_XS:
146+ return 8 ;
147+ default :
148+ return 1 ;
149+ }
150+ }
151+ return 1 ;
152+ }
153+ if (table_id == MMVQ_PARAMETERS_RDNA3_0) {
154+ // RDNA3 (W7900): stricter whitelist than RDNA4.
155+ // Q2_K / Q5_K / IQ4_XS regress in full quant sweeps.
156+ if (ncols_dst == 1 ) {
157+ switch (type) {
158+ case GGML_TYPE_Q4_0:
159+ case GGML_TYPE_Q4_1:
160+ case GGML_TYPE_Q5_0:
161+ case GGML_TYPE_Q5_1:
162+ case GGML_TYPE_Q8_0:
163+ case GGML_TYPE_Q4_K:
164+ case GGML_TYPE_Q6_K:
165+ case GGML_TYPE_IQ4_NL:
166+ return 8 ;
167+ default :
168+ return 1 ;
169+ }
170+ }
171+ return 1 ;
172+ }
117173 return 1 ;
118174}
119175
@@ -138,7 +194,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
138194}
139195
140196template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false >
141- __launch_bounds__ (calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
197+ __launch_bounds__ (calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142198static __global__ void mul_mat_vec_q(
143199 const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
144200 const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -151,7 +207,7 @@ static __global__ void mul_mat_vec_q(
151207 constexpr int qi = ggml_cuda_type_traits<type>::qi;
152208 constexpr int vdr = get_vdr_mmvq (type);
153209 constexpr mmvq_parameter_table_id table_id = get_device_table_id ();
154- constexpr int nwarps = calc_nwarps (ncols_dst, table_id);
210+ constexpr int nwarps = calc_nwarps (type, ncols_dst, table_id);
155211 constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_dst, table_id);
156212 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
157213
@@ -355,12 +411,13 @@ static __global__ void mul_mat_vec_q(
355411 }
356412}
357413
414+ template <ggml_type type>
358415static std::pair<dim3 , dim3 > calc_launch_params (
359416 const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
360417 const int warp_size, const mmvq_parameter_table_id table_id) {
361418 const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_dst, table_id) - 1 ) / calc_rows_per_block (ncols_dst, table_id);
362419 const dim3 block_nums (nblocks, nchannels_dst, nsamples_or_ntokens);
363- const dim3 block_dims (warp_size, calc_nwarps (ncols_dst, table_id), 1 );
420+ const dim3 block_dims (warp_size, calc_nwarps (type, ncols_dst, table_id), 1 );
364421 return {block_nums, block_dims};
365422}
366423
@@ -420,7 +477,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
420477 if (has_ids && ncols_dst > 1 ) {
421478 // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
422479 constexpr int c_ncols_dst = 1 ;
423- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
480+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
424481 mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true >(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
425482 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
426483 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -431,63 +488,63 @@ static void mul_mat_vec_q_switch_ncols_dst(
431488 switch (ncols_dst) {
432489 case 1 : {
433490 constexpr int c_ncols_dst = 1 ;
434- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
491+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
435492 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
436493 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437494 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
438495 dims.first , dims.second , 0 , ids_stride, stream);
439496 } break ;
440497 case 2 : {
441498 constexpr int c_ncols_dst = 2 ;
442- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
499+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
443500 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
444501 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
445502 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
446503 dims.first , dims.second , 0 , ids_stride, stream);
447504 } break ;
448505 case 3 : {
449506 constexpr int c_ncols_dst = 3 ;
450- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
507+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
451508 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
452509 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
453510 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
454511 dims.first , dims.second , 0 , ids_stride, stream);
455512 } break ;
456513 case 4 : {
457514 constexpr int c_ncols_dst = 4 ;
458- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
515+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
459516 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
460517 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461518 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
462519 dims.first , dims.second , 0 , ids_stride, stream);
463520 } break ;
464521 case 5 : {
465522 constexpr int c_ncols_dst = 5 ;
466- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
523+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
467524 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
468525 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
469526 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
470527 dims.first , dims.second , 0 , ids_stride, stream);
471528 } break ;
472529 case 6 : {
473530 constexpr int c_ncols_dst = 6 ;
474- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
531+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
475532 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
476533 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
477534 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
478535 dims.first , dims.second , 0 , ids_stride, stream);
479536 } break ;
480537 case 7 : {
481538 constexpr int c_ncols_dst = 7 ;
482- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
539+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
483540 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
484541 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
485542 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
486543 dims.first , dims.second , 0 , ids_stride, stream);
487544 } break ;
488545 case 8 : {
489546 constexpr int c_ncols_dst = 8 ;
490- std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
547+ std::pair<dim3 , dim3 > dims = calc_launch_params<type> (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
491548 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
492549 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
493550 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
0 commit comments