@@ -3918,6 +3918,201 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT
39183918 * s = sumf ;
39193919}
39203920
3921+ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl512 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
3922+ assert (nrc == 1 );
3923+ UNUSED (nrc );
3924+ UNUSED (bx );
3925+ UNUSED (by );
3926+ UNUSED (bs );
3927+ assert (n % QK_K == 0 );
3928+
3929+ const block_iq4_xs * GGML_RESTRICT x = vx ;
3930+ const block_q8_K * GGML_RESTRICT y = vy ;
3931+
3932+ const int nb = n / QK_K ;
3933+
3934+ const vint8m4_t values = __riscv_vle8_v_i8m4 (kvalues_iq4nl , 16 );
3935+ float sumf = 0 ;
3936+
3937+ // Indices for re-ordering IQ4 data.
3938+ const uint16_t index [32 ] = {
3939+ 0 , 1 , 16 , 17 ,
3940+ 2 , 3 , 18 , 19 ,
3941+ 4 , 5 ,20 , 21 ,
3942+ 6 , 7 , 22 , 23 ,
3943+ 8 , 9 , 24 , 25 ,
3944+ 10 , 11 , 26 , 27 ,
3945+ 12 , 13 ,28 , 29 ,
3946+ 14 , 15 , 30 , 31 ,
3947+ };
3948+ const vuint16m1_t i_vec = __riscv_vle16_v_u16m1 (index , 32 );
3949+
3950+ for (int ibl = 0 ; ibl < nb ; ++ ibl ) {
3951+ const int8_t * q8 = y [ibl ].qs ;
3952+ const uint8_t * iq4 = x [ibl ].qs ;
3953+ uint16_t h = x [ibl ].scales_h ;
3954+
3955+ int sumi = 0 ;
3956+
3957+ #pragma GCC unroll 1
3958+ // Process the entire super-block together.
3959+ for (int ib = 0 ; ib < QK_K / 256 ; ++ ib ) {
3960+ // Weights and activations.
3961+ const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2 (iq4 , 128 );
3962+ iq4 += 128 ;
3963+
3964+ // Unpack the weight blocks.
3965+ const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2 (iq4_packed , 0xf , 128 );
3966+ const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2 (iq4_packed , 4 , 128 );
3967+ const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4 (iq4bits_lo , iq4bits_hi );
3968+ const vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4 (__riscv_vrgatherei16_vv_u64m4 (__riscv_vreinterpret_v_u8m4_u64m4 (iq4bits ), i_vec , 32 ));
3969+ const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4 (values , iq4bits_reorder , 256 );
3970+
3971+ __asm__ __volatile__("" ::: "memory" );
3972+
3973+ // Multiply with activations.
3974+ const vint8m4_t q8b = __riscv_vle8_v_i8m4 (q8 , 256 );
3975+ const vint16m8_t prod = __riscv_vwmul_vv_i16m8 (iq4b , q8b , 256 );
3976+ q8 += 256 ;
3977+
3978+ // Reduce separately.
3979+ const int acc0 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 0 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
3980+ const int acc1 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 1 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
3981+ const int acc2 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 2 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
3982+ const int acc3 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 3 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
3983+ const int acc4 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 4 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
3984+ const int acc5 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 5 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
3985+ const int acc6 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 6 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
3986+ const int acc7 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 (__riscv_vget_v_i16m8_i16m1 (prod , 7 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
3987+
3988+
3989+ const int ls0 = ((x [ibl ].scales_l [0 ] & 0xf ) | ((h << 4 ) & 0x30 )) - 32 ;
3990+ const int ls1 = ((x [ibl ].scales_l [0 ] >> 4 ) | ((h << 2 ) & 0x30 )) - 32 ;
3991+ const int ls2 = ((x [ibl ].scales_l [1 ] & 0xf ) | ((h << 0 ) & 0x30 )) - 32 ;
3992+ const int ls3 = ((x [ibl ].scales_l [1 ] >> 4 ) | ((h >> 2 ) & 0x30 )) - 32 ;
3993+ h >>= 8 ;
3994+ const int ls4 = ((x [ibl ].scales_l [2 ] & 0xf ) | ((h << 4 ) & 0x30 )) - 32 ;
3995+ const int ls5 = ((x [ibl ].scales_l [2 ] >> 4 ) | ((h << 2 ) & 0x30 )) - 32 ;
3996+ const int ls6 = ((x [ibl ].scales_l [3 ] & 0xf ) | ((h << 0 ) & 0x30 )) - 32 ;
3997+ const int ls7 = ((x [ibl ].scales_l [3 ] >> 4 ) | ((h >> 2 ) & 0x30 )) - 32 ;
3998+
3999+ sumi += acc0 * ls0 ;
4000+ sumi += acc1 * ls1 ;
4001+ sumi += acc2 * ls2 ;
4002+ sumi += acc3 * ls3 ;
4003+ sumi += acc4 * ls4 ;
4004+ sumi += acc5 * ls5 ;
4005+ sumi += acc6 * ls6 ;
4006+ sumi += acc7 * ls7 ;
4007+
4008+ __asm__ __volatile__("" ::: "memory" );
4009+ }
4010+
4011+ sumf += GGML_CPU_FP16_TO_FP32 (x [ibl ].d ) * y [ibl ].d * (sumi );
4012+ }
4013+
4014+ * s = sumf ;
4015+ }
4016+
4017+ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl1024 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
4018+ assert (nrc == 1 );
4019+ UNUSED (nrc );
4020+ UNUSED (bx );
4021+ UNUSED (by );
4022+ UNUSED (bs );
4023+ assert (n % QK_K == 0 );
4024+
4025+ const block_iq4_xs * GGML_RESTRICT x = vx ;
4026+ const block_q8_K * GGML_RESTRICT y = vy ;
4027+
4028+ const int nb = n / QK_K ;
4029+
4030+ const vint8m2_t values = __riscv_vle8_v_i8m2 (kvalues_iq4nl , 16 );
4031+ float sumf = 0 ;
4032+
4033+ // Indices for re-ordering IQ4 data.
4034+ const uint16_t index [32 ] = {
4035+ 0 , 1 , 16 , 17 ,
4036+ 2 , 3 , 18 , 19 ,
4037+ 4 , 5 ,20 , 21 ,
4038+ 6 , 7 , 22 , 23 ,
4039+ 8 , 9 , 24 , 25 ,
4040+ 10 , 11 , 26 , 27 ,
4041+ 12 , 13 ,28 , 29 ,
4042+ 14 , 15 , 30 , 31 ,
4043+ };
4044+ const vuint16mf2_t i_vec = __riscv_vle16_v_u16mf2 (index , 32 );
4045+
4046+ for (int ibl = 0 ; ibl < nb ; ++ ibl ) {
4047+ const int8_t * q8 = y [ibl ].qs ;
4048+ const uint8_t * iq4 = x [ibl ].qs ;
4049+ uint16_t h = x [ibl ].scales_h ;
4050+
4051+ int sumi = 0 ;
4052+
4053+ #pragma GCC unroll 1
4054+ // Process the entire super-block together.
4055+ for (int ib = 0 ; ib < QK_K / 256 ; ++ ib ) {
4056+ // Weights and activations.
4057+ const vuint8m1_t iq4_packed = __riscv_vle8_v_u8m1 (iq4 , 128 );
4058+ iq4 += 128 ;
4059+
4060+ // Unpack the weight blocks.
4061+ const vuint8m1_t iq4bits_lo = __riscv_vand_vx_u8m1 (iq4_packed , 0xf , 128 );
4062+ const vuint8m1_t iq4bits_hi = __riscv_vsrl_vx_u8m1 (iq4_packed , 4 , 128 );
4063+ const vuint8m2_t iq4bits = __riscv_vcreate_v_u8m1_u8m2 (iq4bits_lo , iq4bits_hi );
4064+ const vuint8m2_t iq4bits_reorder = __riscv_vreinterpret_v_u64m2_u8m2 (__riscv_vrgatherei16_vv_u64m2 (__riscv_vreinterpret_v_u8m2_u64m2 (iq4bits ), i_vec , 32 ));
4065+ const vint8m2_t iq4b = __riscv_vrgather_vv_i8m2 (values , iq4bits_reorder , 256 );
4066+
4067+ __asm__ __volatile__("" ::: "memory" );
4068+
4069+ // Multiply with activations.
4070+ const vint8m2_t q8b = __riscv_vle8_v_i8m2 (q8 , 256 );
4071+ const vint16m4_t prod = __riscv_vwmul_vv_i16m4 (iq4b , q8b , 256 );
4072+ q8 += 256 ;
4073+
4074+ // Mask for processing 32 elements per prod register.
4075+ const vuint16m1_t p_index = __riscv_vid_v_u16m1 (64 );
4076+ const vbool16_t p_mask = __riscv_vmsgtu_vx_u16m1_b16 (p_index , 31 , 64 );
4077+
4078+ // Reduce separately.
4079+ const int acc0 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 ( __riscv_vget_v_i16m4_i16m1 (prod , 0 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
4080+ const int acc1 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1_m (p_mask , __riscv_vget_v_i16m4_i16m1 (prod , 0 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 64 ));
4081+ const int acc2 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 ( __riscv_vget_v_i16m4_i16m1 (prod , 1 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
4082+ const int acc3 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1_m (p_mask , __riscv_vget_v_i16m4_i16m1 (prod , 1 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 64 ));
4083+ const int acc4 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 ( __riscv_vget_v_i16m4_i16m1 (prod , 2 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
4084+ const int acc5 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1_m (p_mask , __riscv_vget_v_i16m4_i16m1 (prod , 2 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 64 ));
4085+ const int acc6 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1 ( __riscv_vget_v_i16m4_i16m1 (prod , 3 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 32 ));
4086+ const int acc7 = __riscv_vmv_x_s_i32m1_i32 (__riscv_vwredsum_vs_i16m1_i32m1_m (p_mask , __riscv_vget_v_i16m4_i16m1 (prod , 3 ), __riscv_vmv_v_x_i32m1 (0 , 1 ), 64 ));
4087+
4088+ const int ls0 = ((x [ibl ].scales_l [0 ] & 0xf ) | ((h << 4 ) & 0x30 )) - 32 ;
4089+ const int ls1 = ((x [ibl ].scales_l [0 ] >> 4 ) | ((h << 2 ) & 0x30 )) - 32 ;
4090+ const int ls2 = ((x [ibl ].scales_l [1 ] & 0xf ) | ((h << 0 ) & 0x30 )) - 32 ;
4091+ const int ls3 = ((x [ibl ].scales_l [1 ] >> 4 ) | ((h >> 2 ) & 0x30 )) - 32 ;
4092+ h >>= 8 ;
4093+ const int ls4 = ((x [ibl ].scales_l [2 ] & 0xf ) | ((h << 4 ) & 0x30 )) - 32 ;
4094+ const int ls5 = ((x [ibl ].scales_l [2 ] >> 4 ) | ((h << 2 ) & 0x30 )) - 32 ;
4095+ const int ls6 = ((x [ibl ].scales_l [3 ] & 0xf ) | ((h << 0 ) & 0x30 )) - 32 ;
4096+ const int ls7 = ((x [ibl ].scales_l [3 ] >> 4 ) | ((h >> 2 ) & 0x30 )) - 32 ;
4097+
4098+ sumi += acc0 * ls0 ;
4099+ sumi += acc1 * ls1 ;
4100+ sumi += acc2 * ls2 ;
4101+ sumi += acc3 * ls3 ;
4102+ sumi += acc4 * ls4 ;
4103+ sumi += acc5 * ls5 ;
4104+ sumi += acc6 * ls6 ;
4105+ sumi += acc7 * ls7 ;
4106+
4107+ __asm__ __volatile__("" ::: "memory" );
4108+ }
4109+
4110+ sumf += GGML_CPU_FP16_TO_FP32 (x [ibl ].d ) * y [ibl ].d * (sumi );
4111+ }
4112+
4113+ * s = sumf ;
4114+ }
4115+
39214116void ggml_vec_dot_iq4_xs_q8_K (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
39224117#if defined __riscv_v_intrinsic
39234118 switch (__riscv_vlenb () * 8 ) {
@@ -3927,6 +4122,12 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
39274122 case 256 :
39284123 ggml_vec_dot_iq4_xs_q8_K_vl256 (n , s , bs , vx , bx , vy , by , nrc );
39294124 break ;
4125+ case 512 :
4126+ ggml_vec_dot_iq4_xs_q8_K_vl512 (n , s , bs , vx , bx , vy , by , nrc );
4127+ break ;
4128+ case 1024 :
4129+ ggml_vec_dot_iq4_xs_q8_K_vl1024 (n , s , bs , vx , bx , vy , by , nrc );
4130+ break ;
39304131 default :
39314132 ggml_vec_dot_iq4_xs_q8_K_generic (n , s , bs , vx , bx , vy , by , nrc );
39324133 break ;
0 commit comments