Skip to content

Commit c35a383

Browse files
ikawrakowIwan Kawrakow
andauthored
Zen4: Faster PP for IQ2_KS, IQ4_KS, IQ5_KS (#428)
* Zen4: faster PP for iq4_ks and iq5_ks * Zen4: faster PP for iq2_ks --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 7abdf2b commit c35a383

File tree

1 file changed

+119
-22
lines changed

1 file changed

+119
-22
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 119 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,6 +1798,13 @@ struct Q4Bits {
17981798
values[2] = _mm512_and_si512(q4bits, ml);
17991799
values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
18001800
}
1801+
inline void prepare64a(const uint8_t * q4) {
1802+
for (int k = 0; k < 4; ++k) {
1803+
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + k);
1804+
values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(q4bits), _mm256_srli_epi16(q4bits, 4), 1);
1805+
values[k] = _mm512_and_si512(values[k], ml);
1806+
}
1807+
}
18011808
__m512i values[4];
18021809
const __m512i ml = _mm512_set1_epi8(0xf);
18031810
BlockPermuter perm;
@@ -2106,16 +2113,26 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
21062113
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
21072114
DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
21082115
template <typename Q8>
2109-
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
2116+
inline void compute_block(int i, const Q8& q8, __m512 * acc) {
21102117
prepare(x[i].qs);
21112118
auto scales128 = make_scales(x[i].scales, x[i].extra >> 8);
21122119
auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5);
2113-
auto scales_s = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts)));
2114-
s8k.accum_mins(scales_s, q8, i, d, accm);
2120+
auto mins128 = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts)));
2121+
auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
21152122
auto scales256 = MM256_SET_M128I(scales128, scales128);
21162123
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
2117-
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
2118-
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
2124+
__m512i scales[4];
2125+
for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
2126+
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
2127+
auto q8s = q8.load_bsums(iy, i);
2128+
auto prod = _mm256_madd_epi16(mins, q8s);
2129+
auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
2130+
for (int k = 0; k < 4; ++k) {
2131+
auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
2132+
sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
2133+
}
2134+
acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
2135+
}
21192136
}
21202137
inline void prepare(const uint8_t * q2) {
21212138
bits.prepare(q2);
@@ -2140,7 +2157,7 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
21402157
return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch));
21412158
}
21422159
Q2Bits bits;
2143-
Scales8K s8k;
2160+
Scales8KBase s8k;
21442161

21452162
const __m512i values;
21462163
const __m128i m16 = _mm_set1_epi8(-16);
@@ -2149,6 +2166,12 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
21492166
const __m128i hmask = _mm_set1_epi64x(0x8040201008040201);
21502167
const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400);
21512168
const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
2169+
const __m512i shuffles[4] = {
2170+
_mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
2171+
_mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
2172+
_mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
2173+
_mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
2174+
};
21522175
};
21532176

