Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3666,6 +3666,85 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
}
}

// sum[ qy_i * ls_k * (qx_i - 1+/-delta_k)]
// = sum[qy_i * qx_i * ls_k] - 1/8*sum[qy_i * ls_k * (8+/-o_k)]
// = 1/8 * ( sum[qy_i * qx_i * 8*ls+k] - sum[qy_i * ls_k * (8+/-o_k)] )

template <int nrc_y>
static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
Q8<nrc_y, block_q8_K> q8(info);
__m256i qx[8];
__m256i scales[4];
__m256 acc[nrc_y] = {};
auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000
__m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100);
for (int ix = 0; ix < nrc_x; ++ix) {
auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < n/QK_K; ++ibl) {
float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh);
auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7));
scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1));
#ifdef HAVE_FANCY_SIMD
auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask);
auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9));
#else
auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask);
auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7)));
#endif
deltas128 = _mm_mullo_epi16(scales128, deltas128);
scales128 = _mm_slli_epi16(scales128, 3);
auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128);
auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128);
auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7
auto all_scales = MM256_SET_M128I(scales128, scales128);
auto shuffle = shuffle0;
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle);
shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4));
}
const uint8_t * qs = iq1s[ibl].qs;
const uint16_t * qh = iq1s[ibl].qh;
for (int ib = 0; ib < QK_K/32; ib += 2) {
qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)],
iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)],
iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
qs += 8;
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
auto sumi = _mm256_setzero_si256();
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0);
auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1);
#ifdef HAVE_FANCY_SIMD
auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1);
auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2);
sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2));
#else
auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1);
auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2);
auto dot = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot));
#endif
}
#ifdef HAVE_FANCY_SIMD
sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas);
#else
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas));
#endif
acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));
acc[iy] = _mm256_setzero_ps();
}
}
}

template <int nrc_y>
static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Expand Down Expand Up @@ -9473,6 +9552,20 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
mm.funcs[7] = mul_mat_q8_0_r8_q8_1<8>;
expected_typeB = GGML_TYPE_Q8_1_X4;
break;
case GGML_TYPE_IQ1_S:
mm.funcs[0] = mul_mat_iq1_s_q8_K<1>;
mm.funcs[1] = mul_mat_iq1_s_q8_K<2>;
mm.funcs[2] = mul_mat_iq1_s_q8_K<3>;
mm.funcs[3] = mul_mat_iq1_s_q8_K<4>;
mm.funcs[4] = mul_mat_iq1_s_q8_K<5>;
mm.funcs[5] = mul_mat_iq1_s_q8_K<6>;
mm.funcs[6] = mul_mat_iq1_s_q8_K<7>;
mm.funcs[7] = mul_mat_iq1_s_q8_K<8>;
#ifdef HAVE_FANCY_SIMD
mm.func16 = mul_mat_iq1_s_q8_K<16>;
#endif
expected_typeB = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ1_S_R4:
assert (ne00 % QK4_NL == 0);
mm.funcs[0] = mul_mat_iq1_s_r4_q8_1<1>;
Expand Down Expand Up @@ -12513,6 +12606,68 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
}
}

template <int nrc_y>
static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
Q8<nrc_y, block_q8_K> q8(info);
int8x16_t qx[16];
int32x4_t scales[2];
int16x4_t deltas[2];
float32x4_t acc[nrc_y] = {};
auto delta_mask = vdupq_n_u16(0x8000);
for (int ix = 0; ix < nrc_x; ++ix) {
auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < n/QK_K; ++ibl) {
float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
auto qhb = vld1q_u16(iq1s[ibl].qh);
auto scales128 = vandq_u16(vshrq_n_u16(qhb, 12), vdupq_n_u16(7));
scales128 = vaddq_u16(vshlq_n_u16(scales128, 1), vdupq_n_u16(1));
auto mask = vceqq_u16(vandq_u16(qhb, delta_mask), delta_mask);
// Note: we explicitely assume IQ1S_DELTA = 0.125
auto deltas128 = vsubq_s16(vbicq_s16(scales128, mask), vandq_s16(scales128, mask));
//auto deltas128 = vorrq_s16(vandq_s16(vdupq_n_s16(-1), mask), vbicq_s16(vdupq_n_s16(1), mask));
//deltas128 = vmulq_s16(scales128, deltas128);
scales128 = vshlq_n_u16(scales128, 3);
auto qs = iq1s[ibl].qs;
auto qh = iq1s[ibl].qh;
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
qx[4*ib64+0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[2*ib64+0] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[2*ib64+0] << 5) & 0x700)]});
qx[4*ib64+1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[2*ib64+0] << 2) & 0x700)], iq1s_grid[qs[3] | ((qh[2*ib64+0] >> 1) & 0x700)]});
qx[4*ib64+2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[4] | ((qh[2*ib64+1] << 8) & 0x700)], iq1s_grid[qs[5] | ((qh[2*ib64+1] << 5) & 0x700)]});
qx[4*ib64+3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[6] | ((qh[2*ib64+1] << 2) & 0x700)], iq1s_grid[qs[7] | ((qh[2*ib64+1] >> 1) & 0x700)]});
qs += 8;
}
scales[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales128)));
scales[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales128)));
deltas[0] = vget_low_s16 (deltas128);
deltas[1] = vget_high_s16(deltas128);
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums8(iy, ibl);
auto sumi = vdupq_n_s32(0);
sumi = vmlal_s16(sumi, deltas[0], vget_low_s16 (bsums));
sumi = vmlal_s16(sumi, deltas[1], vget_high_s16(bsums));
for (int k = 0; k < QK_K/128; ++k) {
auto qy = q8.load_quants_64(iy, ibl, 2*k+0);
auto dot1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+0], qy.val[0]), qx[8*k+1], qy.val[1]);
auto dot2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+2], qy.val[2]), qx[8*k+3], qy.val[3]);
auto dot12 = vpaddq_s32(dot1, dot2);
qy = q8.load_quants_64(iy, ibl, 2*k+1);
auto dot3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+4], qy.val[0]), qx[8*k+5], qy.val[1]);
auto dot4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+6], qy.val[2]), qx[8*k+7], qy.val[3]);
auto dot34 = vpaddq_s32(dot3, dot4);
auto dot = vpaddq_s32(dot12, dot34);
sumi = vmlaq_s32(sumi, dot, scales[k]);
}
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy]));
acc[iy] = vdupq_n_f32(0);
}
}
}

template <int nrc_y>
static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Expand Down Expand Up @@ -14327,6 +14482,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.func16 = mul_mat_iq2_s_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ1_S:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_q8_K);
m.func16 = mul_mat_iq1_s_q8_K<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ1_S_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1);
m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1;
Expand Down