@@ -168,7 +168,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
168168
169169 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
170170 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
171- // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
171+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
172172 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
173173 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
174174 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
@@ -178,7 +178,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
178178 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
179179 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
180180 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
181- // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
181+ FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
182182 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_1)
183183 // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
184184 // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
@@ -211,30 +211,41 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
211211 FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
212212 // FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
213213
214+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
215+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
216+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
214217 FATTN_VEC_F16_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
215218
216219 // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
217220 // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
218221
219222#else
220- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
221-
223+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
224+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
222225 FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
223226
224- FATTN_VEC_F16_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
225- FATTN_VEC_F16_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
226- FATTN_VEC_F16_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
227+ // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
228+ // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
229+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
230+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
231+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
232+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
233+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
234+ // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
227235
228236 // FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
229237 // FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
230238 // FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
231239 // FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
232240 // FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
233241
234- FATTN_VEC_F16_CASE (128 , GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
235- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
236- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
237- FATTN_VEC_F16_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
242+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
243+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
244+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_IQ4_NL)
245+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
246+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
247+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
248+ // FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
238249
239250#endif // GGML_CUDA_FA_ALL_QUANTS
240251
@@ -292,7 +303,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
292303
293304 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
294305 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
295- // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
306+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
296307 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
297308 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
298309 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
@@ -302,7 +313,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
302313 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
303314 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
304315 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
305- // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
316+ FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
306317 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_1)
307318 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
308319 // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
@@ -335,25 +346,41 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
335346 FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
336347 // FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
337348
349+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
350+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
351+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
338352 FATTN_VEC_F32_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
339353
340354 // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
341355 // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
342356#else
343- FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
344-
357+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
358+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
345359 FATTN_VEC_F32_CASE (128 , GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
346360
347- FATTN_VEC_F32_CASE ( 64 , GGML_TYPE_F16, GGML_TYPE_F16)
348- FATTN_VEC_F32_CASE (128 , GGML_TYPE_F16, GGML_TYPE_F16)
349- FATTN_VEC_F32_CASE (256 , GGML_TYPE_F16, GGML_TYPE_F16)
361+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
362+ // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
363+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
364+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
365+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
366+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
367+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
368+ // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
350369
351370 // FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
352371 // FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
353372 // FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
354373 // FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
355374 // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
356375 // FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
376+
377+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
378+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
379+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_IQ4_NL)
380+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0)
381+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
382+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
383+ // FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0)
357384#endif // GGML_CUDA_FA_ALL_QUANTS
358385
359386 GGML_ABORT (" fatal error" );
0 commit comments