21542177
struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
@@ -2377,6 +2400,29 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
23772400
scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
23782401
prepare(x[i].qs);
23792402
}
2403+
template <typename Q8>
2404+
inline void compute_block(int i, const Q8& q8, __m512 * acc) {
2405+
auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
2406+
auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m4);
2407+
scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
2408+
auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
2409+
auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
2410+
auto scales256 = MM256_SET_M128I(scales128, scales128);
2411+
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
2412+
__m512i scales[4];
2413+
for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
2414+
prepare(x[i].qs);
2415+
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
2416+
auto q8s = q8.load_bsums(iy, i);
2417+
auto prod = _mm256_madd_epi16(mins, q8s);
2418+
auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
2419+
for (int k = 0; k < 4; ++k) {
2420+
auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
2421+
sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
2422+
}
2423+
acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
2424+
}
2425+
}
23802426
inline void prepare(const uint8_t * q4) {
23812427
bits.prepare64(q4);
23822428
// We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
@@ -2425,10 +2471,33 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
24252471
scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
24262472
prepare(x[i].qs, x[i].qh);
24272473
}
2474+
template <typename Q8>
2475+
inline void compute_block(int i, const Q8& q8, __m512 * acc) {
2476+
auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales));
2477+
auto shifts = _mm_and_si128(_mm_cmpeq_epi16(_mm_and_si128(scales128, m1), m1), m2);
2478+
scales128 = _mm_add_epi16(_mm_and_si128(scales128, mask), m127);
2479+
auto mins128 = _mm_mullo_epi16(scales128, _mm_add_epi16(m128, shifts));
2480+
auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
2481+
auto scales256 = MM256_SET_M128I(scales128, scales128);
2482+
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
2483+
__m512i scales[4];
2484+
for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
2485+
prepare(x[i].qs, x[i].qh);
2486+
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
2487+
auto q8s = q8.load_bsums(iy, i);
2488+
auto prod = _mm256_madd_epi16(mins, q8s);
2489+
auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
2490+
for (int k = 0; k < 4; ++k) {
2491+
auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
2492+
sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
2493+
}
2494+
acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
2495+
}
2496+
}
24282497
inline void prepare(const uint8_t * q4, const uint8_t * qh) {
2429-
bits.prepare64(q4);
2498+
bits.prepare64a(q4);
24302499
auto h256 = _mm256_loadu_si256((const __m256i *)qh);
2431-
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1);
2500+
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 1), 1);
24322501
auto m1 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask1), hmask1);
24332502
auto m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
24342503
bits.values[0] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[0]), m1, values[1], bits.values[0]);
@@ -2438,15 +2507,6 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
24382507
m2 = _mm512_cmpeq_epi8_mask(_mm512_and_si512(hbits, hmask2), hmask2);
24392508
bits.values[2] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m1), values[0], bits.values[2]), m1, values[1], bits.values[2]);
24402509
bits.values[3] = _mm512_mask_shuffle_epi8(_mm512_maskz_shuffle_epi8(_knot_mask64(m2), values[0], bits.values[3]), m2, values[1], bits.values[3]);
2441-
// We now have in bits.valuse[0]: 0...31, 64...95
2442-
// bits.valuse[1]: 32..63, 96..127
2443-
// etc.
2444-
auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
2445-
bits.values[1] = _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]);
2446-
bits.values[0] = tmp;
2447-
tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
2448-
bits.values[3] = _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]);
2449-
bits.values[2] = tmp;
24502510
}
24512511
static void load_values(__m512i * values) {
24522512
static const uint8_t kvalues_iq5nl[32] = {
@@ -2465,9 +2525,7 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
24652525
Scales8KBase s8k;
24662526
__m512i values[2];
24672527
const __m512i hmask1 = _mm512_set1_epi8(1);
2468-
const __m512i hmask2 = _mm512_set1_epi8(2);
2469-
const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
2470-
const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
2528+
const __m512i hmask2 = _mm512_set1_epi8(4);
24712529
const __m128i m127 = _mm_set1_epi16(-127);
24722530
const __m128i m128 = _mm_set1_epi16(-128);
24732531
const __m128i mask = _mm_set1_epi16(254);
@@ -2651,6 +2709,34 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D
26512709
}
26522710
}
26532711

2712+
template <typename Dequantizer, int nrc_y>
2713+
static void mul_mat_iqX_k_q8_K_AVX512_new(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
2714+
assert(n % QK_K == 0);
2715+
const int nb = n / QK_K;
2716+
2717+
Q8<nrc_y> q8(info);
2718+
2719+
Dequantizer deq(vx, bx);
2720+
2721+
__m512 accd[nrc_y];
2722+
2723+
for (int ix = 0; ix < nrc_x; ++ix) {
2724+
2725+
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
2726+
2727+
deq.new_row(ix);
2728+
2729+
for (int i = 0; i < nb; ++i) {
2730+
deq.compute_block(i, q8, accd);
2731+
}
2732+
2733+
for (int iy = 0; iy < nrc_y; ++iy) {
2734+
info.store(ix, iy, _mm512_reduce_add_ps(accd[iy]));
2735+
}
2736+
2737+
}
2738+
}
2739+
26542740
template <typename Dequantizer>
26552741
static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
26562742
assert(n % QK_K == 0);
@@ -9713,8 +9799,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
97139799
std::is_same_v<Dequantizer, DequantizerIQ4K> ||
97149800
std::is_same_v<Dequantizer, DequantizerIQ3K> ||
97159801
std::is_same_v<Dequantizer, DequantizerIQ4XS>||
9716-
std::is_same_v<Dequantizer, DequantizerIQ4KS>||
9717-
std::is_same_v<Dequantizer, DequantizerIQ5KS>||
9802+
//std::is_same_v<Dequantizer, DequantizerIQ4KS>||
9803+
//std::is_same_v<Dequantizer, DequantizerIQ5KS>||
97189804
std::is_same_v<Dequantizer, DequantizerIQ4KSS>) {
97199805
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
97209806
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;
@@ -9724,6 +9810,17 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
97249810
m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;
97259811
m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;
97269812
m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;
9813+
} else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> ||
9814+
std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
9815+
std::is_same_v<Dequantizer, DequantizerIQ5KS>) {
9816+
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 1>;
9817+
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 2>;
9818+
m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 3>;
9819+
m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 4>;
9820+
m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 5>;
9821+
m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 6>;
9822+
m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 7>;
9823+
m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 8>;
97279824
} else {
97289825
m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
97299826
m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;

0 commit comments

Comments
 (0)