Skip to content

Commit 134d548

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix AVX2 implementation of IQ4_K, IQ4_KS, IQ5_K, IQ6_K (#427)
* Fix IQ4_K on AVX2 * Fix IQ4_KS on AVX2 * Fix IQ5_K on AVX2 * Fix IQ6_K on AVX2 --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 34ae71c commit 134d548

File tree

1 file changed

+75
-41
lines changed

1 file changed

+75
-41
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 75 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,6 +1673,29 @@ inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i,
16731673
}
16741674
}
16751675

1676+
template <typename Q8, typename Bits>
1677+
inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
1678+
__m256i p[4];
1679+
if (j == 0) {
1680+
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
1681+
for (int k = 0; k < 4; ++k) {
1682+
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
1683+
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k])));
1684+
}
1685+
sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3]));
1686+
}
1687+
} else {
1688+
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
1689+
for (int k = 0; k < 4; ++k) {
1690+
auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]);
1691+
p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k])));
1692+
}
1693+
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2]));
1694+
sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3]));
1695+
}
1696+
}
1697+
}
1698+
16761699
struct SignHelper {
16771700
inline __m256i make_signs(uint32_t sign_bits) const {
16781701
auto aux256 = _mm256_set1_epi32(sign_bits);
@@ -2892,39 +2915,50 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
28922915
};
28932916

28942917
struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
2895-
DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -128), values(load_iq4nl_values_256()) {}
2918+
DequantizerIQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
28962919
template <typename Q8>
2897-
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
2920+
inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, __m256i * scales) {
28982921
d = GGML_FP16_TO_FP32(x[i].d);
2899-
iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales);
2922+
auto scales8 = make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h);
2923+
auto scales16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales8, hshuff));
2924+
prepare_scales_16(scales16, scales);
29002925
}
29012926
inline void prepare(int i, int j) {
29022927
bits.prepare16(x[i].qs, j);
2903-
bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
2904-
bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
2905-
bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
2906-
bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
2928+
auto extra = x[i].extra >> 8*j;
2929+
bits.values[0] = _mm256_shuffle_epi8(values[extra & 3], bits.values[0]); extra >>= 2;
2930+
bits.values[1] = _mm256_shuffle_epi8(values[extra & 3], bits.values[1]); extra >>= 2;
2931+
bits.values[2] = _mm256_shuffle_epi8(values[extra & 3], bits.values[2]); extra >>= 2;
2932+
bits.values[3] = _mm256_shuffle_epi8(values[extra & 3], bits.values[3]);
29072933
}
29082934
__m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const {
29092935
uint64_t aux64;
29102936
memcpy(&aux64, scales_l, 8);
29112937
auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl);
29122938
const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16);
29132939
auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh);
2914-
auto sch = _mm_shuffle_epi8(aux, iqxk.hshuff);
2940+
auto sch = _mm_shuffle_epi8(aux, hshuff);
29152941
return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
29162942
}
2943+
void load_values() {
2944+
auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0);
2945+
auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1);
2946+
values[0] = MM256_SET_M128I(v1, v1);
2947+
values[1] = MM256_SET_M128I(v1, v2);
2948+
values[2] = MM256_SET_M128I(v2, v1);
2949+
values[3] = MM256_SET_M128I(v2, v2);
2950+
}
29172951

29182952
Q4Bits bits;
2919-
const IQXKScales iqxk;
2920-
const __m256i values;
29212953
const __m128i maskl = _mm_set1_epi8(0xf);
29222954
const __m128i maskh = _mm_set1_epi8(0x30);
29232955
const __m128i m32 = _mm_set1_epi8(-32);
2956+
const __m128i hshuff = _mm_set_epi32(0x0f070e06, 0x0d050c04, 0x0b030a02, 0x09010800);
2957+
__m256i values[4];
29242958
};
29252959

29262960
struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
2927-
DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, -128) { load_values(values); }
2961+
DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, 0) { load_values(values); }
29282962
template <typename Q8>
29292963
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
29302964
d = GGML_FP16_TO_FP32(x[i].d);
@@ -2951,12 +2985,8 @@ struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
29512985
return _mm_add_epi8(_mm_or_si128(scl, sch), m32);
29522986
}
29532987
static void load_values(__m256i * values) {
2954-
static const uint8_t kvalues_iq5nl[32] = {
2955-
2, 14, 25, 36, 45, 54, 63, 71, 78, 85, 92, 98, 104, 110, 116, 122, 127,
2956-
133, 139, 145, 151, 157, 164, 171, 179, 187, 196, 205, 215, 225, 237, 249,
2957-
};
2958-
auto values128_1 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 0);
2959-
auto values128_2 = _mm_loadu_si128((const __m128i *)kvalues_iq5nl + 1);
2988+
auto values128_1 = _mm_loadu_si128((const __m128i *)iq5nl_values + 0);
2989+
auto values128_2 = _mm_loadu_si128((const __m128i *)iq5nl_values + 1);
29602990
values[0] = MM256_SET_M128I(values128_1, values128_1);
29612991
values[1] = MM256_SET_M128I(values128_2, values128_2);
29622992
}
@@ -2972,7 +3002,7 @@ struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
29723002
};
29733003

