@@ -116,15 +116,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
116116 }
117117}
118118
119- #define FATTN_VEC_CASE (D, type_K, type_V ) \
120- { \
121- const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
122- const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
123- if (Q->ne [0 ] == (D) && type_K_okay && type_V_okay) { \
124- ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
125- return ; \
126- } \
127- } \
119+ #define FATTN_VEC_CASE (D, type_K, type_V ) \
120+ if (Q->ne[0 ] == (D) && K->type == (type_K) && V->type == (type_V)) { \
121+ ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
122+ return ; \
123+ } \
128124
129125#define FATTN_VEC_CASES_ALL_D (type_K, type_V ) \
130126 FATTN_VEC_CASE ( 64 , type_K, type_V) \
@@ -257,7 +253,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
257253#endif // GGML_CUDA_FA_ALL_QUANTS
258254
259255 switch (K->type ) {
260- case GGML_TYPE_F32:
261256 case GGML_TYPE_F16:
262257 break ;
263258 case GGML_TYPE_Q4_1:
@@ -283,7 +278,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
283278 // If Turing tensor cores available, use them:
284279 if (turing_mma_available (cc) && K->ne [1 ] % FATTN_KQ_STRIDE == 0 && Q->ne [0 ] != 40 ) {
285280 if (can_use_vector_kernel) {
286- if (! ggml_is_quantized ( K->type ) && ! ggml_is_quantized ( V->type ) ) {
281+ if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 ) {
287282 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1 && Q->ne [3 ] == 1 && !(gqa_ratio > 4 && K->ne [1 ] >= 8192 )) {
288283 return BEST_FATTN_KERNEL_VEC;
289284 }
@@ -330,7 +325,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
330325
331326 // If there are no tensor cores available, use the generic tile kernel:
332327 if (can_use_vector_kernel) {
333- if (! ggml_is_quantized ( K->type ) && ! ggml_is_quantized ( V->type ) ) {
328+ if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 ) {
334329 if (Q->ne [1 ] == 1 ) {
335330 if (!gqa_opt_applies) {
336331 return BEST_FATTN_KERNEL_VEC;
0 commit comments