@@ -1104,6 +1104,73 @@ __device__ __forceinline__ void vec_dot_iq3_k_q8_1(
11041104
11051105}
11061106
1107+ __device__ __forceinline__ void vec_dot_iq3_ks_q8_1 (
1108+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iiqs, float * result) {
1109+
1110+ float d = __half2float (*(const half *)vbq);
1111+ const block_iq3_ks * bq3 = (const block_iq3_ks *)((const char *)vbq + sizeof (half)) + kbx;
1112+
1113+ int iqs = iiqs/4 ;
1114+ const int ib128 = iqs/4 ; // 0 or 1. 0 works on quants 0...127, 1 on quants 128...255
1115+ // Each thread processes 8 quants in each of the 4 32-blocks
1116+ const int il8 = iqs%4 ; // 0...3. 0 works on quants 0...7, 1 on quants 8...15, 2 on 16...23, 3 on 24...31
1117+ const int shift = 4 *(il8/2 );
1118+
1119+ const uint16_t * ql = (const uint16_t *)bq3->qs + 16 *ib128 + 4 *il8;
1120+ const uint16_t * qh = (const uint16_t *)bq3->qh + 4 *il8;
1121+
1122+ int32_t aux32;
1123+ const uint8_t * aux8 = (const uint8_t *)&aux32;
1124+
1125+ uint16_t extra = bq3->extra >> 4 *ib128;
1126+ uint16_t extra_v = extra >> 8 ;
1127+
1128+ const uint16_t * values1 = iq3k_table + ((extra_v << 6 ) & 0x40 );
1129+ const uint16_t * values2 = iq3k_table + ((extra_v << 5 ) & 0x40 );
1130+ const uint16_t * values3 = iq3k_table + ((extra_v << 4 ) & 0x40 );
1131+ const uint16_t * values4 = iq3k_table + ((extra_v << 3 ) & 0x40 );
1132+
1133+ const int * q8;
1134+ int sumi[4 ] = {0 , 0 , 0 , 0 };
1135+ int v;
1136+ for (int i = 0 ; i < 2 ; ++i) {
1137+ uint32_t vl = ql[2 *i+0 ] | (ql[2 *i+1 ] << 16 );
1138+ uint32_t vh = ((qh[2 *i+0 ] | (qh[2 *i+1 ] << 16 )) >> 4 *ib128) << 2 ;
1139+
1140+ q8 = (const int *)bq8_1[4 *ib128+0 ].qs + 2 *il8;
1141+ aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1142+ v = int_from_table_2 (aux8, values1);
1143+ sumi[0 ] = ggml_cuda_dp4a (v, q8[i], sumi[0 ]);
1144+ vl >>= 2 ; vh >>= 1 ;
1145+
1146+ q8 += sizeof (block_q8_1)/4 ;
1147+ aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1148+ v = int_from_table_2 (aux8, values2);
1149+ sumi[1 ] = ggml_cuda_dp4a (v, q8[i], sumi[1 ]);
1150+ vl >>= 2 ; vh >>= 1 ;
1151+
1152+ q8 += sizeof (block_q8_1)/4 ;
1153+ aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1154+ v = int_from_table_2 (aux8, values3);
1155+ sumi[2 ] = ggml_cuda_dp4a (v, q8[i], sumi[2 ]);
1156+ vl >>= 2 ; vh >>= 1 ;
1157+
1158+ q8 += sizeof (block_q8_1)/4 ;
1159+ aux32 = (vl & 0x03030303 ) | (vh & 0x04040404 );
1160+ v = int_from_table_2 (aux8, values4);
1161+ sumi[3 ] = ggml_cuda_dp4a (v, q8[i], sumi[3 ]);
1162+
1163+ }
1164+ const uint16_t * sl16 = (const uint16_t *)bq3->scales ;
1165+ aux32 = __vsub4 (((sl16[0 ] | (sl16[1 ] << 16 )) >> 4 *ib128) & 0x0f0f0f0f , 0x10101010 );
1166+ const int8_t * a8 = (const int8_t *)&aux32;
1167+ *result += d * (__low2float (bq8_1[4 *ib128+0 ].ds ) * (a8[0 ] + ((extra << 4 ) & 0x10 )) * sumi[0 ] +
1168+ __low2float (bq8_1[4 *ib128+1 ].ds ) * (a8[1 ] + ((extra << 3 ) & 0x10 )) * sumi[1 ] +
1169+ __low2float (bq8_1[4 *ib128+2 ].ds ) * (a8[2 ] + ((extra << 2 ) & 0x10 )) * sumi[2 ] +
1170+ __low2float (bq8_1[4 *ib128+3 ].ds ) * (a8[3 ] + ((extra << 1 ) & 0x10 )) * sumi[3 ]);
1171+
1172+ }
1173+
11071174__device__ __forceinline__ void vec_dot_iq1_bn_q8_1 (
11081175 const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs, float * result) {
11091176
@@ -1302,6 +1369,14 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
13021369 iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
13031370}
13041371
1372+ void mul_mat_vec_iq3_ks_q8_1_cuda (
1373+ const void * vx, const void * vy, float * dst, const char * ids_data,
1374+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
1375+ const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
1376+
1377+ iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ3_KS, VDR_IQ3_K_Q8_1_MMVQ, vec_dot_iq3_ks_q8_1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
1378+ }
1379+
13051380void mul_mat_vec_iq4_kt_q8_1_cuda (
13061381 const void * vx, const void * vy, float * dst, const char * ids_data,
13071382 const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
0 commit comments