Skip to content

Commit e56061f

Browse files
ikawrakowIwan Kawrakow
andauthored
IQ2_XXS: much faster CPU prompt processing (#515)
* Much faster iq2_xxs GEMM PP-512 = 290 t/s vs ~110 t/s (iq2_xxs) or 148 t/s (iq2_xxs_r4) on main. * iq2_xxs: q8_2_x4 GEMM * iq2_xxs: use template for q8_2_x4 GEMM * Fix AVX2 * Cleanup * NEON is not working yet, so still use Q8_K GEMM --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 3c1f2c6 commit e56061f

File tree

5 files changed

+346
-10
lines changed

5 files changed

+346
-10
lines changed

ggml/src/ggml.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
10671067
.from_float = quantize_row_iq2_xxs,
10681068
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_xxs_ref,
10691069
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
1070+
#ifdef __AVX2__
1071+
.vec_dot_type = GGML_TYPE_Q8_2_X4,
1072+
#else
10701073
.vec_dot_type = GGML_TYPE_Q8_K,
1074+
#endif
10711075
.nrows = 1,
10721076
.row_meta_size = 0,
10731077
},

ggml/src/iqk/iqk_common.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
172172
aux32[0] = a0 & 0x3f3f3f3f;
173173
}
174174

175-
#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__)
176175
const uint64_t keven_signs[128] = {
177176
0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
178177
0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
@@ -207,7 +206,6 @@ const uint64_t keven_signs[128] = {
207206
0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
208207
0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
209208
};
210-
#endif
211209

212210
#ifdef __AVX2__
213211

@@ -540,6 +538,24 @@ struct Q4Bits {
540538

541539
#endif
542540

541+
inline void iqk_transpose_8x8(__m256 * m) {
542+
for (int k = 0; k < 8; k += 4) {
543+
auto t0 = _mm256_unpacklo_ps(m[k+0], m[k+1]);
544+
auto t1 = _mm256_unpacklo_ps(m[k+2], m[k+3]);
545+
auto t2 = _mm256_unpackhi_ps(m[k+0], m[k+1]);
546+
auto t3 = _mm256_unpackhi_ps(m[k+2], m[k+3]);
547+
m[k+0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
548+
m[k+1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
549+
m[k+2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
550+
m[k+3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
551+
}
552+
for (int k = 0; k < 4; ++k) {
553+
auto t = _mm256_set_m128(_mm256_extractf128_ps(m[k+4], 1), _mm256_extractf128_ps(m[k], 1));
554+
m[k+0] = _mm256_set_m128(_mm256_castps256_ps128(m[k+4]), _mm256_castps256_ps128(m[k+0]));
555+
m[k+4] = t;
556+
}
557+
}
558+
543559
#else
544560
// ------------------------------------ __aarch64__ --------------------------------------------------
545561

ggml/src/iqk/iqk_gemm_iquants.cpp

Lines changed: 235 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,12 @@ struct EvenSignHelper {
8787
const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);
8888
const __m256i mask = _mm256_set1_epi32(127);
8989
const __m256i mone = _mm256_set1_epi32(1);
90-
#else
90+
#endif
9191
inline void sign_value(uint32_t aux32, __m256i& value) const {
9292
auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
9393
keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]);
9494
value = _mm256_sign_epi8(value, signs);
9595
}
96-
#endif
9796
};
9897

9998
struct SignHelper {
@@ -144,6 +143,35 @@ struct SignHelper {
144143
const __m256i mone = _mm256_set1_epi8(1);
145144
};
146145

146+
// for (int i = 0; i < nb; ++i) {
147+
//
148+
// __m256i sumi[nrc_y], all_scales;
149+
// //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
150+
// __m256i mins;
151+
// float dmin = deq.new_block(i, &all_scales, mins);
152+
// for (int iy = 0; iy < nrc_y; ++iy) {
153+
// auto bsums = q8.load_bsums(iy, i);
154+
// auto prod = _mm256_madd_epi16(mins, bsums);
155+
// accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
156+
// }
157+
//
158+
// for (int j = 0; j < QK_K/128; ++j) {
159+
// deq.prepare(i, j);
160+
// set_scales_8(&all_scales, j, scales);
161+
// //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
162+
// multiply_add(deq.bits, scales, j, i, q8, sumi);
163+
// }
164+
// for (int iy = 0; iy < nrc_y; ++iy) {
165+
// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
166+
// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
167+
// }
168+
// }
169+
//
170+
// for (int iy = 0; iy < nrc_y; ++iy) {
171+
// info.store(ix, iy, hsum_float_8(accd[iy]));
172+
// }
173+
// }
174+
147175
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
148176
DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
149177

@@ -165,6 +193,16 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
165193
auto sc16 = load_scales(i);
166194
scales[0] = MM256_SET_M128I(sc16, sc16);
167195
}
196+
inline void new_block_f(int i, __m256 * scales) {
197+
auto sc16 = load_scales(i);
198+
auto scf = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(sc16)));
199+
auto scf_l = _mm256_castps256_ps128(scf);
200+
auto scf_h = _mm256_extractf128_ps(scf, 1);
201+
scales[0] = _mm256_set_m128(scf_l, scf_l);
202+
scales[1] = _mm256_set_m128(scf_h, scf_h);
203+
scales[2] = _mm256_mul_ps(scf, _mm256_set1_ps(-minv));
204+
}
205+
168206
inline float new_block(int i, __m256i * scales, __m256i& mins) {
169207
auto sc16 = load_scales(i);
170208
mins = scb.shuffle(sc16);
@@ -729,6 +767,130 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
729767
}
730768
}
731769

