Skip to content

Commit b94cd3b

Browse files
ikawrakowIwan KawrakowNexesenex
authored
Refactor iqk_mul_mat.cpp (#435)
* Refactor iqk: WIP * Refactor iqk: Factor out float GEMM (AVX2/AVX512) * Refactor iqk: Factor out GEMM for legacy quants (AVX2/AVX512) * Refactor iqk: Factor out GEMM for k-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for i-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for iqk-quants (AVX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for 1-bit quants (ABX2/AVX512) * Refactor iqk: fix AVX2 * Refactor iqk: Factor out GEMM for iq1_bn, iq2_bn, iq2_bn_r4 * Refactor iqk: Factor out GEMM for repacked legacy quants * Refactor iqk: Factor out GEMM for q8_K_R8, q8_KV * Refactor iqk: Factor out GEMM for repacked i-quants * Refactor iqk: GEMM kernels are refactored on AVX2/AVX512 * Refactor iqk: factor out 1-bit quants (NEON) * Refactor iqk: factor out k-quants (NEON) * Refactor iqk: factor out floats (NEON) * Also iq4_xs belongs to k-quants * Refactor iqk: factor out iqk quants (NEON) * Refactor iqk: factor out legacy quants (NEON) * Refactor iqk: factor out repacked legacy quants (NEON) * Refactor iqk: factor out repacked k-quants (NEON) * Refactor iqk: factor out repacked iqk quants (NEON) * Refactor iqk: GEMM kernels are refactored on NEON * Refactor iqk: FA compiles If it works is a different story. Current compile time: 107.3 sesonds on the Ryzen-7950X * Refactor iqk: FA refactored (Zen4) Compile time for the FA files is now ~21 seconds on my Ryzen-7950X, so still slightly too long for my taste but much better than the 142 seconds we had before. * Adding forgotten file * Most helpers don't need to be templates Also hide Q4_0 and Q8_KV behind IQK_FA_ALL_QUANTS. Compilation time drops to 14 second on the Ryzen-5975WX * Fix bf16 * Refactor iqk: FA refactored (NEON) * Forgotten MMQ ref and typo (#431) * Adding forgotten iq5_k_r4 * Fix iq4_k_r4 on NEON * Fix iq4_ks on NEON It was broken before the refactoring (the shifts were not correctly applied). * Fix q8_0 on NEON * Fix q6_0 K cache --------- Co-authored-by: Iwan Kawrakow <[email protected]> Co-authored-by: Nexes the Elder <[email protected]>
1 parent a2b5057 commit b94cd3b

23 files changed

+18643
-18116
lines changed

ggml/src/CMakeLists.txt

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,29 @@ set (GGML_HEADERS_IQK iqk/iqk_config.h)
258258
if (GGML_IQK_MUL_MAT)
259259
message(STATUS "Using optimized iqk matrix multiplications")
260260
add_compile_definitions(GGML_USE_IQK_MULMAT)
261-
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp)
262-
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h)
261+
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp
262+
iqk/iqk_flash_attn.cpp
263+
iqk/fa/iqk_fa_576_512.cpp
264+
iqk/fa/iqk_fa_192_128.cpp
265+
iqk/fa/iqk_fa_256_256.cpp
266+
iqk/fa/iqk_fa_128_128.cpp
267+
iqk/fa/iqk_fa_96_96.cpp
268+
iqk/fa/iqk_fa_64_64.cpp
269+
iqk/iqk_gemm_floats.cpp
270+
iqk/iqk_gemm_kquants.cpp
271+
iqk/iqk_gemm_iquants.cpp
272+
iqk/iqk_gemm_iqk_quants.cpp
273+
iqk/iqk_gemm_1bit.cpp
274+
iqk/iqk_gemm_legacy_quants.cpp)
275+
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
276+
iqk/iqk_flash_impl.h
277+
iqk/fa/iqk_fa_templates.h
278+
iqk/iqk_gemm_floats.h
279+
iqk/iqk_gemm_kquants.h
280+
iqk/iqk_gemm_iquants.h
281+
iqk/iqk_gemm_iqk_quants.h
282+
iqk/iqk_gemm_1bit.h
283+
iqk/iqk_gemm_legacy_quants.h)
263284
if (GGML_IQK_FLASH_ATTENTION)
264285
message(STATUS "Enabling IQK Flash Attention kernels")
265286
add_compile_definitions(GGML_IQK_FLASH_ATTENTION)

