Skip to content

Commit 617db24

Browse files
authored
cuda : add RDNA4-specific MMVQ parameter table for bs=1 decode (ggml-org#19478)
* mmvq: add RDNA3/RDNA4-specific parameter table (nwarps=8, rows=1) * mmvq: add dedicated RDNA3 parameter table * mmvq: exclude RDNA3.5 (gfx1150/1151) from RDNA3 table
1 parent 1a3d8ed commit 617db24

2 files changed

Lines changed: 81 additions & 16 deletions

File tree

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,17 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
6060
enum mmvq_parameter_table_id {
6161
MMVQ_PARAMETERS_GENERIC = 0,
6262
MMVQ_PARAMETERS_GCN,
63-
MMVQ_PARAMETERS_RDNA2
63+
MMVQ_PARAMETERS_RDNA2,
64+
MMVQ_PARAMETERS_RDNA3_0,
65+
MMVQ_PARAMETERS_RDNA4
6466
};
6567

6668
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
67-
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
69+
#if defined(RDNA4)
70+
return MMVQ_PARAMETERS_RDNA4;
71+
#elif defined(RDNA3_0)
72+
return MMVQ_PARAMETERS_RDNA3_0;
73+
#elif defined(RDNA2) || defined(RDNA3_5)
6874
return MMVQ_PARAMETERS_RDNA2;
6975
#elif defined(GCN) || defined(CDNA)
7076
return MMVQ_PARAMETERS_GCN;
@@ -74,7 +80,13 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
7480
}
7581

7682
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
77-
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
83+
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
84+
return MMVQ_PARAMETERS_RDNA4;
85+
}
86+
if (GGML_CUDA_CC_IS_RDNA3_0(cc)) {
87+
return MMVQ_PARAMETERS_RDNA3_0;
88+
}
89+
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) {
7890
return MMVQ_PARAMETERS_RDNA2;
7991
}
8092
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
@@ -83,7 +95,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
8395
return MMVQ_PARAMETERS_GENERIC;
8496
}
8597

86-
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
98+
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
8799
if (table_id == MMVQ_PARAMETERS_GENERIC) {
88100
switch (ncols_dst) {
89101
case 1:
@@ -114,6 +126,50 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
114126
return 1;
115127
}
116128
}
129+
if (table_id == MMVQ_PARAMETERS_RDNA4) {
130+
// nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).
131+
// Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
132+
// pressure and lookup table contention at higher thread counts.
133+
if (ncols_dst == 1) {
134+
switch (type) {
135+
case GGML_TYPE_Q4_0:
136+
case GGML_TYPE_Q4_1:
137+
case GGML_TYPE_Q5_0:
138+
case GGML_TYPE_Q5_1:
139+
case GGML_TYPE_Q8_0:
140+
case GGML_TYPE_Q2_K:
141+
case GGML_TYPE_Q4_K:
142+
case GGML_TYPE_Q5_K:
143+
case GGML_TYPE_Q6_K:
144+
case GGML_TYPE_IQ4_NL:
145+
case GGML_TYPE_IQ4_XS:
146+
return 8;
147+
default:
148+
return 1;
149+
}
150+
}
151+
return 1;
152+
}
153+
if (table_id == MMVQ_PARAMETERS_RDNA3_0) {
154+
// RDNA3 (W7900): stricter whitelist than RDNA4.
155+
// Q2_K / Q5_K / IQ4_XS regress in full quant sweeps.
156+
if (ncols_dst == 1) {
157+
switch (type) {
158+
case GGML_TYPE_Q4_0:
159+
case GGML_TYPE_Q4_1:
160+
case GGML_TYPE_Q5_0:
161+
case GGML_TYPE_Q5_1:
162+
case GGML_TYPE_Q8_0:
163+
case GGML_TYPE_Q4_K:
164+
case GGML_TYPE_Q6_K:
165+
case GGML_TYPE_IQ4_NL:
166+
return 8;
167+
default:
168+
return 1;
169+
}
170+
}
171+
return 1;
172+
}
117173
return 1;
118174
}
119175

@@ -138,7 +194,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
138194
}
139195

140196
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
141-
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
197+
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142198
static __global__ void mul_mat_vec_q(
143199
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
144200
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -151,7 +207,7 @@ static __global__ void mul_mat_vec_q(
151207
constexpr int qi = ggml_cuda_type_traits<type>::qi;
152208
constexpr int vdr = get_vdr_mmvq(type);
153209
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
154-
constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
210+
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
155211
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
156212
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
157213

@@ -355,12 +411,13 @@ static __global__ void mul_mat_vec_q(
355411
}
356412
}
357413

414+
template<ggml_type type>
358415
static std::pair<dim3, dim3> calc_launch_params(
359416
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
360417
const int warp_size, const mmvq_parameter_table_id table_id) {
361418
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
362419
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
363-
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
420+
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
364421
return {block_nums, block_dims};
365422
}
366423

@@ -420,7 +477,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
420477
if (has_ids && ncols_dst > 1) {
421478
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
422479
constexpr int c_ncols_dst = 1;
423-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
480+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
424481
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
425482
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
426483
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@@ -431,63 +488,63 @@ static void mul_mat_vec_q_switch_ncols_dst(
431488
switch (ncols_dst) {
432489
case 1: {
433490
constexpr int c_ncols_dst = 1;
434-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
491+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
435492
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
436493
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437494
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
438495
dims.first, dims.second, 0, ids_stride, stream);
439496
} break;
440497
case 2: {
441498
constexpr int c_ncols_dst = 2;
442-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
499+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
443500
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
444501
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
445502
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
446503
dims.first, dims.second, 0, ids_stride, stream);
447504
} break;
448505
case 3: {
449506
constexpr int c_ncols_dst = 3;
450-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
507+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
451508
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
452509
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
453510
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
454511
dims.first, dims.second, 0, ids_stride, stream);
455512
} break;
456513
case 4: {
457514
constexpr int c_ncols_dst = 4;
458-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
515+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
459516
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
460517
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461518
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
462519
dims.first, dims.second, 0, ids_stride, stream);
463520
} break;
464521
case 5: {
465522
constexpr int c_ncols_dst = 5;
466-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
523+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
467524
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
468525
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
469526
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
470527
dims.first, dims.second, 0, ids_stride, stream);
471528
} break;
472529
case 6: {
473530
constexpr int c_ncols_dst = 6;
474-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
531+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
475532
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
476533
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
477534
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
478535
dims.first, dims.second, 0, ids_stride, stream);
479536
} break;
480537
case 7: {
481538
constexpr int c_ncols_dst = 7;
482-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
539+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
483540
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
484541
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
485542
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
486543
dims.first, dims.second, 0, ids_stride, stream);
487544
} break;
488545
case 8: {
489546
constexpr int c_ncols_dst = 8;
490-
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
547+
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
491548
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
492549
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
493550
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,14 @@
207207
#define RDNA3
208208
#endif // defined(__GFX11__)
209209

210+
#if defined(__gfx1150__) || defined(__gfx1151__)
211+
#define RDNA3_5
212+
#endif // defined(__gfx1150__) || defined(__gfx1151__)
213+
214+
#if defined(RDNA3) && !defined(RDNA3_5)
215+
#define RDNA3_0
216+
#endif // defined(RDNA3) && !defined(RDNA3_5)
217+
210218
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
211219
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
212220
#define RDNA2

0 commit comments

Comments
 (0)