Skip to content

Commit d00f03b

Browse files
Nexesenexikawrakow
andcommitted
Allow q8_0 KV cache for head size 256 #330
Co-Authored-By: Kawrakow <[email protected]>
1 parent afc5a3b commit d00f03b

8 files changed

+91
-18
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
211211
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
212212
//FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
213213

214-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
214+
FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
215215
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
216-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
216+
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
217217
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
218218

219219
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
@@ -224,14 +224,14 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
224224
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
225225
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
226226

227-
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
228-
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
229-
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
230-
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
231-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
227+
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
228+
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
229+
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
230+
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
231+
FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
232232
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
233-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
234-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
233+
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
234+
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
235235

236236
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
237237
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
@@ -346,9 +346,9 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
346346
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
347347
//FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
348348

349-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
349+
FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
350350
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
351-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
351+
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
352352
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
353353

354354
//FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
@@ -358,14 +358,14 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
358358
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
359359
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
360360

361-
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
362-
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
363-
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
364-
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
365-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
361+
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
362+
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
363+
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
364+
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
365+
FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
366366
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
367-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
368-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
367+
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
368+
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
369369

370370
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
371371
//FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3689,6 +3689,49 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
36893689
return true;
36903690
case GGML_OP_FLASH_ATTN_EXT:
36913691
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
3692+
/* case GGML_OP_FLASH_ATTN_EXT: {
3693+
#ifndef FLASH_ATTN_AVAILABLE
3694+
return false;
3695+
#endif // FLASH_ATTN_AVAILABLE
3696+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3697+
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3698+
if (!turing_mma_available(cc)) {
3699+
return false;
3700+
}
3701+
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
3702+
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
3703+
}
3704+
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
3705+
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
3706+
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
3707+
return false;
3708+
}
3709+
if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 &&
3710+
(op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) &&
3711+
(op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) {
3712+
return true;
3713+
}
3714+
if (op->src[0]->ne[0] == 192) {
3715+
return false;
3716+
}
3717+
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3718+
return false;
3719+
}
3720+
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
3721+
return true;
3722+
}
3723+
if (op->src[0]->ne[0] == 128) {
3724+
return true;
3725+
}
3726+
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
3727+
return true;
3728+
}
3729+
if (op->src[3] && op->src[3]->ne[2] != 1) {
3730+
return false;
3731+
}
3732+
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
3733+
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
3734+
} */
36923735
case GGML_OP_CROSS_ENTROPY_LOSS:
36933736
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
36943737
case GGML_OP_OPT_STEP_ADAMW:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);

0 commit comments

Comments
 (0)