Skip to content

Commit 0bb6648

Browse files
Nexesenexikawrakow
andcommitted
Allow q8_0 KV cache for head size 256 #330
Co-Authored-By: Kawrakow <[email protected]>
1 parent e2b482d commit 0bb6648

8 files changed

+91
-18
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
211211
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
212212
//FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
213213

214-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
214+
FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
215215
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
216-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
216+
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
217217
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
218218

219219
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
@@ -224,14 +224,14 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
224224
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
225225
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
226226

227-
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
228-
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
229-
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
230-
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
231-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
227+
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
228+
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
229+
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
230+
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
231+
FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
232232
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
233-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
234-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
233+
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
234+
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
235235

236236
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
237237
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
@@ -346,9 +346,9 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
346346
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
347347
//FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
348348

349-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
349+
FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
350350
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
351-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
351+
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
352352
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
353353

354354
//FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
@@ -358,14 +358,14 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
358358
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
359359
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
360360

361-
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
362-
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
363-
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
364-
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
365-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
361+
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
362+
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
363+
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
364+
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
365+
FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
366366
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
367-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
368-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
367+
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
368+
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
369369

370370
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
371371
//FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4085,6 +4085,49 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
40854085
return true;
40864086
case GGML_OP_FLASH_ATTN_EXT:
40874087
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
4088+
/* case GGML_OP_FLASH_ATTN_EXT: {
4089+
#ifndef FLASH_ATTN_AVAILABLE
4090+
return false;
4091+
#endif // FLASH_ATTN_AVAILABLE
4092+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
4093+
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
4094+
if (!turing_mma_available(cc)) {
4095+
return false;
4096+
}
4097+
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
4098+
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
4099+
}
4100+
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
4101+
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
4102+
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
4103+
return false;
4104+
}
4105+
if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 &&
4106+
(op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) &&
4107+
(op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) {
4108+
return true;
4109+
}
4110+
if (op->src[0]->ne[0] == 192) {
4111+
return false;
4112+
}
4113+
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
4114+
return false;
4115+
}
4116+
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
4117+
return true;
4118+
}
4119+
if (op->src[0]->ne[0] == 128) {
4120+
return true;
4121+
}
4122+
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
4123+
return true;
4124+
}
4125+
if (op->src[3] && op->src[3]->ne[2] != 1) {
4126+
return false;
4127+
}
4128+
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
4129+
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
4130+
} */
40884131
case GGML_OP_CROSS_ENTROPY_LOSS:
40894132
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
40904133
case GGML_OP_OPT_STEP_ADAMW:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);

0 commit comments

Comments
 (0)