Skip to content

Commit adb8b4e

Browse files
committed
ggml-cpu: add rvv 512b,1024b impls for iq4_xs
1 parent 2785c94 commit adb8b4e

1 file changed

Lines changed: 201 additions & 0 deletions

File tree

ggml/src/ggml-cpu/arch/riscv/quants.c

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3918,6 +3918,201 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT
39183918
*s = sumf;
39193919
}
39203920

3921+
static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3922+
assert(nrc == 1);
3923+
UNUSED(nrc);
3924+
UNUSED(bx);
3925+
UNUSED(by);
3926+
UNUSED(bs);
3927+
assert(n % QK_K == 0);
3928+
3929+
const block_iq4_xs * GGML_RESTRICT x = vx;
3930+
const block_q8_K * GGML_RESTRICT y = vy;
3931+
3932+
const int nb = n / QK_K;
3933+
3934+
const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16);
3935+
float sumf = 0;
3936+
3937+
// Indices for re-ordering IQ4 data.
3938+
const uint16_t index[32] = {
3939+
0, 1, 16, 17,
3940+
2, 3, 18, 19,
3941+
4, 5,20, 21,
3942+
6, 7, 22, 23,
3943+
8, 9, 24, 25,
3944+
10, 11, 26, 27,
3945+
12, 13,28, 29,
3946+
14, 15, 30, 31,
3947+
};
3948+
const vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 32);
3949+
3950+
for (int ibl = 0; ibl < nb; ++ibl) {
3951+
const int8_t * q8 = y[ibl].qs;
3952+
const uint8_t * iq4 = x[ibl].qs;
3953+
uint16_t h = x[ibl].scales_h;
3954+
3955+
int sumi = 0;
3956+
3957+
#pragma GCC unroll 1
3958+
// Process the entire super-block together.
3959+
for (int ib = 0; ib < QK_K / 256; ++ib) {
3960+
// Weights and activations.
3961+
const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 128);
3962+
iq4 += 128;
3963+
3964+
// Unpack the weight blocks.
3965+
const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 128);
3966+
const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 128);
3967+
const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi);
3968+
const vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 32));
3969+
const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 256);
3970+
3971+
__asm__ __volatile__("" ::: "memory");
3972+
3973+
// Multiply with activations.
3974+
const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256);
3975+
const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 256);
3976+
q8 += 256;
3977+
3978+
// Reduce separately.
3979+
const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32));
3980+
const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32));
3981+
const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32));
3982+
const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32));
3983+
const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), __riscv_vmv_v_x_i32m1(0, 1), 32));
3984+
const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), __riscv_vmv_v_x_i32m1(0, 1), 32));
3985+
const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), __riscv_vmv_v_x_i32m1(0, 1), 32));
3986+
const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), __riscv_vmv_v_x_i32m1(0, 1), 32));
3987+
3988+
3989+
const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32;
3990+
const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32;
3991+
const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32;
3992+
const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32;
3993+
h >>= 8;
3994+
const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32;
3995+
const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32;
3996+
const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32;
3997+
const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32;
3998+
3999+
sumi += acc0 * ls0;
4000+
sumi += acc1 * ls1;
4001+
sumi += acc2 * ls2;
4002+
sumi += acc3 * ls3;
4003+
sumi += acc4 * ls4;
4004+
sumi += acc5 * ls5;
4005+
sumi += acc6 * ls6;
4006+
sumi += acc7 * ls7;
4007+
4008+
__asm__ __volatile__("" ::: "memory");
4009+
}
4010+
4011+
sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi);
4012+
}
4013+
4014+
*s = sumf;
4015+
}
4016+
4017+
static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
4018+
assert(nrc == 1);
4019+
UNUSED(nrc);
4020+
UNUSED(bx);
4021+
UNUSED(by);
4022+
UNUSED(bs);
4023+
assert(n % QK_K == 0);
4024+
4025+
const block_iq4_xs * GGML_RESTRICT x = vx;
4026+
const block_q8_K * GGML_RESTRICT y = vy;
4027+
4028+
const int nb = n / QK_K;
4029+
4030+
const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16);
4031+
float sumf = 0;
4032+
4033+
// Indices for re-ordering IQ4 data.
4034+
const uint16_t index[32] = {
4035+
0, 1, 16, 17,
4036+
2, 3, 18, 19,
4037+
4, 5,20, 21,
4038+
6, 7, 22, 23,
4039+
8, 9, 24, 25,
4040+
10, 11, 26, 27,
4041+
12, 13,28, 29,
4042+
14, 15, 30, 31,
4043+
};
4044+
const vuint16mf2_t i_vec = __riscv_vle16_v_u16mf2(index, 32);
4045+
4046+
for (int ibl = 0; ibl < nb; ++ibl) {
4047+
const int8_t * q8 = y[ibl].qs;
4048+
const uint8_t * iq4 = x[ibl].qs;
4049+
uint16_t h = x[ibl].scales_h;
4050+
4051+
int sumi = 0;
4052+
4053+
#pragma GCC unroll 1
4054+
// Process the entire super-block together.
4055+
for (int ib = 0; ib < QK_K / 256; ++ib) {
4056+
// Weights and activations.
4057+
const vuint8m1_t iq4_packed = __riscv_vle8_v_u8m1(iq4, 128);
4058+
iq4 += 128;
4059+
4060+
// Unpack the weight blocks.
4061+
const vuint8m1_t iq4bits_lo = __riscv_vand_vx_u8m1(iq4_packed, 0xf, 128);
4062+
const vuint8m1_t iq4bits_hi = __riscv_vsrl_vx_u8m1(iq4_packed, 4, 128);
4063+
const vuint8m2_t iq4bits = __riscv_vcreate_v_u8m1_u8m2(iq4bits_lo, iq4bits_hi);
4064+
const vuint8m2_t iq4bits_reorder = __riscv_vreinterpret_v_u64m2_u8m2(__riscv_vrgatherei16_vv_u64m2(__riscv_vreinterpret_v_u8m2_u64m2(iq4bits), i_vec, 32));
4065+
const vint8m2_t iq4b = __riscv_vrgather_vv_i8m2(values, iq4bits_reorder, 256);
4066+
4067+
__asm__ __volatile__("" ::: "memory");
4068+
4069+
// Multiply with activations.
4070+
const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256);
4071+
const vint16m4_t prod = __riscv_vwmul_vv_i16m4(iq4b, q8b, 256);
4072+
q8 += 256;
4073+
4074+
// Mask for processing 32 elements per prod register.
4075+
const vuint16m1_t p_index = __riscv_vid_v_u16m1(64);
4076+
const vbool16_t p_mask = __riscv_vmsgtu_vx_u16m1_b16(p_index, 31, 64);
4077+
4078+
// Reduce separately.
4079+
const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32));
4080+
const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 64));
4081+
const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32));
4082+
const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 64));
4083+
const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32));
4084+
const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 64));
4085+
const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32));
4086+
const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 64));
4087+
4088+
const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32;
4089+
const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32;
4090+
const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32;
4091+
const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32;
4092+
h >>= 8;
4093+
const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32;
4094+
const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32;
4095+
const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32;
4096+
const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32;
4097+
4098+
sumi += acc0 * ls0;
4099+
sumi += acc1 * ls1;
4100+
sumi += acc2 * ls2;
4101+
sumi += acc3 * ls3;
4102+
sumi += acc4 * ls4;
4103+
sumi += acc5 * ls5;
4104+
sumi += acc6 * ls6;
4105+
sumi += acc7 * ls7;
4106+
4107+
__asm__ __volatile__("" ::: "memory");
4108+
}
4109+
4110+
sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi);
4111+
}
4112+
4113+
*s = sumf;
4114+
}
4115+
39214116
void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
39224117
#if defined __riscv_v_intrinsic
39234118
switch (__riscv_vlenb() * 8) {
@@ -3927,6 +4122,12 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
39274122
case 256:
39284123
ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
39294124
break;
4125+
case 512:
4126+
ggml_vec_dot_iq4_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
4127+
break;
4128+
case 1024:
4129+
ggml_vec_dot_iq4_xs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc);
4130+
break;
39304131
default:
39314132
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
39324133
break;

0 commit comments

Comments
 (0)