|
| 1 | +#include "iqk/iqk_config.h" |
| 2 | + |
| 3 | +#if defined IQK_IMPLEMENT && defined GGML_IQK_FLASH_ATTENTION |
| 4 | + |
| 5 | +#include "iqk/fa/iqk_fa_templates.h" |
| 6 | + |
| 7 | +namespace { |
| 8 | + |
| 9 | +template <int step_k, typename KHelper, typename VHelper> |
| 10 | +inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, |
| 11 | + int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, |
| 12 | + const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { |
| 13 | + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { |
| 14 | + nq1 -= n; |
| 15 | + if (nq1 == 0) return true; |
| 16 | + q += n*stride_q; |
| 17 | + mask += n*stride_m; |
| 18 | + qkv += n*stride_qkv; |
| 19 | + if (M && S) { M += n; S += n; } |
| 20 | + return false; |
| 21 | + }; |
| 22 | + if (nq1 >= 16) { |
| 23 | + int n_step = nq1/16; |
| 24 | + FlashAttn<576, 512, 16, step_k> fa(scale, softcap); |
| 25 | + fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); |
| 26 | + if (update(16*n_step)) return; |
| 27 | + } |
| 28 | + if (nq1 >= 8) { |
| 29 | + int n_step = nq1/8; |
| 30 | + FlashAttn<576, 512, 8, step_k> fa(scale, softcap); |
| 31 | + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); |
| 32 | + if (update(8*n_step)) return; |
| 33 | + } |
| 34 | + if (nq1 >= 4) { |
| 35 | + int n_step = nq1/4; |
| 36 | + FlashAttn<576, 512, 4, step_k> fa(scale, softcap); |
| 37 | + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); |
| 38 | + if (update(4*n_step)) return; |
| 39 | + } |
| 40 | + if (nq1 >= 2) { |
| 41 | + int n_step = nq1/2; |
| 42 | + FlashAttn<576, 512, 2, step_k> fa(scale, softcap); |
| 43 | + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); |
| 44 | + if (update(2*n_step)) return; |
| 45 | + } |
| 46 | + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); |
| 47 | + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); |
| 48 | +} |
| 49 | + |
| 50 | +template <int step_k> |
| 51 | +inline bool iqk_deepseek_helper(ggml_type type_k, |
| 52 | + int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, |
| 53 | + const float * q, const char * k, const char * v, const char * mask, |
| 54 | + float scale, float softcap, float * qkv, float * M, float * S) { |
| 55 | + if (type_k == GGML_TYPE_Q8_0) { |
| 56 | + HelperQ80 kh((const char *)k, stride_k); |
| 57 | + HelperQ80 vh((const char *)v, stride_v); |
| 58 | + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); |
| 59 | + return true; |
| 60 | + } |
| 61 | + if (type_k == GGML_TYPE_Q8_0_R8) { |
| 62 | + HelperQ80R8<576> kh((const char *)k, stride_k); |
| 63 | + HelperQ80 vh((const char *)v, stride_v); |
| 64 | + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); |
| 65 | + return true; |
| 66 | + } |
| 67 | + if (type_k == GGML_TYPE_Q6_0) { |
| 68 | + HelperQ60 kh((const char *)k, stride_k); |
| 69 | + HelperQ60 vh((const char *)v, stride_v); |
| 70 | + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); |
| 71 | + return true; |
| 72 | + } |
| 73 | +#if GGML_IQK_FA_ALL_QUANTS |
| 74 | + if (type_k == GGML_TYPE_Q8_KV) { |
| 75 | + HelperQ8KV<576> kh((const char *)k, stride_k); |
| 76 | + HelperQ8KV<512> vh((const char *)v, stride_v); |
| 77 | + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); |
| 78 | + return true; |
| 79 | + } |
| 80 | +#endif |
| 81 | + if (type_k == GGML_TYPE_F16) { |
| 82 | + HelperF16 kh((const char *)k, stride_k); |
| 83 | + HelperF16 vh((const char *)v, stride_v); |
| 84 | + iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); |
| 85 | + return true; |
| 86 | + } |
| 87 | +#ifdef __AVX512BF16__ |
| 88 | + if (type_k == GGML_TYPE_BF16) { |
| 89 | + HelperBF16<576, step_k> kh((const char *)k, stride_k); |
| 90 | + HelperBF16<512, step_k> vh((const char *)v, stride_v); |
| 91 | + if (nq1 % 8 == 0) { |
| 92 | + FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); |
| 93 | + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); |
| 94 | + } else { |
| 95 | + FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); |
| 96 | + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); |
| 97 | + } |
| 98 | + return true; |
| 99 | + } |
| 100 | +#endif |
| 101 | + return false; |
| 102 | +} |
| 103 | + |
| 104 | +} |
| 105 | + |
| 106 | +IQK_FA_CASE(iqk_fa_576_512) { |
| 107 | + |
| 108 | + auto type_k = ggml_type(int_type_k); |
| 109 | + auto type_v = ggml_type(int_type_v); |
| 110 | + |
| 111 | + if (!(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0))) { |
| 112 | + return false; |
| 113 | + } |
| 114 | + stride_q /= sizeof(float); // q stride as float |
| 115 | + return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv, |
| 116 | + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S); |
| 117 | + |
| 118 | +} |
| 119 | + |
| 120 | +#endif |
0 commit comments