Skip to content

Commit c6d9c9f

Browse files
improved upstreamed iq2_xs_vl256 implementation
1 parent 8ea532d commit c6d9c9f

1 file changed

Lines changed: 87 additions & 53 deletions

File tree

ggml/src/ggml-cpu/arch/riscv/quants.c

Lines changed: 87 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3796,8 +3796,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
37963796
case 256:
37973797
ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
37983798
break;
3799+
case 512:
3800+
ggml_vec_dot_iq2_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
3801+
break;
37993802
default:
3800-
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3803+
ggml_vec_dot_iq2_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc);
38013804
break;
38023805
}
38033806
#else
@@ -3844,13 +3847,16 @@ static const int8_t keven_signs_q2xs[1024] = {
38443847
};
38453848
#endif
38463849

3847-
static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3850+
static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n,
3851+
float * GGML_RESTRICT s,
3852+
size_t bs,
3853+
const void * GGML_RESTRICT vx,
3854+
size_t bx,
3855+
const void * GGML_RESTRICT vy,
3856+
size_t by,
3857+
int nrc) {
38483858
assert(n % QK_K == 0);
3849-
assert(nrc == 1);
3850-
UNUSED(nrc);
3851-
UNUSED(bx);
3852-
UNUSED(by);
3853-
UNUSED(bs);
3859+
(void)nrc; (void)bx; (void)by; (void)bs;
38543860

38553861
const block_iq2_xs * GGML_RESTRICT x = vx;
38563862
const block_q8_K * GGML_RESTRICT y = vy;
@@ -3869,61 +3875,74 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT
38693875

38703876
int32_t sum_int = 0;
38713877

3872-
// Loop over 4 subblocks of 64 elements (QK_K = 256)
3873-
for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) {
3874-
// Load 8 uint16 indices (controls 64 values)
3875-
vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8);
3876-
qs += 8;
3878+
for (int ib128 = 0; ib128 < 2; ++ib128) {
38773879

3878-
// Extract indices for grid (low 9 bits) and signs (high 7 bits)
3879-
// Multiply by 8 (<< 3) for byte offsets into the uint64 tables
3880-
vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8);
3881-
vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8);
3880+
vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 16);
3881+
qs += 16;
38823882

3883-
vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8);
3884-
vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8);
3883+
// Prepare offsets for grid and signs
3884+
vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 16), 3, 16);
3885+
vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 16), 3, 16);
38853886

3886-
vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64));
3887-
vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64));
3887+
// Indexed load 128 weights (16 x 8-byte chunks)
3888+
vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 16);
3889+
vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 16);
38883890

3889-
// Apply signs
3890-
vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64);
3891+
vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64));
3892+
vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64));
38913893

3892-
// Load Q8 weights (64 elements)
3893-
vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64);
3894-
q8 += 64;
3894+
// Apply signs to get dequantized IQ2 values
3895+
vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 128);
3896+
asm volatile("" ::: "memory");
38953897

3896-
// Multiply (Widening to int16, 64 elements -> LMUL=4)
3897-
vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64);
3898+
// Load corresponding Q8 weights
3899+
vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 128);
3900+
q8 += 128;
3901+
3902+
vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 128);
3903+
asm volatile("" ::: "memory");
3904+
3905+
uint8_t sc0 = scales[0];
3906+
uint8_t sc1 = scales[1];
3907+
uint8_t sc2 = scales[2];
3908+
uint8_t sc3 = scales[3];
3909+
scales += 4;
38983910

3899-
// Reduction
39003911
vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);
39013912

3902-
int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
3903-
__riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16));
3904-
int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
3905-
__riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16));
3906-
int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
3907-
__riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16));
3908-
int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
3909-
__riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16));
3910-
3911-
// Apply Scales
3912-
const uint8_t scale_byte_1 = scales[0];
3913-
const uint8_t scale_byte_2 = scales[1];
3914-
scales += 2;
3913+
// 9. Reduce each 16-element chunk and apply corresponding nibble scale
3914+
3915+
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), zero_vec, 16));
3916+
sum_int += s0 * ((sc0 & 0x0F) * 2 + 1);
3917+
3918+
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), zero_vec, 16));
3919+
sum_int += s1 * ((sc0 >> 4) * 2 + 1);
3920+
3921+
int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), zero_vec, 16));
3922+
sum_int += s2 * ((sc1 & 0x0F) * 2 + 1);
3923+
3924+
int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), zero_vec, 16));
3925+
sum_int += s3 * ((sc1 >> 4) * 2 + 1);
3926+
3927+
int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), zero_vec, 16));
3928+
sum_int += s4 * ((sc2 & 0x0F) * 2 + 1);
39153929

3916-
sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1);
3917-
sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1);
3918-
sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1);
3919-
sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1);
3930+
int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), zero_vec, 16));
3931+
sum_int += s5 * ((sc2 >> 4) * 2 + 1);
3932+
3933+
int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), zero_vec, 16));
3934+
sum_int += s6 * ((sc3 & 0x0F) * 2 + 1);
3935+
3936+
int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), zero_vec, 16));
3937+
sum_int += s7 * ((sc3 >> 4) * 2 + 1);
39203938
}
39213939

3922-
sumf += d * sum_int;
3940+
sumf += d * (float)sum_int;
39233941
}
39243942
*s = 0.125f * sumf;
39253943
}
39263944

3945+
39273946
static void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
39283947
assert(n % QK_K == 0);
39293948
assert(nrc == 1);
@@ -3992,11 +4011,14 @@ static void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_
39924011
void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
39934012
#if defined __riscv_v_intrinsic
39944013
switch (__riscv_vlenb() * 8) {
4014+
case 128:
4015+
ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4016+
break;
39954017
case 256:
39964018
ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
39974019
break;
39984020
default:
3999-
ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4021+
ggml_vec_dot_iq2_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
40004022
break;
40014023
}
40024024
#else
@@ -4268,9 +4290,12 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
42684290
case 128:
42694291
ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
42704292
break;
4271-
default:
4293+
case 256:
42724294
ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
42734295
break;
4296+
default:
4297+
ggml_vec_dot_iq2_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
4298+
break;
42744299
}
42754300
#else
42764301
ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc);
@@ -4464,11 +4489,14 @@ static void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t
44644489
void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
44654490
#if defined __riscv_v_intrinsic
44664491
switch (__riscv_vlenb() * 8) {
4492+
case 128:
4493+
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4494+
break;
44674495
case 256:
44684496
ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
44694497
break;
44704498
default:
4471-
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4499+
ggml_vec_dot_iq3_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
44724500
break;
44734501
}
44744502
#else
@@ -4756,11 +4784,17 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n,
47564784
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
47574785
#if defined __riscv_v_intrinsic
47584786
switch (__riscv_vlenb() * 8) {
4787+
case 128:
4788+
ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4789+
break;
47594790
case 256:
47604791
ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
47614792
break;
4793+
case 512:
4794+
ggml_vec_dot_iq3_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc);
4795+
break;
47624796
default:
4763-
ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
4797+
ggml_vec_dot_iq3_xxs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc);
47644798
break;
47654799
}
47664800
#else
@@ -5551,11 +5585,11 @@ static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT
55515585
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
55525586
#if defined __riscv_v_intrinsic
55535587
switch (__riscv_vlenb() * 8) {
5554-
case 256:
5555-
ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
5588+
case 128:
5589+
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
55565590
break;
55575591
default:
5558-
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
5592+
ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
55595593
break;
55605594
}
55615595
#else

0 commit comments

Comments
 (0)