@@ -92,6 +92,9 @@ struct DataInfo {
9292 inline void store(int ix, int iy, __m128 result) const {
9393 _mm_storeu_ps(dst_row(iy) + ix, result);
9494 }
95+ inline void store(int ix, int iy, __m256 result) const {
96+ _mm256_storeu_ps(dst_row(iy) + ix, result);
97+ }
9598#endif
9699#ifdef __ARM_NEON
97100 inline void store(int ix, int iy, float32x4_t result) const {
@@ -175,6 +178,7 @@ struct MulMat {
175178 case GGML_TYPE_IQ4_NL_R4:
176179 case GGML_TYPE_IQ4_XS_R4:
177180 case GGML_TYPE_IQ2_BN_R4: return 4;
181+ case GGML_TYPE_Q8_K_R8: return 8;
178182 default: return 1;
179183 }
180184 }
@@ -3802,6 +3806,76 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn
38023806 }
38033807}
38043808
3809+ // The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
3810+ template <int nrc_y>
3811+ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
3812+ GGML_ASSERT(nrc_x%4 == 0);
3813+ Q8<nrc_y, block_q8_K> q8(info);
3814+ #ifndef HAVE_FANCY_SIMD
3815+ auto m1 = _mm256_set1_epi16(1);
3816+ #endif
3817+ int nbl = n / QK_K;
3818+ __m256 acc[nrc_y] = {};
3819+ __m256i isum[nrc_y] = {};
3820+ __m256i qx[4];
3821+ for (int ix = 0; ix < nrc_x; ix += 8) {
3822+ const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx);
3823+ for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
3824+ auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ibl].d));
3825+ for (int ib = 0; ib < QK_K/16; ++ib) {
3826+ qx[0] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0);
3827+ qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1);
3828+ qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2);
3829+ qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3);
3830+ #ifdef HAVE_FANCY_SIMD
3831+ qx[0] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0), _mm256_set1_epi8(-128));
3832+ qx[1] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1), _mm256_set1_epi8(-128));
3833+ qx[2] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2), _mm256_set1_epi8(-128));
3834+ qx[3] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3), _mm256_set1_epi8(-128));
3835+ #else
3836+ auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
3837+ auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
3838+ auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
3839+ auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
3840+ #endif
3841+ for (int iy = 0; iy < nrc_y; ++iy) {
3842+ auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib);
3843+ auto y = MM256_SET_M128I(y128, y128);
3844+ #ifdef HAVE_FANCY_SIMD
3845+ isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
3846+ isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
3847+ isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
3848+ isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
3849+ #else
3850+ auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])));
3851+ auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])));
3852+ auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])));
3853+ auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
3854+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi1, sumi2));
3855+ isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi3, sumi4));
3856+ #endif
3857+ }
3858+ }
3859+ #ifdef HAVE_FANCY_SIMD
3860+ auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f));
3861+ #endif
3862+ for (int iy = 0; iy < nrc_y; ++iy) {
3863+ auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
3864+ acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
3865+ #ifdef HAVE_FANCY_SIMD
3866+ auto bsums = (const float *)q8.y[iy][ibl].bsums;
3867+ acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]);
3868+ #endif
3869+ isum[iy] = _mm256_setzero_si256();
3870+ }
3871+ }
3872+ for (int iy = 0; iy < nrc_y; ++iy) {
3873+ info.store(ix, iy, acc[iy]);
3874+ acc[iy] = _mm256_setzero_ps();
3875+ }
3876+ }
3877+ }
3878+
38053879template <int nrc_y>
38063880static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
38073881 GGML_ASSERT(nrc_x%4 == 0);
@@ -5976,6 +6050,18 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
59766050 mm.funcs[7] = mul_mat_q6_k_r4_q8_k<8>;
59776051 expected_typeB = GGML_TYPE_Q8_K;
59786052 break;
6053+ case GGML_TYPE_Q8_K_R8:
6054+ assert (ne00 % QK_K == 0);
6055+ mm.funcs[0] = mul_mat_q8_k_r8_q8_k<1>;
6056+ mm.funcs[1] = mul_mat_q8_k_r8_q8_k<2>;
6057+ mm.funcs[2] = mul_mat_q8_k_r8_q8_k<3>;
6058+ mm.funcs[3] = mul_mat_q8_k_r8_q8_k<4>;
6059+ mm.funcs[4] = mul_mat_q8_k_r8_q8_k<5>;
6060+ mm.funcs[5] = mul_mat_q8_k_r8_q8_k<6>;
6061+ mm.funcs[6] = mul_mat_q8_k_r8_q8_k<7>;
6062+ mm.funcs[7] = mul_mat_q8_k_r8_q8_k<8>;
6063+ expected_typeB = GGML_TYPE_Q8_KR8;
6064+ break;
59796065 case GGML_TYPE_IQ4_K_R4:
59806066 assert (ne00 % QK_K == 0);
59816067 mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>;
@@ -9158,6 +9244,55 @@ void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
91589244 }
91599245}
91609246
9247+ template <int nrc_y>
9248+ void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
9249+ GGML_ASSERT(nrc_x%8 == 0);
9250+ Q8<nrc_y, block_q8_K> q8(info);
9251+ int nbl = n / QK_K;
9252+ float32x4_t acc[2*nrc_y] = {};
9253+ for (int ix = 0; ix < nrc_x; ix += 8) {
9254+ const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + ix*bx);
9255+ for (int ibl = 0; ibl < nbl; ++ibl) {
9256+ auto d4l = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+0));
9257+ auto d4h = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+4));
9258+ int32x4_t isum[2*nrc_y] = {};
9259+ for (int ib = 0; ib < QK_K/16; ++ib) {
9260+ auto q1 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 0);
9261+ auto q2 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 64);
9262+ for (int k = 0; k < 4; ++k) {
9263+ q1.val[k] = veorq_u8(q1.val[k], vdupq_n_u8(0x80));
9264+ q2.val[k] = veorq_u8(q2.val[k], vdupq_n_u8(0x80));
9265+ }
9266+ for (int iy = 0; iy < nrc_y; ++iy) {
9267+ auto y = vld1q_s8(q8.y[iy][ibl].qs+16*ib);
9268+ isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[0], y, 0);
9269+ isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[1], y, 0);
9270+ isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[2], y, 1);
9271+ isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[3], y, 1);
9272+ isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[0], y, 2);
9273+ isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[1], y, 2);
9274+ isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[2], y, 3);
9275+ isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[3], y, 3);
9276+ }
9277+ }
9278+ for (int iy = 0; iy < nrc_y; ++iy) {
9279+ auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
9280+ const float * bsum = (const float *)q8.y[iy][ibl].bsums;
9281+ auto m8 = vdupq_n_f32(-128.f*bsum[0]);
9282+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[2*iy+0]));
9283+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[2*iy+1]));
9284+ acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], d4l, m8);
9285+ acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], d4l, m8);
9286+ }
9287+ }
9288+ for (int iy = 0; iy < nrc_y; ++iy) {
9289+ info.store(ix+0, iy, acc[2*iy+0]);
9290+ info.store(ix+4, iy, acc[2*iy+1]);
9291+ acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
9292+ }
9293+ }
9294+ }
9295+
91619296void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
91629297 GGML_ASSERT(nrc_x%4 == 0);
91639298 Q8<1, block_q8_0_x4> q8(info);
@@ -9575,6 +9710,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
95759710 SET_MUL_MAT_FUNCTIONS(m, mul_mat_q6_k_r4_q8_k);
95769711 expected_Btype = GGML_TYPE_Q8_K;
95779712 break;
9713+ case GGML_TYPE_Q8_K_R8:
9714+ SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
9715+ expected_Btype = GGML_TYPE_Q8_KR8;
9716+ break;
95789717 case GGML_TYPE_IQ4_K_R4:
95799718 SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k);
95809719 expected_Btype = GGML_TYPE_Q8_K;
0 commit comments