Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
68b782e
Refactor iqk: WIP
May 17, 2025
51a87cf
Refactor iqk: Factor out float GEMM (AVX2/AVX512)
May 17, 2025
f83e64d
Refactor iqk: Factor out GEMM for legacy quants (AVX2/AVX512)
May 17, 2025
4ef94c2
Refactor iqk: Factor out GEMM for k-quants (AVX2/AVX512)
May 17, 2025
d355ff9
Refactor iqk: fix AVX2
May 17, 2025
2cbbc55
Refactor iqk: Factor out GEMM for i-quants (AVX2/AVX512)
May 17, 2025
8dae13c
Refactor iqk: fix AVX2
May 17, 2025
de5660c
Refactor iqk: Factor out GEMM for iqk-quants (AVX2/AVX512)
May 17, 2025
082a9bd
Refactor iqk: fix AVX2
May 17, 2025
9b6e75c
Refactor iqk: Factor out GEMM for 1-bit quants (ABX2/AVX512)
May 17, 2025
d66ec60
Refactor iqk: fix AVX2
May 17, 2025
7868545
Refactor iqk: Factor out GEMM for iq1_bn, iq2_bn, iq2_bn_r4
May 17, 2025
6cd3609
Refactor iqk: Factor out GEMM for repacked legacy quants
May 18, 2025
f501200
Refactor iqk: Factor out GEMM for q8_K_R8, q8_KV
May 18, 2025
0d96f3b
Refactor iqk: Factor out GEMM for repacked i-quants
May 18, 2025
c63a0af
Refactor iqk: GEMM kernels are refactored on AVX2/AVX512
May 18, 2025
28b9480
Refactor iqk: factor out 1-bit quants (NEON)
May 18, 2025
c805a19
Refactor iqk: factor out k-quants (NEON)
May 18, 2025
f4ab917
Refactor iqk: factor out floats (NEON)
May 18, 2025
3124136
Also iq4_xs belongs to k-quants
May 18, 2025
465d717
Refactor iqk: factor out iqk quants (NEON)
May 18, 2025
bd1e4d4
Refactor iqk: factor out legacy quants (NEON)
May 18, 2025
2b8a231
Refactor iqk: factor out repacked legacy quants (NEON)
May 19, 2025
7e59d2b
Refactor iqk: factor out repacked k-quants (NEON)
May 19, 2025
7aa2de6
Refactor iqk: factor out repacked iqk quants (NEON)
May 19, 2025
4b4b4fd
Refactor iqk: GEMM kernels are refactored on NEON
May 19, 2025
131e5ac
Refactor iqk: FA compiles
May 19, 2025
630279c
Refactor iqk: FA refactored (Zen4)
May 19, 2025
fbfe79e
Adding forgotten file
May 19, 2025
9541631
Most helpers don't need to be templates
May 19, 2025
9ae8f75
Fix bf16
May 19, 2025
65c8e86
Refactor iqk: FA refactored (NEON)
May 19, 2025
380ab3f
Forgotten MMQ ref and typo (#431)
Nexesenex May 18, 2025
06efa17
Adding forgotten iq5_k_r4
May 19, 2025
7090f17
Fix iq4_k_r4 on NEON
May 19, 2025
4fdb50b
Fix iq4_ks on NEON
May 20, 2025
5351ec0
Fix q8_0 on NEON
May 20, 2025
0943331
Fix q6_0 K cache
May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,29 @@ set (GGML_HEADERS_IQK iqk/iqk_config.h)
if (GGML_IQK_MUL_MAT)
message(STATUS "Using optimized iqk matrix multiplications")
add_compile_definitions(GGML_USE_IQK_MULMAT)
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp)
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h)
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp
iqk/iqk_flash_attn.cpp
iqk/fa/iqk_fa_576_512.cpp
iqk/fa/iqk_fa_192_128.cpp
iqk/fa/iqk_fa_256_256.cpp
iqk/fa/iqk_fa_128_128.cpp
iqk/fa/iqk_fa_96_96.cpp
iqk/fa/iqk_fa_64_64.cpp
iqk/iqk_gemm_floats.cpp
iqk/iqk_gemm_kquants.cpp
iqk/iqk_gemm_iquants.cpp
iqk/iqk_gemm_iqk_quants.cpp
iqk/iqk_gemm_1bit.cpp
iqk/iqk_gemm_legacy_quants.cpp)
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
iqk/iqk_flash_impl.h
iqk/fa/iqk_fa_templates.h
iqk/iqk_gemm_floats.h
iqk/iqk_gemm_kquants.h
iqk/iqk_gemm_iquants.h
iqk/iqk_gemm_iqk_quants.h
iqk/iqk_gemm_1bit.h
iqk/iqk_gemm_legacy_quants.h)
if (GGML_IQK_FLASH_ATTENTION)
message(STATUS "Enabling IQK Flash Attention kernels")
add_compile_definitions(GGML_IQK_FLASH_ATTENTION)
Expand Down
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ static constexpr __device__ int get_mmq_y_device() {

static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
switch (type) {
case GGML_TYPE_Q4_0 : return MMQ_DP4A_TXS_Q4_0;
case GGML_TYPE_Q4_1 : return MMQ_DP4A_TXS_Q4_1;
case GGML_TYPE_Q5_0 : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q5_1 : return MMQ_DP4A_TXS_Q8_1;
Expand Down Expand Up @@ -3363,7 +3364,7 @@ static __global__ void mul_mat_q(
const int jt = kbc / (blocks_per_ne00*nty);
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;

constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
it, jt, kb0_start, kb0_stop);
Expand Down
45 changes: 45 additions & 0 deletions ggml/src/iqk/fa/iqk_fa_128_128.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "iqk/iqk_config.h"

#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION

#include "iqk/fa/iqk_fa_templates.h"

IQK_FA_CASE(iqk_fa_128_128) {

auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);

stride_q /= sizeof(float); // q stride as float
auto ck = (const char *)k;
auto cv = (const char *)v;
auto cm = (const char *)mask;

#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
#endif

if (nk%128 == 0) {
return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}

return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);

}

