Skip to content

Commit 7af22b6

Browse files
committed
Revert "CUDA: enable FA for FP32 KV cache (ggml-org#16546)"
This reverts commit 9c7185d.
1 parent e88433d commit 7af22b6

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
516516
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
517517
const int nwarps = nthreads / WARP_SIZE;
518518
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
519-
const bool need_f16_K = type_K == GGML_TYPE_F16;
520-
const bool need_f16_V = type_V == GGML_TYPE_F16;
519+
constexpr bool need_f16_K = false;
520+
constexpr bool need_f16_V = false;
521521
constexpr size_t nbytes_shared = 0;
522522
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
523523
}
@@ -526,6 +526,11 @@ template <int D, ggml_type type_K, ggml_type type_V>
526526
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
527527
const ggml_tensor * KQV = dst;
528528
const ggml_tensor * Q = dst->src[0];
529+
const ggml_tensor * K = dst->src[1];
530+
const ggml_tensor * V = dst->src[2];
531+
532+
GGML_ASSERT(K->type == type_K);
533+
GGML_ASSERT(V->type == type_V);
529534

530535
float logit_softcap;
531536
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));

ggml/src/ggml-cuda/fattn.cu

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
116116
}
117117
}
118118

119-
#define FATTN_VEC_CASE(D, type_K, type_V) \
120-
{ \
121-
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
122-
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
123-
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
124-
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
125-
return; \
126-
} \
127-
} \
119+
#define FATTN_VEC_CASE(D, type_K, type_V) \
120+
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
121+
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
122+
return; \
123+
} \
128124

129125
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
130126
FATTN_VEC_CASE( 64, type_K, type_V) \
@@ -257,7 +253,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
257253
#endif // GGML_CUDA_FA_ALL_QUANTS
258254

259255
switch (K->type) {
260-
case GGML_TYPE_F32:
261256
case GGML_TYPE_F16:
262257
break;
263258
case GGML_TYPE_Q4_1:
@@ -283,7 +278,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
283278
// If Turing tensor cores available, use them:
284279
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
285280
if (can_use_vector_kernel) {
286-
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
281+
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
287282
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
288283
return BEST_FATTN_KERNEL_VEC;
289284
}
@@ -330,7 +325,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
330325

331326
// If there are no tensor cores available, use the generic tile kernel:
332327
if (can_use_vector_kernel) {
333-
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
328+
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
334329
if (Q->ne[1] == 1) {
335330
if (!gqa_opt_applies) {
336331
return BEST_FATTN_KERNEL_VEC;

0 commit comments

Comments
 (0)