Skip to content

Commit 20758ed

Browse files
ikawrakowIwan Kawrakow
andauthored
Q8_K_R8: Fastest quantized matrix multiplications (#141)
* q8_k_r8: fastest matrix multiplication known to human kind We get PP-512(LLaMA-3.1-8B) = 370 t/s on a Ryzen-7950X! * q8_k_r8: AVX2 I was worried that we don't have enough vector registrers on AVX2, but it looks like it handles it just fine. We get PP-512(LLaMA-3.1-8B) = 354 t/s on a Ryzen-5975WX. Slightly slower than the Zen4 version with double the threads, but still a huge upgrade compared to Q8_0_R4. * q8_k_r4: NEON We get PP-512(LLaMA-3.1-8B) = 159.2 t/s. Compare this to the 128 t/s we have fr Q8_0_R4. * q8_k_r4: go to signed ints Why? * On AVX2 _mm256_maddubs_epi16() may overflow, so we need to stay within the signed int range and use _mm256_sign_epi8. Not yet tested on the AVX2 comp, vut expect major slowdown. * It is almost 10% faster on ARM_NEON. Somehow the veorrq_u8() needed tto convert from unsigned to signed seems to be extremely slow on the M2-Max * We only lose ~0.5% in oerformance on Zen4 (there the exclusive or that we now use to convert fro signed to unsigned seems to be much faster than on M2-Max) * Shutup useless compiler warnings --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 12f962d commit 20758ed

File tree

10 files changed

+301
-7
lines changed

10 files changed

+301
-7
lines changed

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
7070
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
7171
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
7272
{ "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", },
73+
{ "Q8_K_R8", LLAMA_FTYPE_MOSTLY_Q8_K_R8, "Q8_K repacked", },
7374
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
7475
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
7576
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },

ggml/include/ggml.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ extern "C" {
408408
GGML_TYPE_IQ4_KSS = 146,
409409
GGML_TYPE_Q8_K16 = 147,
410410
GGML_TYPE_Q8_K32 = 148,
411+
GGML_TYPE_Q8_KR8 = 149,
411412

412413
GGML_TYPE_Q4_0_R4 = 202,
413414
GGML_TYPE_Q5_0_R4 = 206,
@@ -422,6 +423,7 @@ extern "C" {
422423
GGML_TYPE_Q6_0_R4 = 233,
423424
GGML_TYPE_IQ2_BN_R4 = 335,
424425
GGML_TYPE_IQ4_K_R4 = 339,
426+
GGML_TYPE_Q8_K_R8 = 399,
425427
GGML_TYPE_COUNT,
426428
};
427429

@@ -494,6 +496,7 @@ extern "C" {
494496
GGML_FTYPE_MOSTLY_Q6_0_R4 = 227, // except 1d tensors
495497
GGML_FTYPE_MOSTLY_IQ2_BN_R4 = 329, // except 1d tensors
496498
GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors
499+
GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
497500
};
498501

499502
// available tensor operations:

ggml/src/ggml-common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,12 @@ typedef struct {
382382
} block_q8_K128;
383383
static_assert(sizeof(block_q8_K128) == sizeof(float) + 128, "wrong q8_K128 block size/padding");
384384

385+
typedef struct {
386+
ggml_half d[8]; // delta
387+
int8_t qs[8*QK_K]; // quants, stored as unsigned ints
388+
} block_q8_k_r8;
389+
static_assert(sizeof(block_q8_k_r8) == 8*sizeof(ggml_half) + 8*QK_K, "wrong q8_k_r8 block size/padding");
390+
385391
// (Almost) "true" 2-bit quantization.
386392
// Due to the need to use blocks as per ggml design, it ends up using
387393
// 2.0625 bpw because of the 16-bit scale for each block of 256.

ggml/src/ggml-quants.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15208,6 +15208,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1520815208
case GGML_TYPE_Q5_K_R4: break;
1520915209
case GGML_TYPE_Q6_K_R4: break;
1521015210
case GGML_TYPE_IQ4_K_R4: break;
15211+
case GGML_TYPE_Q8_K_R8: break;
1521115212
case GGML_TYPE_Q4_0_4_4:
1521215213
case GGML_TYPE_Q4_0_4_8:
1521315214
{

ggml/src/ggml.c

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
979979
.nrows = 1,
980980
.row_meta_size = 0,
981981
},
982+
[GGML_TYPE_Q8_K_R8] = {
983+
.type_name = "q8_k_r8",
984+
.blck_size = QK_K,
985+
.type_size = sizeof(block_q8_k_r8)/8,
986+
.is_quantized = true,
987+
.to_float = (ggml_to_float_t) dequantize_row_q8_k_r8,
988+
.from_float = quantize_row_q8_k_r8,
989+
.from_float_ref = (ggml_from_float_t) quantize_row_q8_k_r8_ref,
990+
.vec_dot = vec_dot_q8_k_r8_q8_k,
991+
.vec_dot_type = GGML_TYPE_Q8_KR8,
992+
.nrows = 1,
993+
.row_meta_size = 0,
994+
},
982995
[GGML_TYPE_IQ2_XXS] = {
983996
.type_name = "iq2_xxs",
984997
.blck_size = QK_K,
@@ -1197,6 +1210,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
11971210
.from_float = quantize_row_q8_K32,
11981211
.row_meta_size = 0,
11991212
},
1213+
[GGML_TYPE_Q8_KR8] = {
1214+
.type_name = "q8_KR8",
1215+
.blck_size = QK_K,
1216+
.type_size = sizeof(block_q8_K),
1217+
.is_quantized = true,
1218+
.from_float = quantize_row_q8_KR8,
1219+
.row_meta_size = 0,
1220+
},
12001221
[GGML_TYPE_BF16] = {
12011222
.type_name = "bf16",
12021223
.blck_size = 1,
@@ -4105,6 +4126,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
41054126
case GGML_FTYPE_MOSTLY_Q5_K_R4: wtype = GGML_TYPE_Q5_K_R4; break;
41064127
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
41074128
case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break;
4129+
case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break;
41084130
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
41094131
case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
41104132
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
@@ -10641,6 +10663,7 @@ static void ggml_compute_forward_add(
1064110663
case GGML_TYPE_Q5_K_R4:
1064210664
case GGML_TYPE_Q6_K:
1064310665
case GGML_TYPE_Q6_K_R4:
10666+
case GGML_TYPE_Q8_K_R8:
1064410667
case GGML_TYPE_IQ2_XXS:
1064510668
case GGML_TYPE_IQ2_XS:
1064610669
case GGML_TYPE_IQ3_XXS:
@@ -11096,6 +11119,7 @@ static void ggml_compute_forward_add1(
1109611119
case GGML_TYPE_Q5_K_R4:
1109711120
case GGML_TYPE_Q6_K:
1109811121
case GGML_TYPE_Q6_K_R4:
11122+
case GGML_TYPE_Q8_K_R8:
1109911123
case GGML_TYPE_IQ2_XXS:
1110011124
case GGML_TYPE_IQ2_XS:
1110111125
case GGML_TYPE_IQ3_XXS:
@@ -11248,6 +11272,7 @@ static void ggml_compute_forward_acc(
1124811272
case GGML_TYPE_Q5_K_R4:
1124911273
case GGML_TYPE_Q6_K:
1125011274
case GGML_TYPE_Q6_K_R4:
11275+
case GGML_TYPE_Q8_K_R8:
1125111276
case GGML_TYPE_IQ2_XXS:
1125211277
case GGML_TYPE_IQ2_XS:
1125311278
case GGML_TYPE_IQ3_XXS:
@@ -14446,6 +14471,7 @@ static void ggml_compute_forward_out_prod(
1444614471
case GGML_TYPE_Q5_K_R4:
1444714472
case GGML_TYPE_Q6_K:
1444814473
case GGML_TYPE_Q6_K_R4:
14474+
case GGML_TYPE_Q8_K_R8:
1444914475
case GGML_TYPE_IQ2_XXS:
1445014476
case GGML_TYPE_IQ2_XS:
1445114477
case GGML_TYPE_IQ3_XXS:
@@ -14838,6 +14864,7 @@ static void ggml_compute_forward_set(
1483814864
case GGML_TYPE_Q5_K_R4:
1483914865
case GGML_TYPE_Q6_K:
1484014866
case GGML_TYPE_Q6_K_R4:
14867+
case GGML_TYPE_Q8_K_R8:
1484114868
case GGML_TYPE_IQ2_XXS:
1484214869
case GGML_TYPE_IQ2_XS:
1484314870
case GGML_TYPE_IQ3_XXS:
@@ -15124,6 +15151,7 @@ static void ggml_compute_forward_get_rows(
1512415151
case GGML_TYPE_Q5_K_R4:
1512515152
case GGML_TYPE_Q6_K:
1512615153
case GGML_TYPE_Q6_K_R4:
15154+
case GGML_TYPE_Q8_K_R8:
1512715155
case GGML_TYPE_IQ2_XXS:
1512815156
case GGML_TYPE_IQ2_XS:
1512915157
case GGML_TYPE_IQ3_XXS:
@@ -15737,6 +15765,8 @@ static void ggml_compute_forward_clamp(
1573715765
case GGML_TYPE_Q5_K_R4:
1573815766
case GGML_TYPE_Q6_K:
1573915767
case GGML_TYPE_Q6_K_R4:
15768+
case GGML_TYPE_Q8_K_R8:
15769+
case GGML_TYPE_Q8_KR8:
1574015770
case GGML_TYPE_IQ2_XXS:
1574115771
case GGML_TYPE_IQ2_XS:
1574215772
case GGML_TYPE_IQ3_XXS:
@@ -22578,6 +22608,7 @@ size_t ggml_quantize_chunk(
2257822608
case GGML_TYPE_Q5_K_R4: result = quantize_q5_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2257922609
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2258022610
case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
22611+
case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2258122612
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2258222613
case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2258322614
case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ struct DataInfo {
9292
inline void store(int ix, int iy, __m128 result) const {
9393
_mm_storeu_ps(dst_row(iy) + ix, result);
9494
}
95+
inline void store(int ix, int iy, __m256 result) const {
96+
_mm256_storeu_ps(dst_row(iy) + ix, result);
97+
}
9598
#endif
9699
#ifdef __ARM_NEON
97100
inline void store(int ix, int iy, float32x4_t result) const {
@@ -175,6 +178,7 @@ struct MulMat {
175178
case GGML_TYPE_IQ4_NL_R4:
176179
case GGML_TYPE_IQ4_XS_R4:
177180
case GGML_TYPE_IQ2_BN_R4: return 4;
181+
case GGML_TYPE_Q8_K_R8: return 8;
178182
default: return 1;
179183
}
180184
}
@@ -3802,6 +3806,76 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
38023806
}
38033807
}
38043808

3809+
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
3810+
template <int nrc_y>
3811+
static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
3812+
GGML_ASSERT(nrc_x%4 == 0);
3813+
Q8<nrc_y, block_q8_K> q8(info);
3814+
#ifndef HAVE_FANCY_SIMD
3815+
auto m1 = _mm256_set1_epi16(1);
3816+
#endif
3817+
int nbl = n / QK_K;
3818+
__m256 acc[nrc_y] = {};
3819+
__m256i isum[nrc_y] = {};
3820+
__m256i qx[4];
3821+
for (int ix = 0; ix < nrc_x; ix += 8) {
3822+
const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx);
3823+
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
3824+
auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ibl].d));
3825+
for (int ib = 0; ib < QK_K/16; ++ib) {
3826+
qx[0] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0);
3827+
qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1);
3828+
qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2);
3829+
qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3);
3830+
#ifdef HAVE_FANCY_SIMD
3831+
qx[0] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0), _mm256_set1_epi8(-128));
3832+
qx[1] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1), _mm256_set1_epi8(-128));
3833+
qx[2] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2), _mm256_set1_epi8(-128));
3834+
qx[3] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3), _mm256_set1_epi8(-128));
3835+
#else
3836+
auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
3837+
auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
3838+
auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
3839+
auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
3840+
#endif
3841+
for (int iy = 0; iy < nrc_y; ++iy) {
3842+
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib);
3843+
auto y = MM256_SET_M128I(y128, y128);
3844+
#ifdef HAVE_FANCY_SIMD
3845+
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
3846+
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
3847+
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
3848+
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
3849+
#else
3850+
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])));
3851+
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])));
3852+
auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])));
3853+
auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
3854+
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi1, sumi2));
3855+
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi3, sumi4));
3856+
#endif
3857+
}
3858+
}
3859+
#ifdef HAVE_FANCY_SIMD
3860+
auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f));
3861+
#endif
3862+
for (int iy = 0; iy < nrc_y; ++iy) {
3863+
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
3864+
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
3865+
#ifdef HAVE_FANCY_SIMD
3866+
auto bsums = (const float *)q8.y[iy][ibl].bsums;
3867+
acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]);
3868+
#endif
3869+
isum[iy] = _mm256_setzero_si256();
3870+
}
3871+
}
3872+
for (int iy = 0; iy < nrc_y; ++iy) {
3873+
info.store(ix, iy, acc[iy]);
3874+
acc[iy] = _mm256_setzero_ps();
3875+
}
3876+
}
3877+
}
3878+
38053879
template <int nrc_y>
38063880
static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
38073881
GGML_ASSERT(nrc_x%4 == 0);
@@ -5976,6 +6050,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
59766050
mm.funcs[7] = mul_mat_q6_k_r4_q8_k<8>;
59776051
expected_typeB = GGML_TYPE_Q8_K;
59786052
break;
6053+
case GGML_TYPE_Q8_K_R8:
6054+
assert (ne00 % QK_K == 0);
6055+
mm.funcs[0] = mul_mat_q8_k_r8_q8_k<1>;
6056+
mm.funcs[1] = mul_mat_q8_k_r8_q8_k<2>;
6057+
mm.funcs[2] = mul_mat_q8_k_r8_q8_k<3>;
6058+
mm.funcs[3] = mul_mat_q8_k_r8_q8_k<4>;
6059+
mm.funcs[4] = mul_mat_q8_k_r8_q8_k<5>;
6060+
mm.funcs[5] = mul_mat_q8_k_r8_q8_k<6>;
6061+
mm.funcs[6] = mul_mat_q8_k_r8_q8_k<7>;
6062+
mm.funcs[7] = mul_mat_q8_k_r8_q8_k<8>;
6063+
expected_typeB = GGML_TYPE_Q8_KR8;
6064+
break;
59796065
case GGML_TYPE_IQ4_K_R4:
59806066
assert (ne00 % QK_K == 0);
59816067
mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>;
@@ -9158,6 +9244,55 @@ void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
91589244
}
91599245
}
91609246

