Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ if (CUDAToolkit_FOUND)
template-instances/fattn-vec-instance-q8_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo2_0-q8_0.cu
template-instances/fattn-vec-instance-q8_0-turbo2_0.cu)
template-instances/fattn-vec-instance-q8_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu)
endif()

ggml_add_backend_library(ggml-cuda
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/fattn-vec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,12 @@ extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0);
extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0);
extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0);
extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0);

// Mixed turbo3/turbo2 KV cache types
extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0);
extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0);
extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0);

extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0);
extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0);
extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0);
18 changes: 12 additions & 6 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0)

// Mixed turbo3/turbo2 KV cache types
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0)

GGML_ABORT("fatal error");
}

Expand Down Expand Up @@ -368,12 +372,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const

#ifndef GGML_CUDA_FA_ALL_QUANTS
if (K->type != V->type) {
// Allow mixed turbo/q8_0 KV types
const bool turbo_q8_mix = (K->type == GGML_TYPE_TURBO3_0 && V->type == GGML_TYPE_Q8_0) ||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_TURBO3_0) ||
(K->type == GGML_TYPE_TURBO2_0 && V->type == GGML_TYPE_Q8_0) ||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_TURBO2_0);
if (!turbo_q8_mix) {
// Allow mixed turbo KV types
const bool turbo_mix = (K->type == GGML_TYPE_TURBO3_0 && V->type == GGML_TYPE_Q8_0) ||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_TURBO3_0) ||
(K->type == GGML_TYPE_TURBO2_0 && V->type == GGML_TYPE_Q8_0) ||
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_TURBO2_0) ||
(K->type == GGML_TYPE_TURBO3_0 && V->type == GGML_TYPE_TURBO2_0) ||
(K->type == GGML_TYPE_TURBO2_0 && V->type == GGML_TYPE_TURBO3_0);
if (!turbo_mix) {
return BEST_FATTN_KERNEL_NONE;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Mixed KV: turbo2 K + turbo3 V

#include "../fattn-vec.cuh"

DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0);
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Mixed KV: turbo3 K + turbo2 V

#include "../fattn-vec.cuh"

DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0);
DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0);
DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0);
Loading