ggml/src/iqk/fa/iqk_fa_128_128.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "iqk/iqk_config.h"
2+
3+
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
4+
5+
#include "iqk/fa/iqk_fa_templates.h"
6+
7+
IQK_FA_CASE(iqk_fa_128_128) {
8+
9+
auto type_k = ggml_type(int_type_k);
10+
auto type_v = ggml_type(int_type_v);
11+
12+
stride_q /= sizeof(float); // q stride as float
13+
auto ck = (const char *)k;
14+
auto cv = (const char *)v;
15+
auto cm = (const char *)mask;
16+
17+
#ifdef __AVX512BF16__
18+
if (type_k == GGML_TYPE_BF16) {
19+
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
20+
if (nk%64 == 0) {
21+
iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
22+
q, ck, cv, cm, scale, softcap, qkv, M, S);
23+
return true;
24+
}
25+
iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
26+
q, ck, cv, cm, scale, softcap, qkv, M, S);
27+
return true;
28+
}
29+
#endif
30+
31+
if (nk%128 == 0) {
32+
return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
33+
q, ck, cv, cm, scale, softcap, qkv, M, S);
34+
}
35+
if (nk%64 == 0) {
36+
return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
37+
q, ck, cv, cm, scale, softcap, qkv, M, S);
38+
}
39+
40+
return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
41+
q, ck, cv, cm, scale, softcap, qkv, M, S);
42+
43+
}
44+
45+
#endif

ggml/src/iqk/fa/iqk_fa_192_128.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "iqk/iqk_config.h"
2+
3+
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
4+
5+
#include "iqk/fa/iqk_fa_templates.h"
6+
7+
IQK_FA_CASE(iqk_fa_192_128) {
8+
9+
auto type_k = ggml_type(int_type_k);
10+
auto type_v = ggml_type(int_type_v);
11+
12+
stride_q /= sizeof(float); // q stride as float
13+
auto ck = (const char *)k;
14+
auto cv = (const char *)v;
15+
auto cm = (const char *)mask;
16+
17+
#ifdef __AVX512BF16__
18+
if (type_k == GGML_TYPE_BF16) {
19+
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
20+
if (nk%64 == 0) {
21+
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
22+
q, ck, cv, cm, scale, softcap, qkv, M, S);
23+
return true;
24+
}
25+
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
26+
q, ck, cv, cm, scale, softcap, qkv, M, S);
27+
return true;
28+
}
29+
#endif
30+
31+
if (nk%128 == 0) {
32+
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
33+
q, ck, cv, cm, scale, softcap, qkv, M, S);
34+
}
35+
if (nk%64 == 0) {
36+
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
37+
q, ck, cv, cm, scale, softcap, qkv, M, S);
38+
}
39+
40+
return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
41+
q, ck, cv, cm, scale, softcap, qkv, M, S);
42+
43+
}
44+
45+
#endif

ggml/src/iqk/fa/iqk_fa_256_256.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "iqk/iqk_config.h"
2+
3+
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
4+
5+
#include "iqk/fa/iqk_fa_templates.h"
6+
7+
IQK_FA_CASE(iqk_fa_256_256) {
8+
9+
auto type_k = ggml_type(int_type_k);
10+
auto type_v = ggml_type(int_type_v);
11+
12+
stride_q /= sizeof(float); // q stride as float
13+
auto ck = (const char *)k;
14+
auto cv = (const char *)v;
15+
auto cm = (const char *)mask;
16+
17+
#ifdef __AVX512BF16__
18+
if (type_k == GGML_TYPE_BF16) {
19+
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
20+
if (nk%64 == 0) {
21+
iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
22+
q, ck, cv, cm, scale, softcap, qkv, M, S);
23+
return true;
24+
}
25+
iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
26+
q, ck, cv, cm, scale, softcap, qkv, M, S);
27+
return true;
28+
}
29+
#endif
30+
31+
if (nk%128 == 0) {
32+
return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
33+
q, ck, cv, cm, scale, softcap, qkv, M, S);
34+
}
35+
if (nk%64 == 0) {
36+
return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
37+
q, ck, cv, cm, scale, softcap, qkv, M, S);
38+
}
39+
40+
return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
41+
q, ck, cv, cm, scale, softcap, qkv, M, S);
42+
43+
}
44+
45+
#endif

