Skip to content

Commit e2e4852

Browse files
committed
Up FA KV modes
1 parent f0edaf3 commit e2e4852

File tree

4 files changed

+66
-26
lines changed

4 files changed

+66
-26
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -851,14 +851,19 @@ static void on_no_fattn_vec_case(const int D) {
851851
} else if (D == 128) {
852852
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
853853
fprintf(stderr, "Supported combinations:\n");
854-
fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n");
855-
fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.5 BPV\n");
856-
fprintf(stderr, " - K == q6_0, V == q5_0, 6.0 BPV\n");
857-
fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.5 BPV\n");
858-
fprintf(stderr, " - K == q8_0, V == q6_0, 7.5 BPV\n");
859-
fprintf(stderr, " - K == q8_0, V == q8_0, 8.5 BPV\n");
860-
fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n");
861-
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n");
854+
fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n"); //now obsolete, left as a legacy failsafe.
855+
fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.5 BPV\n"); //replaces KV q4_0, 1-3% perplexity drop.
856+
fprintf(stderr, " - K == q5_0, V == iq4_nl, 5.0 BPV\n"); //best performance oriented compromise, and recommanded for speculative model.
857+
fprintf(stderr, " - K == q5_1, V == q5_0, 5.5 BPV\n"); //pre K q6_0 most balanced compromise, left as a failsafe.
858+
fprintf(stderr, " - K == q6_0, V == iq4_nl, 5.5 BPV\n"); //replaces KV q5_1-q5_0, 0.1-0.3 perplexity drop.
859+
fprintf(stderr, " - K == q6_0, V == q5_0, 6.0 BPV\n"); //almost equals KV q8_0/q5_0, almost..
860+
fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.5 BPV\n"); //IK's favorite.
861+
fprintf(stderr, " - K == q8_0, V == q5_0, 7.0 BPV\n"); //qualitative compromise.
862+
fprintf(stderr, " - K == q8_0, V == q6_0, 7.5 BPV\n"); //should be optimal, but KV 8/5 or even 6/5 match it.
863+
fprintf(stderr, " - K == q8_0, V == q8_0, 8.5 BPV\n"); //the classic, can't go wrong with this one.
864+
fprintf(stderr, " - K == f16, V == q8_0, 12.25 BPV\n"); //uncompromizing quality, non quantized K (the most sensitive to quantization).
865+
fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n"); //non-quantized KV cache.
866+
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q6_0, q8_0, and f16.\n");
862867
GGML_ABORT("fatal error");
863868
} else {
864869
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,10 @@ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
525525
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
526526
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16);
527527

528+
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
529+
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
530+
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
531+
528532
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
529533
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
530534
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,10 @@ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
516516
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
517517
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16);
518518

519+
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
520+
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
521+
522+
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
519523
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
520524
//extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
521525
//extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
169169

170170
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
171171
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
172-
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
172+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
173173
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
174174
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
175175
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
@@ -179,7 +179,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
179179
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
180180
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
181181
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
182-
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
182+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
183183
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_1)
184184
//FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
185185
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
@@ -212,30 +212,41 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
212212
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
213213
//FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
214214

215+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
216+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
217+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
215218
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
216219

217220
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
218221
//FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
219222

220223
#else
221-
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
222-
224+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
225+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
223226
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
224227

225-
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
226-
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
227-
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
228+
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
229+
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
230+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
231+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
232+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
233+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
234+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
235+
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
228236

229237
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
230238
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
231239
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
232240
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
233241
//FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
234242

235-
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
236-
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
237-
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
238-
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
243+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
244+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
245+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_IQ4_NL)
246+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
247+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
248+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
249+
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
239250

240251
#endif // GGML_CUDA_FA_ALL_QUANTS
241252

@@ -293,7 +304,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
293304

294305
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
295306
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
296-
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
307+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
297308
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
298309
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
299310
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
@@ -303,7 +314,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
303314
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
304315
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
305316
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
306-
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
317+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
307318
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_1)
308319
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
309320
//FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
@@ -336,25 +347,41 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
336347
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
337348
//FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
338349

350+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
351+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
352+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
339353
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
340354

341355
//FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
342356
//FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
343357
#else
344-
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
345-
358+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
359+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
346360
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
347361

348-
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
349-
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
350-
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
362+
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
363+
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
364+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
365+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
366+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
367+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
368+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
369+
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
351370

352371
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
353372
//FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
354373
//FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
355374
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
356375
//FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
357376
//FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
377+
378+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
379+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
380+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_IQ4_NL)
381+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
382+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
383+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
384+
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
358385
#endif // GGML_CUDA_FA_ALL_QUANTS
359386

360387
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)