9247+
template <int nrc_y>
9248+
void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
9249+
GGML_ASSERT(nrc_x%8 == 0);
9250+
Q8<nrc_y, block_q8_K> q8(info);
9251+
int nbl = n / QK_K;
9252+
float32x4_t acc[2*nrc_y] = {};
9253+
for (int ix = 0; ix < nrc_x; ix += 8) {
9254+
const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + ix*bx);
9255+
for (int ibl = 0; ibl < nbl; ++ibl) {
9256+
auto d4l = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+0));
9257+
auto d4h = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+4));
9258+
int32x4_t isum[2*nrc_y] = {};
9259+
for (int ib = 0; ib < QK_K/16; ++ib) {
9260+
auto q1 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 0);
9261+
auto q2 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 64);
9262+
for (int k = 0; k < 4; ++k) {
9263+
q1.val[k] = veorq_u8(q1.val[k], vdupq_n_u8(0x80));
9264+
q2.val[k] = veorq_u8(q2.val[k], vdupq_n_u8(0x80));
9265+
}
9266+
for (int iy = 0; iy < nrc_y; ++iy) {
9267+
auto y = vld1q_s8(q8.y[iy][ibl].qs+16*ib);
9268+
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[0], y, 0);
9269+
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[1], y, 0);
9270+
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[2], y, 1);
9271+
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[3], y, 1);
9272+
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[0], y, 2);
9273+
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[1], y, 2);
9274+
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[2], y, 3);
9275+
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[3], y, 3);
9276+
}
9277+
}
9278+
for (int iy = 0; iy < nrc_y; ++iy) {
9279+
auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
9280+
const float * bsum = (const float *)q8.y[iy][ibl].bsums;
9281+
auto m8 = vdupq_n_f32(-128.f*bsum[0]);
9282+
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[2*iy+0]));
9283+
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[2*iy+1]));
9284+
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], d4l, m8);
9285+
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], d4l, m8);
9286+
}
9287+
}
9288+
for (int iy = 0; iy < nrc_y; ++iy) {
9289+
info.store(ix+0, iy, acc[2*iy+0]);
9290+
info.store(ix+4, iy, acc[2*iy+1]);
9291+
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
9292+
}
9293+
}
9294+
}
9295+
91619296
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
91629297
GGML_ASSERT(nrc_x%4 == 0);
91639298
Q8<1, block_q8_0_x4> q8(info);
@@ -9575,6 +9710,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
95759710
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q6_k_r4_q8_k);
95769711
expected_Btype = GGML_TYPE_Q8_K;
95779712
break;
9713+
case GGML_TYPE_Q8_K_R8:
9714+
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
9715+
expected_Btype = GGML_TYPE_Q8_KR8;
9716+
break;
95789717
case GGML_TYPE_IQ4_K_R4:
95799718
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k);
95809719
expected_Btype = GGML_TYPE_Q8_K;

0 commit comments

Comments
 (0)