#endif
45 changes: 45 additions & 0 deletions ggml/src/iqk/fa/iqk_fa_192_128.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "iqk/iqk_config.h"

#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION

#include "iqk/fa/iqk_fa_templates.h"

IQK_FA_CASE(iqk_fa_192_128) {

auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);

stride_q /= sizeof(float); // q stride as float
auto ck = (const char *)k;
auto cv = (const char *)v;
auto cm = (const char *)mask;

#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
#endif

if (nk%128 == 0) {
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}

return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);

}

#endif
45 changes: 45 additions & 0 deletions ggml/src/iqk/fa/iqk_fa_256_256.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "iqk/iqk_config.h"

#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION

#include "iqk/fa/iqk_fa_templates.h"

IQK_FA_CASE(iqk_fa_256_256) {

auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);

stride_q /= sizeof(float); // q stride as float
auto ck = (const char *)k;
auto cv = (const char *)v;
auto cm = (const char *)mask;

#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
#endif

if (nk%128 == 0) {
return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}

return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);

}

#endif
120 changes: 120 additions & 0 deletions ggml/src/iqk/fa/iqk_fa_576_512.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include "iqk/iqk_config.h"

#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION

#include "iqk/fa/iqk_fa_templates.h"

namespace {

template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
nq1 -= n;
if (nq1 == 0) return true;
q += n*stride_q;
mask += n*stride_m;
qkv += n*stride_qkv;
if (M && S) { M += n; S += n; }
return false;
};
if (nq1 >= 16) {
int n_step = nq1/16;
FlashAttn<576, 512, 16, step_k> fa(scale, softcap);
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(16*n_step)) return;
}
if (nq1 >= 8) {
int n_step = nq1/8;
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(8*n_step)) return;
}
if (nq1 >= 4) {
int n_step = nq1/4;
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(4*n_step)) return;
}
if (nq1 >= 2) {
int n_step = nq1/2;
FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(2*n_step)) return;
}
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}

template <int step_k>
inline bool iqk_deepseek_helper(ggml_type type_k,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv, float * M, float * S) {
if (type_k == GGML_TYPE_Q8_0) {
HelperQ80 kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
if (type_k == GGML_TYPE_Q8_0_R8) {
HelperQ80R8<576> kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
if (type_k == GGML_TYPE_Q6_0) {
HelperQ60 kh((const char *)k, stride_k);
HelperQ60 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
#if GGML_IQK_FA_ALL_QUANTS
if (type_k == GGML_TYPE_Q8_KV) {
HelperQ8KV<576> kh((const char *)k, stride_k);
HelperQ8KV<512> vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
#endif
if (type_k == GGML_TYPE_F16) {
HelperF16 kh((const char *)k, stride_k);
HelperF16 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
return true;
}
#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
HelperBF16<576, step_k> kh((const char *)k, stride_k);
HelperBF16<512, step_k> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} else {
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
return true;
}
#endif
return false;
}

}

IQK_FA_CASE(iqk_fa_576_512) {

auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);

if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) {
return false;
}
stride_q /= sizeof(float); // q stride as float
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);

}

#endif
45 changes: 45 additions & 0 deletions ggml/src/iqk/fa/iqk_fa_64_64.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "iqk/iqk_config.h"

#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION

#include "iqk/fa/iqk_fa_templates.h"

IQK_FA_CASE(iqk_fa_64_64) {

auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);

stride_q /= sizeof(float); // q stride as float
auto ck = (const char *)k;
auto cv = (const char *)v;
auto cm = (const char *)mask;

#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
#endif

if (nk%128 == 0) {
return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}

return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);

}

#endif
45 changes: 45 additions & 0 deletions ggml/src/iqk/fa/iqk_fa_96_96.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include "iqk/iqk_config.h"

#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION

#include "iqk/fa/iqk_fa_templates.h"

IQK_FA_CASE(iqk_fa_96_96) {

auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);

stride_q /= sizeof(float); // q stride as float
auto ck = (const char *)k;
auto cv = (const char *)v;
auto cm = (const char *)mask;

#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
return true;
}
#endif

if (nk%128 == 0) {
return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
}

return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);

}

#endif
Loading