@@ -169,7 +169,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
169169
170170 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
171171 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
172- // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
172+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
173173 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
174174 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
175175 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
@@ -179,7 +179,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
179179 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
180180 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
181181 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
182- // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
182+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
183183 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_1)
184184 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
185185 // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
@@ -212,30 +212,41 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
212212 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
213213 // FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
214214
215+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
216+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
217+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
215218 FATTN_VEC_F16_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
216219
217220 // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
218221 // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
219222
220223#else
221- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
222-
224+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
225+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
223226 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
224227
225- FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
226- FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
227- FATTN_VEC_F16_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
228+ // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
229+ // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
230+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
231+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
232+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
233+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
234+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
235+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
228236
229237 // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
230238 // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
231239 // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
232240 // FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
233241 // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
234242
235- FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
236- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
237- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
238- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
243+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
244+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
245+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_IQ4_NL)
246+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
247+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
248+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
249+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
239250
240251#endif // GGML_CUDA_FA_ALL_QUANTS
241252
@@ -293,7 +304,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
293304
294305 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
295306 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
296- // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
307+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
297308 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
298309 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
299310 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
@@ -303,7 +314,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
303314 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
304315 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
305316 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
306- // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
317+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
307318 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_1)
308319 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
309320 // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
@@ -336,25 +347,41 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
336347 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
337348 // FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
338349
350+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
351+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
352+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
339353 FATTN_VEC_F32_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
340354
341355 // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
342356 // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
343357#else
344- FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
345-
358+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
359+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
346360 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
347361
348- FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
349- FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
350- FATTN_VEC_F32_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
362+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
363+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
364+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
365+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
366+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
367+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
368+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
369+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
351370
352371 // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
353372 // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
354373 // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
355374 // FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
356375 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
357376 // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
377+
378+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
379+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
380+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_IQ4_NL)
381+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
382+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
383+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
384+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
358385#endif // GGML_CUDA_FA_ALL_QUANTS
359386
360387 GGML_ABORT (" fatal error" );
0 commit comments