Skip to content

Commit 4248bff

Browse files
committed
Up FA KV modes
1 parent dc60af2 commit 4248bff

File tree

4 files changed

+66
-26
lines changed

4 files changed

+66
-26
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -846,14 +846,19 @@ static void on_no_fattn_vec_case(const int D) {
846846
} else if (D == 128) {
847847
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
848848
fprintf(stderr, "Supported combinations:\n");
849-
fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n");
850-
fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.5 BPV\n");
851-
fprintf(stderr, " - K == q6_0, V == q5_0, 6.0 BPV\n");
852-
fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.5 BPV\n");
853-
fprintf(stderr, " - K == q8_0, V == q6_0, 7.5 BPV\n");
854-
fprintf(stderr, " - K == q8_0, V == q8_0, 8.5 BPV\n");
855-
fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n");
856-
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n");
849+
fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n"); //now obsolete, left as a legacy failsafe.
850+
fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.5 BPV\n"); //replaces KV q4_0, 1-3% perplexity drop.
851+
fprintf(stderr, " - K == q5_0, V == iq4_nl, 5.0 BPV\n"); //best performance oriented compromise, and recommanded for speculative model.
852+
fprintf(stderr, " - K == q5_1, V == q5_0, 5.5 BPV\n"); //pre K q6_0 most balanced compromise, left as a failsafe.
853+
fprintf(stderr, " - K == q6_0, V == iq4_nl, 5.5 BPV\n"); //replaces KV q5_1-q5_0, 0.1-0.3 perplexity drop.
854+
fprintf(stderr, " - K == q6_0, V == q5_0, 6.0 BPV\n"); //almost equals KV q8_0/q5_0, almost..
855+
fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.5 BPV\n"); //IK's favorite.
856+
fprintf(stderr, " - K == q8_0, V == q5_0, 7.0 BPV\n"); //qualitative compromise.
857+
fprintf(stderr, " - K == q8_0, V == q6_0, 7.5 BPV\n"); //should be optimal, but KV 8/5 or even 6/5 match it.
858+
fprintf(stderr, " - K == q8_0, V == q8_0, 8.5 BPV\n"); //the classic, can't go wrong with this one.
859+
fprintf(stderr, " - K == f16, V == q8_0, 12.25 BPV\n"); //uncompromizing quality, non quantized K (the most sensitive to quantization).
860+
fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n"); //non-quantized KV cache.
861+
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q6_0, q8_0, and f16.\n");
857862
GGML_ABORT("fatal error");
858863
} else {
859864
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,10 @@ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
525525
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
526526
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16);
527527

528+
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
529+
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
530+
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
531+
528532
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
529533
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
530534
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,10 @@ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
516516
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
517517
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16);
518518

519+
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
520+
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
521+
522+
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
519523
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
520524
//extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
521525
//extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
168168

169169
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
170170
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
171-
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
171+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
172172
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
173173
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
174174
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
@@ -178,7 +178,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
178178
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
179179
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
180180
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
181-
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
181+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
182182
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_1)
183183
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
184184
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
@@ -211,30 +211,41 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
211211
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
212212
//FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
213213

214+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
215+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
216+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
214217
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
215218

216219
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
217220
//FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
218221

219222
#else
220-
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
221-
223+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
224+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
222225
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
223226

224-
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
225-
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
226-
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
227+
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
228+
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
229+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
230+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
231+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
232+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
233+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
234+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
227235

228236
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
229237
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
230238
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
231239
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
232240
//FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
233241

234-
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
235-
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
236-
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
237-
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
242+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
243+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
244+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_IQ4_NL)
245+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
246+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
247+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
248+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
238249

239250
#endif // GGML_CUDA_FA_ALL_QUANTS
240251

@@ -292,7 +303,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
292303

293304
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
294305
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
295-
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
306+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
296307
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
297308
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
298309
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
@@ -302,7 +313,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
302313
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
303314
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
304315
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
305-
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
316+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
306317
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_1)
307318
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
308319
//FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
@@ -335,25 +346,41 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
335346
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
336347
//FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
337348

349+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
350+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
351+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
338352
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
339353

340354
//FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
341355
//FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
342356
#else
343-
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
344-
357+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
358+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
345359
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
346360

347-
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
348-
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
349-
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
361+
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
362+
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
363+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
364+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
365+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
366+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
367+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
368+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
350369

351370
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
352371
//FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
353372
//FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
354373
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
355374
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
356375
//FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
376+
377+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
378+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
379+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_IQ4_NL)
380+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
381+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
382+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
383+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
357384
#endif // GGML_CUDA_FA_ALL_QUANTS
358385

359386
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)