@@ -1298,61 +1298,8 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
12981298 }
12991299#else
13001300
1301- float sumf[8];
1302- float sum_minf[8];
1303- int sumi1,sumi2,sumi3,sumi4;
1304- int sumi;
1301+ ggml_gemv_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
13051302
1306- const block_q8_K * a_ptr = (const block_q8_K *)vy;
1307- for(int x = 0; x < nc / ncols_interleaved; x++) {
1308- const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
1309- for (int j = 0; j < ncols_interleaved; j++) {
1310- sumf[j] = 0.0;
1311- sum_minf[j] = 0.0;
1312- }
1313- for (int l = 0; l < nb; l++) {
1314- for (int k = 0; k < (qk / (4 * blocklen)); k++) {
1315- uint8_t *scales_0 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 ;
1316- uint8_t *scales_1 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 16;
1317- uint8_t *scales_2 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 32;
1318- uint8_t *scales_3 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 48;
1319- for (int j = 0; j < ncols_interleaved; j++) {
1320- sumi1 = 0;
1321- sumi2 = 0;
1322- sumi3 = 0;
1323- sumi4 = 0;
1324- sumi = 0;
1325- int offset = ((k / 2) % 2) + j * 2;
1326- for (int i = 0; i < blocklen; ++i){
1327- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
1328- const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
1329- const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
1330- const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
1331- sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
1332- sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
1333- sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
1334- sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
1335-
1336- sumi1 = sumi1 * (scales_0[offset] & 0xF);
1337- sumi2 = sumi2 * (scales_1[offset] & 0xF);
1338- sumi3 = sumi3 * (scales_2[offset] & 0xF);
1339- sumi4 = sumi4 * (scales_3[offset] & 0xF);
1340- sumi += sumi1 + sumi2 + sumi3 + sumi4;
1341- }
1342- sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
1343- }
1344- }
1345- for(int sb = 0; sb < 8; sb++) {
1346- uint8_t *mins = (uint8_t*) b_ptr[l].scales + sb * 16;
1347- for(int j = 0; j < ncols_interleaved; j++){
1348- sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
1349- }
1350- }
1351- }
1352- for (int j = 0; j < ncols_interleaved; j++) {
1353- s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
1354- }
1355- }
13561303#endif
13571304}
13581305
@@ -6527,74 +6474,8 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
65276474 }
65286475#else
65296476
6530- float sumf[4][8];
6531- float sum_minf[4][8];
6532- int sumi1, sumi2, sumi3, sumi4;
6533- int sumi;
6477+ ggml_gemm_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
65346478
6535- for (int y = 0; y < nr / 4; y++) {
6536- const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
6537- for (int x = 0; x < nc / ncols_interleaved; x++) {
6538- const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
6539- for (int m = 0; m < 4; m++) {
6540- for (int j = 0; j < ncols_interleaved; j++) {
6541- sumf[m][j] = 0.0;
6542- sum_minf[m][j] = 0.0;
6543- }
6544- }
6545- for (int l = 0; l < nb; l++) {
6546- for (int k = 0; k < (qk / (4 * blocklen)); k++) {
6547-
6548- uint8_t *scales_0 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 ;
6549- uint8_t *scales_1 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 16;
6550- uint8_t *scales_2 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 32;
6551- uint8_t *scales_3 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 48;
6552- for (int m = 0; m < 4; m++) {
6553- for (int j = 0; j < ncols_interleaved; j++) {
6554- sumi1 = 0;
6555- sumi2 = 0;
6556- sumi3 = 0;
6557- sumi4 = 0;
6558- sumi = 0;
6559- int offset = ((k / 2) % 2) + j * 2;
6560- for (int i = 0; i < blocklen; ++i){
6561- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x03);
6562- const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 0x03);
6563- const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 0x03);
6564- const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 0x03);
6565- sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
6566- sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
6567- sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
6568- sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
6569- sumi1 = sumi1 * (scales_0[offset] & 0xF);
6570- sumi2 = sumi2 * (scales_1[offset] & 0xF);
6571- sumi3 = sumi3 * (scales_2[offset] & 0xF);
6572- sumi4 = sumi4 * (scales_3[offset] & 0xF);
6573- sumi += sumi1 + sumi2 + sumi3 + sumi4;
6574- }
6575- sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
6576- }
6577- }
6578- }
6579- for(int sb = 0; sb < 8; sb++) {
6580- uint8_t *mins = (uint8_t*) b_ptr[l].scales + sb * 16;
6581- for(int m = 0; m < 4; m++) {
6582- const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
6583- for(int j = 0; j < ncols_interleaved; j++) {
6584- int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
6585- sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
6586- }
6587- }
6588- }
6589- }
6590-
6591- for (int m = 0; m < 4; m++) {
6592- for (int j = 0; j < ncols_interleaved; j++) {
6593- s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
6594- }
6595- }
6596- }
6597- }
65986479
65996480#endif
66006481}
0 commit comments