29743004
struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
2975-
DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, -128) { load_values(values); }
3005+
DequantizerIQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(1, 0) { load_values(values); }
29763006
template <typename Q8>
29773007
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
29783008
d = GGML_FP16_TO_FP32(x[i].d);
@@ -3000,14 +3030,8 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
30003030
_mm256_and_si256(mask4, _mm256_shuffle_epi8(values[3], l))));
30013031
}
30023032
static void load_values(__m256i * values) {
3003-
static const uint8_t kvalues_iq6nl[64] = {
3004-
1, 7, 13, 19, 24, 30, 35, 40, 44, 49, 54, 58, 62, 66, 70, 74,
3005-
77, 81, 84, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 117, 120, 123,
3006-
126, 128, 131, 134, 137, 140, 142, 145, 148, 151, 155, 158, 161, 164, 168, 172,
3007-
175, 179, 183, 187, 191, 196, 200, 205, 210, 215, 220, 226, 231, 237, 243, 249,
3008-
};
30093033
for (int k = 0; k < 4; ++k) {
3010-
auto values128 = _mm_loadu_si128((const __m128i *)kvalues_iq6nl + k);
3034+
auto values128 = _mm_loadu_si128((const __m128i *)iq6nl_values + k);
30113035
values[k] = MM256_SET_M128I(values128, values128);
30123036
}
30133037
}
@@ -3022,32 +3046,32 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
30223046
};
30233047

30243048
struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
3025-
DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
3049+
DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
30263050
template <typename Q8>
3027-
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
3051+
inline __m256i new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accd) {
30283052
auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
3029-
auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
30303053
scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
3031-
auto scales_s = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
3032-
s8k.accum_mins(scales_s, q8, i, d, accd);
30333054
return MM256_SET_M128I(scales128, scales128);
30343055
}
30353056
inline void prepare(int i, int j) {
30363057
bits.prepare16(x[i].qs, j);
3037-
bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
3038-
bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
3039-
bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
3040-
bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
3058+
bits.values[0] = _mm256_shuffle_epi8(values[x[i].scales[4*j+0] & 1], bits.values[0]);
3059+
bits.values[1] = _mm256_shuffle_epi8(values[x[i].scales[4*j+1] & 1], bits.values[1]);
3060+
bits.values[2] = _mm256_shuffle_epi8(values[x[i].scales[4*j+2] & 1], bits.values[2]);
3061+
bits.values[3] = _mm256_shuffle_epi8(values[x[i].scales[4*j+3] & 1], bits.values[3]);
30413062
}
3063+
void load_values() {
3064+
auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0);
3065+
auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1);
3066+
values[0] = MM256_SET_M128I(v1, v1);
3067+
values[1] = MM256_SET_M128I(v2, v2);
3068+
}
3069+
30423070

30433071
Q4Bits bits;
3044-
Scales8KBase s8k;
3045-
const __m256i values;
3072+
__m256i values[2];
30463073
const __m128i mask = _mm_set1_epi16(254);
30473074
const __m128i m127 = _mm_set1_epi16(-127);
3048-
const __m128i m128 = _mm_set1_epi16(-128);
3049-
const __m128i m1 = _mm_set1_epi16(1);
3050-
const __m128i m4 = _mm_set1_epi16(4);
30513075
};
30523076

30533077
struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
@@ -3304,7 +3328,13 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
33043328
for (int j = 0; j < QK_K/128; ++j) {
33053329
deq.prepare(i, j);
33063330
set_scales_16(all_scales[j], scales);
3307-
multiply_add(deq.bits, scales, j, i, q8, sumi);
3331+
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4K> ||
3332+
std::is_same_v<Dequantizer, DequantizerIQ5K> ||
3333+
std::is_same_v<Dequantizer, DequantizerIQ6K>) {
3334+
multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
3335+
} else {
3336+
multiply_add(deq.bits, scales, j, i, q8, sumi);
3337+
}
33083338
}
33093339

33103340
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -3351,7 +3381,11 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
33513381

33523382
set_scales_8(all_scales, j, scales);
33533383

3354-
multiply_add(deq.bits, scales, j, i, q8, sumi);
3384+
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS>) {
3385+
multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
3386+
} else {
3387+
multiply_add(deq.bits, scales, j, i, q8, sumi);
3388+
}
33553389

33563390
}
33573391

0 commit comments

Comments
 (0)