@@ -3796,8 +3796,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
37963796 case 256 :
37973797 ggml_vec_dot_iq2_s_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
37983798 break ;
3799+ case 512 :
3800+ ggml_vec_dot_iq2_s_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
3801+ break ;
37993802 default :
3800- ggml_vec_dot_iq2_s_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
3803+ ggml_vec_dot_iq2_s_q8_K_vl1024 (n , s , bs , vx , bx , vy , by , nrc );
38013804 break ;
38023805 }
38033806#else
@@ -3844,13 +3847,16 @@ static const int8_t keven_signs_q2xs[1024] = {
38443847};
38453848#endif
38463849
3847- static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
3850+ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256 (int n ,
3851+ float * GGML_RESTRICT s ,
3852+ size_t bs ,
3853+ const void * GGML_RESTRICT vx ,
3854+ size_t bx ,
3855+ const void * GGML_RESTRICT vy ,
3856+ size_t by ,
3857+ int nrc ) {
38483858 assert (n % QK_K == 0 );
3849- assert (nrc == 1 );
3850- UNUSED (nrc );
3851- UNUSED (bx );
3852- UNUSED (by );
3853- UNUSED (bs );
3859+ (void )nrc ; (void )bx ; (void )by ; (void )bs ;
38543860
38553861 const block_iq2_xs * GGML_RESTRICT x = vx ;
38563862 const block_q8_K * GGML_RESTRICT y = vy ;
@@ -3869,61 +3875,74 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT
38693875
38703876 int32_t sum_int = 0 ;
38713877
3872- // Loop over 4 subblocks of 64 elements (QK_K = 256)
3873- for (int ib64 = 0 ; ib64 < QK_K / 64 ; ++ ib64 ) {
3874- // Load 8 uint16 indices (controls 64 values)
3875- vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2 (qs , 8 );
3876- qs += 8 ;
3878+ for (int ib128 = 0 ; ib128 < 2 ; ++ ib128 ) {
38773879
3878- // Extract indices for grid (low 9 bits) and signs (high 7 bits)
3879- // Multiply by 8 (<< 3) for byte offsets into the uint64 tables
3880- vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2 (__riscv_vand_vx_u16mf2 (v_qs , 511 , 8 ), 3 , 8 );
3881- vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2 (__riscv_vsrl_vx_u16mf2 (v_qs , 9 , 8 ), 3 , 8 );
3880+ vuint16m1_t v_qs = __riscv_vle16_v_u16m1 (qs , 16 );
3881+ qs += 16 ;
38823882
3883- vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2 (grid64 , vidx_grid , 8 );
3884- vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2 (signs64 , vidx_sign , 8 );
3883+ // Prepare offsets for grid and signs
3884+ vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1 (__riscv_vand_vx_u16m1 (v_qs , 511 , 16 ), 3 , 16 );
3885+ vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1 (__riscv_vsrl_vx_u16m1 (v_qs , 9 , 16 ), 3 , 16 );
38853886
3886- vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2 (__riscv_vreinterpret_v_u64m2_u8m2 (vq2_64 ));
3887- vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2 (__riscv_vreinterpret_v_u64m2_u8m2 (vs2_64 ));
3887+ // Indexed load 128 weights (16 x 8-byte chunks)
3888+ vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4 (grid64 , vidx_grid , 16 );
3889+ vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4 (signs64 , vidx_sign , 16 );
38883890
3889- // Apply signs
3890- vint8m2_t q2_final = __riscv_vmul_vv_i8m2 ( q2u , q2s , 64 );
3891+ vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4 ( __riscv_vreinterpret_v_u64m4_u8m4 ( vq2_64 ));
3892+ vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4 ( __riscv_vreinterpret_v_u64m4_u8m4 ( vs2_64 ) );
38913893
3892- // Load Q8 weights (64 elements)
3893- vint8m2_t q8v = __riscv_vle8_v_i8m2 ( q8 , 64 );
3894- q8 += 64 ;
3894+ // Apply signs to get dequantized IQ2 values
3895+ vint8m4_t q2_final = __riscv_vmul_vv_i8m4 ( q2u , q2s , 128 );
3896+ asm volatile ( "" ::: "memory" ) ;
38953897
3896- // Multiply (Widening to int16, 64 elements -> LMUL=4)
3897- vint16m4_t prod = __riscv_vwmul_vv_i16m4 (q2_final , q8v , 64 );
3898+ // Load corresponding Q8 weights
3899+ vint8m4_t q8v = __riscv_vle8_v_i8m4 (q8 , 128 );
3900+ q8 += 128 ;
3901+
3902+ vint16m8_t prod = __riscv_vwmul_vv_i16m8 (q2_final , q8v , 128 );
3903+ asm volatile ("" ::: "memory" );
3904+
3905+ uint8_t sc0 = scales [0 ];
3906+ uint8_t sc1 = scales [1 ];
3907+ uint8_t sc2 = scales [2 ];
3908+ uint8_t sc3 = scales [3 ];
3909+ scales += 4 ;
38983910
3899- // Reduction
39003911 vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1 (0 , 1 );
39013912
3902- int32_t sum0 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (
3903- __riscv_vget_v_i16m4_i16m1 (prod , 0 ), zero_vec , 16 ));
3904- int32_t sum1 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (
3905- __riscv_vget_v_i16m4_i16m1 (prod , 1 ), zero_vec , 16 ));
3906- int32_t sum2 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (
3907- __riscv_vget_v_i16m4_i16m1 (prod , 2 ), zero_vec , 16 ));
3908- int32_t sum3 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (
3909- __riscv_vget_v_i16m4_i16m1 (prod , 3 ), zero_vec , 16 ));
3910-
3911- // Apply Scales
3912- const uint8_t scale_byte_1 = scales [0 ];
3913- const uint8_t scale_byte_2 = scales [1 ];
3914- scales += 2 ;
3913+ // 9. Reduce each 16-element chunk and apply corresponding nibble scale
3914+
3915+ int32_t s0 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 0 ), zero_vec , 16 ));
3916+ sum_int += s0 * ((sc0 & 0x0F ) * 2 + 1 );
3917+
3918+ int32_t s1 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 1 ), zero_vec , 16 ));
3919+ sum_int += s1 * ((sc0 >> 4 ) * 2 + 1 );
3920+
3921+ int32_t s2 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 2 ), zero_vec , 16 ));
3922+ sum_int += s2 * ((sc1 & 0x0F ) * 2 + 1 );
3923+
3924+ int32_t s3 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 3 ), zero_vec , 16 ));
3925+ sum_int += s3 * ((sc1 >> 4 ) * 2 + 1 );
3926+
3927+ int32_t s4 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 4 ), zero_vec , 16 ));
3928+ sum_int += s4 * ((sc2 & 0x0F ) * 2 + 1 );
39153929
3916- sum_int += sum0 * ((scale_byte_1 & 0x0F ) * 2 + 1 );
3917- sum_int += sum1 * ((scale_byte_1 >> 4 ) * 2 + 1 );
3918- sum_int += sum2 * ((scale_byte_2 & 0x0F ) * 2 + 1 );
3919- sum_int += sum3 * ((scale_byte_2 >> 4 ) * 2 + 1 );
3930+ int32_t s5 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 5 ), zero_vec , 16 ));
3931+ sum_int += s5 * ((sc2 >> 4 ) * 2 + 1 );
3932+
3933+ int32_t s6 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 6 ), zero_vec , 16 ));
3934+ sum_int += s6 * ((sc3 & 0x0F ) * 2 + 1 );
3935+
3936+ int32_t s7 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 7 ), zero_vec , 16 ));
3937+ sum_int += s7 * ((sc3 >> 4 ) * 2 + 1 );
39203938 }
39213939
3922- sumf += d * sum_int ;
3940+ sumf += d * ( float ) sum_int ;
39233941 }
39243942 * s = 0.125f * sumf ;
39253943}
39263944
3945+
39273946static void ggml_vec_dot_iq2_xs_q8_K_vl512 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
39283947 assert (n % QK_K == 0 );
39293948 assert (nrc == 1 );
@@ -3992,11 +4011,14 @@ static void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_
39924011void ggml_vec_dot_iq2_xs_q8_K (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
39934012#if defined __riscv_v_intrinsic
39944013 switch (__riscv_vlenb () * 8 ) {
4014+ case 128 :
4015+ ggml_vec_dot_iq2_xs_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
4016+ break ;
39954017 case 256 :
39964018 ggml_vec_dot_iq2_xs_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
39974019 break ;
39984020 default :
3999- ggml_vec_dot_iq2_xs_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
4021+ ggml_vec_dot_iq2_xs_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
40004022 break ;
40014023 }
40024024#else
@@ -4268,9 +4290,12 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
42684290 case 128 :
42694291 ggml_vec_dot_iq2_xxs_q8_K_vl128 (n , s , bs , vx , bx , vy , by , nrc );
42704292 break ;
4271- default :
4293+ case 256 :
42724294 ggml_vec_dot_iq2_xxs_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
42734295 break ;
4296+ default :
4297+ ggml_vec_dot_iq2_xxs_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
4298+ break ;
42744299 }
42754300#else
42764301 ggml_vec_dot_iq2_xxs_q8_K (n , s , bs , vx , bx , vy , by , nrc );
@@ -4464,11 +4489,14 @@ static void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t
44644489void ggml_vec_dot_iq3_s_q8_K (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
44654490#if defined __riscv_v_intrinsic
44664491 switch (__riscv_vlenb () * 8 ) {
4492+ case 128 :
4493+ ggml_vec_dot_iq3_s_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
4494+ break ;
44674495 case 256 :
44684496 ggml_vec_dot_iq3_s_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
44694497 break ;
44704498 default :
4471- ggml_vec_dot_iq3_s_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
4499+ ggml_vec_dot_iq3_s_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
44724500 break ;
44734501 }
44744502#else
@@ -4756,11 +4784,17 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n,
47564784void ggml_vec_dot_iq3_xxs_q8_K (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
47574785#if defined __riscv_v_intrinsic
47584786 switch (__riscv_vlenb () * 8 ) {
4787+ case 128 :
4788+ ggml_vec_dot_iq3_xxs_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
4789+ break ;
47594790 case 256 :
47604791 ggml_vec_dot_iq3_xxs_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
47614792 break ;
4793+ case 512 :
4794+ ggml_vec_dot_iq3_xxs_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
4795+ break ;
47624796 default :
4763- ggml_vec_dot_iq3_xxs_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
4797+ ggml_vec_dot_iq3_xxs_q8_K_vl1024 (n , s , bs , vx , bx , vy , by , nrc );
47644798 break ;
47654799 }
47664800#else
@@ -5551,11 +5585,11 @@ static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT
55515585void ggml_vec_dot_tq2_0_q8_K (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
55525586#if defined __riscv_v_intrinsic
55535587 switch (__riscv_vlenb () * 8 ) {
5554- case 256 :
5555- ggml_vec_dot_tq2_0_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
5588+ case 128 :
5589+ ggml_vec_dot_tq2_0_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
55565590 break ;
55575591 default :
5558- ggml_vec_dot_tq2_0_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
5592+ ggml_vec_dot_tq2_0_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
55595593 break ;
55605594 }
55615595#else
0 commit comments