ggml/src/iqk/fa/iqk_fa_576_512.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include "iqk/iqk_config.h"
2+
3+
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
4+
5+
#include "iqk/fa/iqk_fa_templates.h"
6+
7+
namespace {
8+
9+
template <int step_k, typename KHelper, typename VHelper>
10+
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
11+
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
12+
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
13+
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
14+
nq1 -= n;
15+
if (nq1 == 0) return true;
16+
q += n*stride_q;
17+
mask += n*stride_m;
18+
qkv += n*stride_qkv;
19+
if (M && S) { M += n; S += n; }
20+
return false;
21+
};
22+
if (nq1 >= 16) {
23+
int n_step = nq1/16;
24+
FlashAttn<576, 512, 16, step_k> fa(scale, softcap);
25+
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
26+
if (update(16*n_step)) return;
27+
}
28+
if (nq1 >= 8) {
29+
int n_step = nq1/8;
30+
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
31+
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
32+
if (update(8*n_step)) return;
33+
}
34+
if (nq1 >= 4) {
35+
int n_step = nq1/4;
36+
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
37+
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
38+
if (update(4*n_step)) return;
39+
}
40+
if (nq1 >= 2) {
41+
int n_step = nq1/2;
42+
FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
43+
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
44+
if (update(2*n_step)) return;
45+
}
46+
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
47+
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
48+
}
49+
50+
template <int step_k>
51+
inline bool iqk_deepseek_helper(ggml_type type_k,
52+
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
53+
const float * q, const char * k, const char * v, const char * mask,
54+
float scale, float softcap, float * qkv, float * M, float * S) {
55+
if (type_k == GGML_TYPE_Q8_0) {
56+
HelperQ80 kh((const char *)k, stride_k);
57+
HelperQ80 vh((const char *)v, stride_v);
58+
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
59+
return true;
60+
}
61+
if (type_k == GGML_TYPE_Q8_0_R8) {
62+
HelperQ80R8<576> kh((const char *)k, stride_k);
63+
HelperQ80 vh((const char *)v, stride_v);
64+
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
65+
return true;
66+
}
67+
if (type_k == GGML_TYPE_Q6_0) {
68+
HelperQ60 kh((const char *)k, stride_k);
69+
HelperQ60 vh((const char *)v, stride_v);
70+
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
71+
return true;
72+
}
73+
#if GGML_IQK_FA_ALL_QUANTS
74+
if (type_k == GGML_TYPE_Q8_KV) {
75+
HelperQ8KV<576> kh((const char *)k, stride_k);
76+
HelperQ8KV<512> vh((const char *)v, stride_v);
77+
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
78+
return true;
79+
}
80+
#endif
81+
if (type_k == GGML_TYPE_F16) {
82+
HelperF16 kh((const char *)k, stride_k);
83+
HelperF16 vh((const char *)v, stride_v);
84+
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
85+
return true;
86+
}
87+
#ifdef __AVX512BF16__
88+
if (type_k == GGML_TYPE_BF16) {
89+
HelperBF16<576, step_k> kh((const char *)k, stride_k);
90+
HelperBF16<512, step_k> vh((const char *)v, stride_v);
91+
if (nq1 % 8 == 0) {
92+
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
93+
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
94+
} else {
95+
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
96+
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
97+
}
98+
return true;
99+
}
100+
#endif
101+
return false;
102+
}
103+
104+
}
105+
106+
IQK_FA_CASE(iqk_fa_576_512) {
107+
108+
auto type_k = ggml_type(int_type_k);
109+
auto type_v = ggml_type(int_type_v);
110+
111+
if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) {
112+
return false;
113+
}
114+
stride_q /= sizeof(float); // q stride as float
115+
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
116+
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);
117+
118+
}
119+
120+
#endif

ggml/src/iqk/fa/iqk_fa_64_64.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "iqk/iqk_config.h"
2+
3+
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
4+
5+
#include "iqk/fa/iqk_fa_templates.h"
6+
7+
IQK_FA_CASE(iqk_fa_64_64) {
8+
9+
auto type_k = ggml_type(int_type_k);
10+
auto type_v = ggml_type(int_type_v);
11+
12+
stride_q /= sizeof(float); // q stride as float
13+
auto ck = (const char *)k;
14+
auto cv = (const char *)v;
15+
auto cm = (const char *)mask;
16+
17+
#ifdef __AVX512BF16__
18+
if (type_k == GGML_TYPE_BF16) {
19+
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
20+
if (nk%64 == 0) {
21+
iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
22+
q, ck, cv, cm, scale, softcap, qkv, M, S);
23+
return true;
24+
}
25+
iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
26+
q, ck, cv, cm, scale, softcap, qkv, M, S);
27+
return true;
28+
}
29+
#endif
30+
31+
if (nk%128 == 0) {
32+
return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
33+
q, ck, cv, cm, scale, softcap, qkv, M, S);
34+
}
35+
if (nk%64 == 0) {
36+
return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
37+
q, ck, cv, cm, scale, softcap, qkv, M, S);
38+
}
39+
40+
return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
41+
q, ck, cv, cm, scale, softcap, qkv, M, S);
42+
43+
}
44+
45+
#endif

ggml/src/iqk/fa/iqk_fa_96_96.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "iqk/iqk_config.h"
2+
3+
#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION
4+
5+
#include "iqk/fa/iqk_fa_templates.h"
6+
7+
IQK_FA_CASE(iqk_fa_96_96) {
8+
9+
auto type_k = ggml_type(int_type_k);
10+
auto type_v = ggml_type(int_type_v);
11+
12+
stride_q /= sizeof(float); // q stride as float
13+
auto ck = (const char *)k;
14+
auto cv = (const char *)v;
15+
auto cm = (const char *)mask;
16+
17+
#ifdef __AVX512BF16__
18+
if (type_k == GGML_TYPE_BF16) {
19+
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
20+
if (nk%64 == 0) {
21+
iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
22+
q, ck, cv, cm, scale, softcap, qkv, M, S);
23+
return true;
24+
}
25+
iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
26+
q, ck, cv, cm, scale, softcap, qkv, M, S);
27+
return true;
28+
}
29+
#endif
30+
31+
if (nk%128 == 0) {
32+
return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
33+
q, ck, cv, cm, scale, softcap, qkv, M, S);
34+
}
35+
if (nk%64 == 0) {
36+
return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
37+
q, ck, cv, cm, scale, softcap, qkv, M, S);
38+
}
39+
40+
return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
41+
q, ck, cv, cm, scale, softcap, qkv, M, S);
42+
43+
}
44+
45+
#endif

0 commit comments

Comments
 (0)