770+
template <typename Dequantizer, int nrc_y>
771+
static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
772+
static_assert(Dequantizer::num_blocks == 8);
773+
const int nb = n / QK_K;
774+
Q8<nrc_y, block_q8_2_x4> q8(info);
775+
Dequantizer deq(vx, bx);
776+
__m256 scales[3];
777+
__m256 accd[nrc_y];
778+
__m256i sumi[4];
779+
780+
for (int ix = 0; ix < nrc_x; ++ix) {
781+
782+
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
783+
784+
deq.new_row(ix);
785+
786+
for (int i = 0; i < nb; ++i) {
787+
788+
deq.new_block_f(i, scales);
789+
for (int iy = 0; iy < nrc_y; ++iy) {
790+
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
791+
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
792+
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
793+
accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
794+
}
795+
796+
for (int j = 0; j < QK_K/128; ++j) {
797+
deq.prepare(i, j);
798+
auto& values = deq.bits.values;
799+
for (int iy = 0; iy < nrc_y; ++iy) {
800+
auto qs = q8.y[iy][2*i+j].qs;
801+
#ifdef HAVE_FANCY_SIMD
802+
sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[0], _mm256_loadu_si256((const __m256i*)qs+0));
803+
sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[1], _mm256_loadu_si256((const __m256i*)qs+1));
804+
sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[2], _mm256_loadu_si256((const __m256i*)qs+2));
805+
sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[3], _mm256_loadu_si256((const __m256i*)qs+3));
806+
#else
807+
sumi[0] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[0], _mm256_loadu_si256((const __m256i*)qs+0)));
808+
sumi[1] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[1], _mm256_loadu_si256((const __m256i*)qs+1)));
809+
sumi[2] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[2], _mm256_loadu_si256((const __m256i*)qs+2)));
810+
sumi[3] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[3], _mm256_loadu_si256((const __m256i*)qs+3)));
811+
#endif
812+
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
813+
sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
814+
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
815+
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
816+
auto dy = _mm256_set_m128(d4, d4);
817+
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
818+
}
819+
}
820+
}
821+
822+
for (int iy = 0; iy < nrc_y; ++iy) {
823+
info.store(ix, iy, hsum_float_8(accd[iy]));
824+
}
825+
}
826+
}
827+
828+
template <int nrc_y>
829+
static void mul_mat_iq2_xxs_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
830+
const int nb = n / QK_K;
831+
__m256 scales[2];
832+
__m256 accd[nrc_y];
833+
__m256i sumi[4];
834+
__m256i xv[4];
835+
EvenSignHelper esh;
836+
837+
for (int ix = 0; ix < nrc_x; ++ix) {
838+
839+
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
840+
841+
const block_iq2_xxs * x = (const block_iq2_xxs *)((const char *)vx + ix*bx);
842+
843+
for (int i = 0; i < nb; ++i) {
844+
const float d = GGML_FP16_TO_FP32(x[i].d)*0.125f;
845+
const uint16_t * a16 = x[i].qs;
846+
auto sc16 = _mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]);
847+
sc16 = _mm_or_si128(_mm_slli_epi16(_mm_srli_epi16(sc16, 12), 1), _mm_set1_epi16(1));
848+
auto sc32 = _mm256_cvtepi16_epi32(sc16);
849+
auto all_scales = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sc32));
850+
auto all_mins = _mm256_mul_ps(all_scales, _mm256_set1_ps(-43.f));
851+
for (int iy = 0; iy < nrc_y; ++iy) {
852+
auto y = (const block_q8_2_x4 *)info.src1_row(iy);
853+
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(y[2*i+0].d + 4)));
854+
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(y[2*i+1].d + 4)));
855+
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
856+
accd[iy] = _mm256_fmadd_ps(all_mins, my, accd[iy]);
857+
}
858+
auto scales_l = _mm256_castps256_ps128(all_scales);
859+
auto scales_h = _mm256_extractf128_ps(all_scales, 1);
860+
scales[0] = _mm256_set_m128(scales_l, scales_l);
861+
scales[1] = _mm256_set_m128(scales_h, scales_h);
862+
863+
for (int j = 0; j < QK_K/128; ++j) {
864+
const uint8_t * a8 = (const uint8_t *)(a16 + 16*j);
865+
for (int k = 0; k < 4; ++k) {
866+
auto a8k = a8 + 8*k;
867+
xv[k] = _mm256_set_epi64x(iq2xxs_grid[a8k[3]], iq2xxs_grid[a8k[2]], iq2xxs_grid[a8k[1]], iq2xxs_grid[a8k[0]]);
868+
uint32_t aux32; std::memcpy(&aux32, a8k+4, sizeof(uint32_t));
869+
esh.sign_value(aux32, xv[k]);
870+
xv[k] = _mm256_add_epi8(xv[k], _mm256_set1_epi8(43));
871+
}
872+
for (int iy = 0; iy < nrc_y; ++iy) {
873+
auto y = (const block_q8_2_x4 *)info.src1_row(iy);
874+
sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[0], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+0));
875+
sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[1], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+1));
876+
sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[2], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+2));
877+
sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), xv[3], _mm256_loadu_si256((const __m256i*)y[2*i+j].qs+3));
878+
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
879+
sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
880+
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
881+
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)y[2*i+j].d)), 16));
882+
auto dy = _mm256_set_m128(d4, d4);
883+
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
884+
}
885+
}
886+
}
887+
888+
for (int iy = 0; iy < nrc_y; ++iy) {
889+
info.store(ix, iy, hsum_float_8(accd[iy]));
890+
}
891+
}
892+
}
893+
732894
template <typename Dequantizer, int nrc_y>
733895
static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
734896
assert(n % QK_K == 0);
@@ -1560,6 +1722,55 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
15601722
}
15611723
}
15621724

