Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 105 additions & 59 deletions ggml/src/iqk/iqk_gemm_iquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,35 +145,6 @@ struct SignHelper {
const __m256i mone = _mm256_set1_epi8(1);
};

// for (int i = 0; i < nb; ++i) {
//
// __m256i sumi[nrc_y], all_scales;
// //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
// __m256i mins;
// float dmin = deq.new_block(i, &all_scales, mins);
// for (int iy = 0; iy < nrc_y; ++iy) {
// auto bsums = q8.load_bsums(iy, i);
// auto prod = _mm256_madd_epi16(mins, bsums);
// accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
// }
//
// for (int j = 0; j < QK_K/128; ++j) {
// deq.prepare(i, j);
// set_scales_8(&all_scales, j, scales);
// //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
// multiply_add(deq.bits, scales, j, i, q8, sumi);
// }
// for (int iy = 0; iy < nrc_y; ++iy) {
// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
// }
// }
//
// for (int iy = 0; iy < nrc_y; ++iy) {
// info.store(ix, iy, hsum_float_8(accd[iy]));
// }
// }

struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}

Expand Down Expand Up @@ -221,7 +192,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
}

IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {
#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
#if defined z_HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);
esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);
#else
Expand All @@ -246,7 +217,11 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
}
inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
make4(data.val, bits.values, q8_quants);
}
inline void prepare(int i, int j, __m256i * q8_quants) {
Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
make4(data.val, bits.values, q8_quants);
}

Expand Down Expand Up @@ -526,6 +501,13 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
sign_2_values(signs+0, q8_quants+0);
sign_2_values(signs+4, q8_quants+2);
}
inline void prepare(int i, int j, __m256i * q8_quants) {
auto qs = x[i].qs + 32*j;
const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
make4_unsigned(qs, bits.values);
sign_2_values(signs+0, q8_quants+0);
sign_2_values(signs+4, q8_quants+2);
}

constexpr static int minv = 64;

Expand Down Expand Up @@ -625,6 +607,10 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
}
inline void prepare(int i, int j, __m256i * q8_quants) {
prepare_unsigned(i, j);
sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
}

inline void prepare_unsigned(int i, int j) {
auto qs = x[i].qs + 32*j;
Expand Down Expand Up @@ -787,15 +773,69 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
}
}

template <typename Dequantizer, int nrc_y>
template <int n_sum>
inline __m256i compute_dot_4(const __m256i * x, const __m256i * y) {
#ifdef HAVE_FANCY_SIMD
auto sumi0 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[0], y[0]);
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[1], y[1]);
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[2], y[2]);
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[3], y[3]);
sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
#else
auto m1 = _mm256_set1_epi16(1);
if constexpr (n_sum == 2) {
auto sumi0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[0], y[0]));
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[1], y[1]));
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[2], y[2]));
auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[3], y[3]));
sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
}
else {
auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]);
auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]);
auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]);
auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]);
if constexpr (n_sum == 4) {
sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
sumi0 = _mm256_madd_epi16(m1, sumi0);
sumi2 = _mm256_madd_epi16(m1, sumi2);
return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
}
else {
auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]);
auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]);
auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]);
auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]);
sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
return _mm256_madd_epi16(m1, sumi0);
}
}
#endif
}

template <typename Dequantizer, int nrc_y, int n_sum = 2>
static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
static_assert(Dequantizer::num_blocks == 8);
static_assert(n_sum == 2 || n_sum == 4 || n_sum == 8);
#ifdef HAVE_FANCY_SIMD
constexpr bool use_1_row = nrc_y == 1;
#else
constexpr bool use_1_row = nrc_y == 1 && !std::is_same_v<Dequantizer, DequantizerIQ2XXS>;
#endif

const int nb = n / QK_K;
Q8<nrc_y, block_q8_2_x4> q8(info);
Dequantizer deq(vx, bx);
__m256 scales[3];
__m256 accd[nrc_y];
__m256i sumi[4];
__m256i vy[4];

for (int ix = 0; ix < nrc_x; ++ix) {

Expand All @@ -806,35 +846,33 @@ static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const Data
for (int i = 0; i < nb; ++i) {

deq.new_block_f(i, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
if constexpr (!use_1_row) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
}
}

for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j);
auto& values = deq.bits.values;
for (int iy = 0; iy < nrc_y; ++iy) {
auto qs = q8.y[iy][2*i+j].qs;
#ifdef HAVE_FANCY_SIMD
sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[0], _mm256_loadu_si256((const __m256i*)qs+0));
sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[1], _mm256_loadu_si256((const __m256i*)qs+1));
sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[2], _mm256_loadu_si256((const __m256i*)qs+2));
sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[3], _mm256_loadu_si256((const __m256i*)qs+3));
#else
sumi[0] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[0], _mm256_loadu_si256((const __m256i*)qs+0)));
sumi[1] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[1], _mm256_loadu_si256((const __m256i*)qs+1)));
sumi[2] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[2], _mm256_loadu_si256((const __m256i*)qs+2)));
sumi[3] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[3], _mm256_loadu_si256((const __m256i*)qs+3)));
#endif
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
if constexpr (use_1_row) {
for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)q8.y[0][2*i+j].qs+k);
deq.prepare(i, j, vy);
auto sumi = compute_dot_4<2*n_sum>(deq.bits.values, vy);
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[0][2*i+j].d)), 16));
auto dy = _mm256_set_m128(d4, d4);
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[0]);
} else {
deq.prepare(i, j);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qs = q8.y[iy][2*i+j].qs;
for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)qs+k);
auto sumi = compute_dot_4<n_sum>(deq.bits.values, vy);
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
auto dy = _mm256_set_m128(d4, d4);
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[iy]);
}
}
}
}
Expand Down Expand Up @@ -1934,7 +1972,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_

if (ggml_type(typeA) == GGML_TYPE_IQ3_S) {
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
//IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
kernels[0] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 1, 8>;
kernels[1] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 2, 8>;
kernels[2] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 3, 8>;
kernels[3] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 4, 8>;
kernels[4] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 5, 8>;
kernels[5] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 6, 8>;
kernels[6] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 7, 8>;
kernels[7] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 8, 8>;
func16 = nullptr;
return true;
}
Expand Down