1725+
void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1726+
GGML_ASSERT(n%QK_K == 0);
1727+
GGML_ASSERT(nrc_x%8 == 0);
1728+
1729+
int nb = n/QK_K;
1730+
1731+
const block_iq2_xxs * x8[8];
1732+
1733+
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
1734+
1735+
ggml_half dh[8];
1736+
uint16_t all_ls[64];
1737+
EvenSignHelper esh;
1738+
1739+
uint32_t block[8];
1740+
uint32_t aux32[2];
1741+
const uint8_t * aux8 = (const uint8_t *)aux32;
1742+
1743+
for (int ix = 0; ix < nrc_x; ix += 8) {
1744+
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xxs *)((const char *)vx + (ix + k)*bx);
1745+
for (int i = 0; i < nb; ++i) {
1746+
// TODO: simdify
1747+
for (int k = 0; k < 8; ++k) {
1748+
dh[k] = x8[k][i].d;
1749+
for (int ib32 = 0; ib32 < 8; ++ib32) {
1750+
std::memcpy(aux32, x8[k][i].qs + 4*ib32, 2*sizeof(uint32_t));
1751+
all_ls[8*ib32 + k] = (2*(aux32[1] >> 28) + 1);
1752+
auto value = _mm256_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
1753+
esh.sign_value(aux32[1], value);
1754+
_mm256_storeu_si256((__m256i *)block, value);
1755+
auto qs = (uint32_t *)y[ib32].qs;
1756+
for (int l = 0; l < 4; ++l) {
1757+
qs[8*l + k + 0] = block[l + 0];
1758+
qs[8*l + k + 32] = block[l + 4];
1759+
}
1760+
}
1761+
}
1762+
auto vd = _mm256_mul_ps(_mm256_set1_ps(0.125f), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh)));
1763+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
1764+
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
1765+
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
1766+
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
1767+
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
1768+
}
1769+
y += QK_K/32;
1770+
}
1771+
}
1772+
}
1773+
15631774
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
15641775
funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
15651776
funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
@@ -1575,7 +1786,19 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
15751786

15761787
bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
15771788

1578-
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
1789+
if (ne00%QK_K != 0) return false;
1790+
1791+
if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) {
1792+
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
1793+
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels);
1794+
//IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_q8_2_IQ_N, kernels);
1795+
func16 = nullptr;
1796+
return true;
1797+
}
1798+
return false;
1799+
}
1800+
1801+
if (ggml_type(typeB) != GGML_TYPE_Q8_K) {
15791802
return false;
15801803
}
15811804

@@ -1629,6 +1852,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
16291852

16301853
}
16311854

1855+
bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1856+
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
1857+
switch (ggml_type(type)) {
1858+
case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
1859+
default: return false;
1860+
}
1861+
return true;
1862+
}
1863+
16321864
#else
16331865
// --------------------------------------- __aarch64__ ---------------------------------------------
16341866

ggml/src/iqk/iqk_gemm_iquants.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@
88

99
bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
1010

11+
bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x);
12+
1113
#endif

0 commit comments

Comments
 (0)