diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 04ad2460d32..bddc73eb5a8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -824,6 +824,7 @@ struct vk_device_struct { vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; vk_pipeline pipeline_pool2d_f32; + vk_pipeline pipeline_turbo_wht; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128 @@ -3447,11 +3448,13 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_TURBO3_0, turbo3_0, FA_SCALAR, ) } else { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_TURBO3_0, turbo3_0, FA_SCALAR, _fp32) } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { @@ -4187,7 +4190,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {128, 1, 1}, {}, 1); + + // TurboQuant WHT + ggml_vk_create_pipeline(device, device->pipeline_turbo_wht, "turbo_wht", turbo_wht_len, turbo_wht_data, "main", 2, 3 * sizeof(uint32_t), {128, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -4307,7 +4313,6 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_rte_len, cpy_f32_turbo3_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); } else { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -4315,7 +4320,6 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_len, cpy_f32_turbo3_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); } #define SET_ROWS(itype, rte) \ @@ -7278,6 +7282,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_TURBO3_0: return ctx->device->pipeline_cpy_quant_f32[src->type]; default: break; @@ -10063,7 +10068,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SET_ROWS: { uint32_t ne = ggml_nelements(src0); - if (ggml_is_quantized(dst->type)) { + if (dst->type == GGML_TYPE_TURBO3_0) { + ne = ne / 128; + } else if (ggml_is_quantized(dst->type)) { // quants run 32 threads each doing QUANT_K elements ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type)); } else { @@ -10834,6 +10841,32 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, }); } +static void ggml_vk_turbo_wht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + int direction, group_size; + memcpy(&direction, dst->op_params + 0, sizeof(int)); + memcpy(&group_size, dst->op_params + sizeof(int), sizeof(int)); + struct { uint32_t ne; uint32_t direction; uint32_t group_size; } pc = { + (uint32_t)ggml_nelements(src0), (uint32_t)direction, (uint32_t)group_size, + }; + vk_pipeline pipeline = ctx->device->pipeline_turbo_wht; + GGML_ASSERT(pipeline != nullptr); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0, false); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false); + // Spread workgroups across Y/Z to stay within maxComputeWorkGroupCount[0]. + // elements[0] / group_size = wg0; each row of 512 workgroups uses one Y slice. + const uint32_t n_groups = pc.ne / (uint32_t)group_size; + std::array elements; + if (n_groups > 262144) { + elements = { 512 * (uint32_t)group_size, 512, CEIL_DIV(n_groups, 262144) }; + } else if (n_groups > 512) { + elements = { 512 * (uint32_t)group_size, CEIL_DIV(n_groups, 512), 1 }; + } else { + elements = { pc.ne, 1, 1 }; + } + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, elements); +} + static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f }); } @@ -13015,6 +13048,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SET_ROWS: ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_TURBO_WHT: + ggml_vk_turbo_wht(ctx, compute_ctx, src0, node); + break; case GGML_OP_SILU_BACK: ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node); @@ -15338,6 +15375,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_TURBO3_0: // supported in scalar and coopmat2 paths break; case GGML_TYPE_Q4_1: @@ -15441,7 +15479,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: - case GGML_TYPE_TURBO3_0: return true; default: break; @@ -15710,6 +15747,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_is_contiguous(op->src[1]) && ggml_is_contiguous(op)); } + case GGML_OP_TURBO_WHT: + return op->src[0]->type == GGML_TYPE_F32 && op->src[0]->ne[0] % 128 == 0; default: return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index 54331e28c82..31a7f5f5434 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,9 +1,15 @@ #version 450 +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_shuffle : enable #include "rte.glsl" #include "types.glsl" -#if defined(SET_ROWS) && QUANT_K == 1 +#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0) +layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 128; +#elif defined(SET_ROWS) && QUANT_K == 1 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; const uint BLOCK_SIZE = 512; #else @@ -185,62 +191,67 @@ void quantize(uint dst_idx, uint src_idx) #endif #if defined(DATA_A_TURBO3_0) -void quantize(uint dst_idx, uint src_idx) -{ - const float centroids[8] = float[8]( - -0.190685, -0.117832, -0.065717, -0.021460, - 0.021460, 0.065717, 0.117832, 0.190685 - ); - const float midpoints[7] = float[7]( - -0.154259, -0.091775, -0.043589, 0.0, 0.043589, 0.091775, 0.154259 - ); - - // Compute L2 norm - float norm_sq = 0.0; - [[unroll]] for (int j = 0; j < 32; ++j) { - float v = data_s[src_idx + j]; - norm_sq += v * v; - } - float norm = sqrt(norm_sq); - float inv_norm = (norm > 1e-10) ? (1.0 / norm) : 0.0; - - // Clear output - [[unroll]] for (int j = 0; j < 8; ++j) data_q[dst_idx].qs[j] = uint8_t(0); - [[unroll]] for (int j = 0; j < 4; ++j) data_q[dst_idx].signs[j] = uint8_t(0); - - // Accumulate centroid reconstruction norm for correction - float recon_norm_sq = 0.0; - - // Quantize each element - [[unroll]] for (int j = 0; j < 32; ++j) { - float val = data_s[src_idx + j] * inv_norm; - - // Find nearest centroid via midpoint comparison - uint idx = 0; - if (val < midpoints[0]) idx = 0; - else if (val < midpoints[1]) idx = 1; - else if (val < midpoints[2]) idx = 2; - else if (val < midpoints[3]) idx = 3; - else if (val < midpoints[4]) idx = 4; - else if (val < midpoints[5]) idx = 5; - else if (val < midpoints[6]) idx = 6; - else idx = 7; - - recon_norm_sq += centroids[idx] * centroids[idx]; - - // Pack: low 2 bits to qs, high 1 bit to signs - uint low2 = idx & 0x3; - uint hi1 = (idx >> 2) & 0x1; - data_q[dst_idx].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2)); - data_q[dst_idx].signs[j / 8] |= uint8_t(hi1 << (j % 8)); - } +const float TS1[128] = float[128]( + -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1, + -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1, + -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1, + 1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1, + -1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, + 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1 +); + +const float TS2[128] = float[128]( + 1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, + 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1, + -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, + -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1 +); + +const float TINV = 0.08838834764831845; // 1 / sqrt(128) + +const float TC[8] = float[8]( + -0.190685, -0.117832, -0.065717, -0.021460, + 0.021460, 0.065717, 0.117832, 0.190685 +); + +const float TM[7] = float[7]( + -0.154259, -0.091775, -0.043589, + 0.0, + 0.043589, 0.091775, 0.154259 +); - // Norm correction: scale so reconstruction matches original norm - float recon_norm = sqrt(recon_norm_sq); - float corrected_norm = (recon_norm > 1e-10) ? (norm / recon_norm) : norm; - data_q[dst_idx].norm = float16_t(corrected_norm); +#if defined(SET_ROWS) + +shared float wht[128]; +shared float sg_acc[16]; +shared float gnrm; + +void quantize_block(uint b, uint o) { + [[unroll]] for (int j = 0; j < 32; ++j) data_q[b].qs[j] = uint8_t(0); + [[unroll]] for (int j = 0; j < 16; ++j) data_q[b].signs[j] = uint8_t(0); + float rs = 0.0; + [[unroll]] for (int j = 0; j < 128; ++j) { + float v = wht[o + j]; + uint i = v < TM[0] ? 0 : v < TM[1] ? 1 : v < TM[2] ? 2 : v < TM[3] ? 3 : + v < TM[4] ? 4 : v < TM[5] ? 5 : v < TM[6] ? 6 : 7; + rs += TC[i] * TC[i]; + uint low2 = i & 0x3; + uint hi1 = (i >> 2) & 0x1; + data_q[b].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2)); + data_q[b].signs[j / 8] |= uint8_t(hi1 << (j % 8)); + } + float rn = sqrt(rs); + data_q[b].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm); } -#endif + +#endif // defined(SET_ROWS) +#endif // defined(DATA_A_TURBO3_0) #if defined(DATA_A_IQ4_NL) uint best_index(float x) { @@ -304,7 +315,97 @@ void quantize(uint dst_idx, uint src_idx) } #endif -#if defined(SET_ROWS) +#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0) +void main() { + const uint t = gl_LocalInvocationID.x; + const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint gpr = p.ne00 / 128; + + if (gpr == 0) return; + if (g >= p.ne / 128) return; + + uint tmp = g; + const uint ig = tmp % gpr; tmp /= gpr; + const uint i01 = tmp % p.ne01; tmp /= p.ne01; + const uint i02 = tmp % p.ne12; + const uint i03 = tmp / p.ne12; + + const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset(); + const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE; + const uint db = dst_idx(ig, i1, i02, i03) + get_doffset(); + + // Step 1: load into shared memory + wht[t] = data_s[sb + t]; + barrier(); + + // Step 2: L2 norm via subgroup reduction + float v2 = wht[t] * wht[t]; + v2 = subgroupAdd(v2); + if (gl_SubgroupInvocationID == 0) sg_acc[gl_SubgroupID] = v2; + barrier(); + if (t == 0) { + float total = 0.0; + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc[w]; + gnrm = sqrt(total); + } + barrier(); + + // Step 3: normalize, then apply forward WHT: signs1 -> butterfly -> signs2 + wht[t] *= (gnrm > 1e-10) ? (1.0 / gnrm) : 0.0; + barrier(); + + wht[t] *= TS1[t]; + barrier(); + + [[unroll]] for (uint h = 1; h < 128; h *= 2) { + if ((t % (2 * h)) < h) { + float a = wht[t]; + float b = wht[t + h]; + wht[t] = a + b; + wht[t + h] = a - b; + } + barrier(); + } + + // Step 5: apply signs2 + scaling + float rv = wht[t] * TINV * TS2[t]; + + // Step 6: quantize -- all 128 threads participate + uint idx = rv < TM[0] ? 0u : rv < TM[1] ? 1u : rv < TM[2] ? 2u : rv < TM[3] ? 3u : + rv < TM[4] ? 4u : rv < TM[5] ? 5u : rv < TM[6] ? 6u : 7u; + + // Pack qs: 4 elements per byte via subgroup shuffle + uint sg_lane = gl_SubgroupInvocationID; + uint my_low2 = idx & 0x3u; + uint qs_byte = 0u; + [[unroll]] for (uint k = 0; k < 4; k++) { + uint contrib = subgroupShuffle(my_low2, (sg_lane & ~3u) + k); + qs_byte |= contrib << (k * 2u); + } + if (sg_lane % 4u == 0u) { + data_q[db].qs[t / 4u] = uint8_t(qs_byte); + } + + // Pack signs: 8 elements per byte via subgroup ballot + uvec4 ballot = subgroupBallot(((idx >> 2u) & 1u) != 0u); + if (sg_lane % 8u == 0u) { + uint local_byte = sg_lane / 8u; + data_q[db].signs[t / 8u] = uint8_t((ballot.x >> (local_byte * 8u)) & 0xFFu); + } + + // Step 7: reconstruction norm via subgroup reduction + float rc = TC[idx] * TC[idx]; + rc = subgroupAdd(rc); + if (sg_lane == 0u) sg_acc[gl_SubgroupID] = rc; + barrier(); + if (t == 0u) { + float total = 0.0; + for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc[w]; + float rn = sqrt(total); + data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm); + } +} +#elif defined(SET_ROWS) void main() { #ifdef NEEDS_INIT_IQ_SHMEM diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_turbo3_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_turbo3_0.comp index 17b9bd9eb4b..c12dd5f8e95 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_turbo3_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_turbo3_0.comp @@ -2,7 +2,9 @@ #include "dequant_head.glsl" -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; +// 128 elements per block (QK_TURBO3 = 128) +// Each workgroup processes one block. 128 threads, 1 element per thread. +layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {block_turbo3_0 data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; @@ -13,34 +15,22 @@ void main() { 0.021460, 0.065717, 0.117832, 0.190685 ); - const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + // Each workgroup processes one 128-element block + const uint ib = gl_WorkGroupID.x; + const uint j = gl_LocalInvocationID.x; // element index 0..127 - const uint tid = gl_LocalInvocationID.x % 64; - const uint il = tid/32; - const uint ir = tid%32; - const uint ib = 32*i + ir; - if (ib >= p.nel / 32) { - return; - } - - const uint b_idx = 1024*i + 32*ir + 16*il; + if (ib >= p.nel / 128) return; const float norm = float(data_a[ib].norm); - const uint q_start = 16*il; - - [[unroll]] for (uint l = 0; l < 16; ++l) { - const uint j = q_start + l; - - // Extract 2-bit low index from qs (4 per byte) - const uint low2 = (uint(data_a[ib].qs[j / 4]) >> ((j % 4) * 2)) & 0x3; + // Extract 2-bit low index from qs (4 per byte) + const uint low2 = (uint(data_a[ib].qs[j / 4]) >> ((j % 4) * 2)) & 0x3; - // Extract 1-bit high from signs (8 per byte) - const uint hi1 = (uint(data_a[ib].signs[j / 8]) >> (j % 8)) & 0x1; + // Extract 1-bit high from signs (8 per byte) + const uint hi1 = (uint(data_a[ib].signs[j / 8]) >> (j % 8)) & 0x1; - // Combine to 3-bit index - const uint idx = low2 | (hi1 << 2); + // Combine to 3-bit index + const uint idx = low2 | (hi1 << 2); - data_b[b_idx + l] = D_TYPE(centroids[idx] * norm); - } + data_b[ib * 128 + j] = D_TYPE(centroids[idx] * norm); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 11b7dce8578..7b6d07ba432 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -27,10 +27,12 @@ const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupS layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +#if !defined(DATA_A_TURBO3_0) layout (binding = 1) readonly buffer K {float16_t data_k[];}; layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +#endif layout (binding = 3) readonly buffer M {float16_t data_m[];}; // If SubGroupSize is set to 0 then only use shmem reductions diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 172d38f034e..9925902c06b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -84,6 +84,9 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; #if defined(DATA_A_F32) layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed; layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed; +#elif defined(DATA_A_TURBO3_0) +layout (binding = 1) readonly buffer K_T3 {block_turbo3_0 data_k_t3[];}; +layout (binding = 2) readonly buffer V_T3 {block_turbo3_0 data_v_t3[];}; #elif defined(A_TYPE_PACKED16) layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; @@ -93,6 +96,11 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16 #define BLOCK_SIZE 1 #endif +// turbo3: define BLOCK_BYTE_SIZE early (before first use in FA offset computation) +#if defined(DATA_A_TURBO3_0) && !defined(BLOCK_BYTE_SIZE) +#define BLOCK_BYTE_SIZE 50 // block_turbo3_0: 2 (norm) + 32 (qs) + 16 (signs) = 50 bytes +#endif + #if defined(DATA_A_F32) #undef BLOCK_SIZE #define BLOCK_SIZE 4 @@ -149,6 +157,35 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { } #endif +#if defined(DATA_A_TURBO3_0) +const float T3C[8] = float[8]( + -0.190685, -0.117832, -0.065717, -0.021460, + 0.021460, 0.065717, 0.117832, 0.190685 +); +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + FLOAT_TYPEV4 r; + for (int k = 0; k < 4; k++) { + uint j = iqs + uint(k); + float nm; + uint qb; + uint sb; + if (binding_idx == BINDING_IDX_K) { + nm = float(data_k_t3[a_offset + ib].norm); + qb = uint(data_k_t3[a_offset + ib].qs[j / 4]); + sb = uint(data_k_t3[a_offset + ib].signs[j / 8]); + } else { + nm = float(data_v_t3[a_offset + ib].norm); + qb = uint(data_v_t3[a_offset + ib].qs[j / 4]); + sb = uint(data_v_t3[a_offset + ib].signs[j / 8]); + } + uint lo = (qb >> ((j % 4) * 2)) & 0x3; + uint hi = (sb >> (j % 8)) & 0x1; + r[k] = FLOAT_TYPE(T3C[lo | (hi << 2)] * nm); + } + return r; +} +#endif + #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/turbo_wht.comp b/ggml/src/ggml-vulkan/vulkan-shaders/turbo_wht.comp new file mode 100644 index 00000000000..914875eba7a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/turbo_wht.comp @@ -0,0 +1,67 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { uint ne; uint direction; uint group_size; } p; + +layout (binding = 0) readonly buffer A { float data_a[]; }; +layout (binding = 1) writeonly buffer D { float data_d[]; }; + +shared float x[128]; + +// Pre-scramble sign vectors applied before and after the WHT. +// direction == 0: pre = S1, post = S2; direction == 1: pre = S2, post = S1. +const float S1[128] = float[128]( + -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1, + -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1, + -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1, + 1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1, + -1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, + 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1 +); + +const float S2[128] = float[128]( + 1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, + 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1, + -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, + -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1 +); + +const float INV_SQRT_128 = 0.08838834764831845; // 1 / sqrt(128) + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint base = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * 128; + + if (base + tid >= p.ne) return; + + const float fs = (p.direction == 0) ? S1[tid] : S2[tid]; + const float ss = (p.direction == 0) ? S2[tid] : S1[tid]; + + x[tid] = data_a[base + tid]; + barrier(); + + x[tid] *= fs; + barrier(); + + [[unroll]] for (uint h = 1; h < 128; h *= 2) { + if ((tid % (2 * h)) < h) { + float a = x[tid]; + float b = x[tid + h]; + x[tid] = a + b; + x[tid + h] = a - b; + } + barrier(); + } + + data_d[base + tid] = x[tid] * INV_SQRT_128 * ss; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index e3635fa01b7..1171192a053 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -6,6 +6,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_8bit_storage : require #if defined(DATA_A_F32) #define QUANT_K 1 @@ -1696,14 +1697,14 @@ struct block_mxfp4 #define A_TYPE block_mxfp4 #endif -#define QUANT_K_TURBO3_0 32 +#define QUANT_K_TURBO3_0 128 #define QUANT_R_TURBO3_0 1 struct block_turbo3_0 { float16_t norm; - uint8_t qs[8]; // 2-bit centroid indices (4 per byte) - uint8_t signs[4]; // 1-bit high bit of 3-bit index (8 per byte) + uint8_t qs[32]; // 2-bit centroid indices (4 per byte), 128/4 = 32 bytes + uint8_t signs[16]; // 1-bit high bit of 3-bit index (8 per byte), 128/8 = 16 bytes }; #if defined(DATA_A_TURBO3_0) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 90253243ab8..e6eaf85804e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -656,7 +656,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32" || tname == "turbo3_0") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); @@ -667,7 +667,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32" || tname == "turbo3_0") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); @@ -758,11 +758,13 @@ void process_shaders() { string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}); - for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0"}) { + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } + // turbo3_0 copy-from-quant only; copy-to-quant (cpy_f32_turbo3_0) omitted because the non-SET_ROWS quantize() path lacks the WHT transform + string_to_spv("cpy_turbo3_0_f32", "copy_from_quant.comp", {{"DATA_A_TURBO3_0", "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0"}) { string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -771,6 +773,9 @@ void process_shaders() { string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); } + // TurboQuant WHT operation + string_to_spv("turbo_wht", "turbo_wht.comp", {}); + auto get_type_str = [](bool f16) { return f16 ? "float16_t" : "float"; }; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6ad8a648c8c..ee725ec5c27 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6168,6 +6168,126 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_TURBO_WHT +struct test_turbo_wht : public test_case { + const int64_t head_dim; + const int64_t n_heads; + const int direction; // 0=forward, 1=inverse + + std::string vars() override { + return VARS_TO_STR3(head_dim, n_heads, direction); + } + + double max_nmse_err() override { + return 1e-5; // f32 SIMD reduction order varies across GPU backends + } + + test_turbo_wht(int64_t head_dim = 128, int64_t n_heads = 4, int direction = 0) + : head_dim(head_dim), n_heads(n_heads), direction(direction) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, head_dim, n_heads); + ggml_set_param(a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_turbo_wht(ctx, a, direction, 0, nullptr); + ggml_set_name(out, "out"); + return out; + } +}; + +// GGML_OP_TURBO_WHT round-trip: forward then inverse should recover the original +struct test_turbo_wht_roundtrip : public test_case { + const int64_t head_dim; + const int64_t n_heads; + + std::string vars() override { + return VARS_TO_STR2(head_dim, n_heads); + } + + double max_nmse_err() override { + return 1e-5; // two WHT passes compound the f32 reduction error + } + + test_turbo_wht_roundtrip(int64_t head_dim = 128, int64_t n_heads = 4) + : head_dim(head_dim), n_heads(n_heads) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, head_dim, n_heads); + ggml_set_param(a); + ggml_set_name(a, "a"); + // forward WHT (direction=0), then inverse WHT (direction=1) + ggml_tensor * fwd = ggml_turbo_wht(ctx, a, 0, 0, nullptr); + ggml_tensor * inv = ggml_turbo_wht(ctx, fwd, 1, 0, nullptr); + ggml_set_name(inv, "out"); + return inv; + } +}; + +// Test SET_ROWS with turbo3 destination, then dequantize and compare. +// This validates the full quantization pipeline: f32 -> WHT -> PolarQuant -> turbo3 +// followed by dequantization: turbo3 -> f32. The round-trip error should be bounded. +// Unlike the generic SET_ROWS test (which compares raw quantized bytes), this test +// compares the dequantized f32 output, tolerating the lossy quantization error. +struct test_set_rows_turbo3 : public test_case { + const ggml_type type_idx; + const int64_t ne0; // head dim (must be multiple of 128) + const int64_t ne1; // rows in dst + const int r; // rows to write + + std::string vars() override { + return VARS_TO_STR4(type_idx, ne0, ne1, r); + } + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "SET_ROWS_TURBO3"; + } + + test_set_rows_turbo3(ggml_type type_idx = GGML_TYPE_I32, + int64_t ne0 = 128, int64_t ne1 = 8, int r = 4) + : type_idx(type_idx), ne0(ne0), ne1(ne1), r(r) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + // dst: the turbo3 KV cache buffer + ggml_tensor * dst = ggml_new_tensor_2d(ctx, GGML_TYPE_TURBO3_0, ne0, ne1); + ggml_set_name(dst, "dst"); + + // src: f32 values to quantize into the cache + ggml_tensor * src = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne0, r); + ggml_set_name(src, "src"); + + // row indices + ggml_tensor * row_idxs = ggml_new_tensor_1d(ctx, type_idx, r); + ggml_set_name(row_idxs, "row_idxs"); + + // Write f32 data into turbo3 dst via SET_ROWS (includes WHT + quantize) + ggml_tensor * written = ggml_set_rows(ctx, dst, src, row_idxs); + + // Read it back by dequantizing the written rows to f32 + ggml_tensor * out = ggml_cpy(ctx, written, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne0, ne1)); + ggml_set_name(out, "out"); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) continue; + init_set_rows_row_ids(t, ne1); + } else { + init_tensor_uniform(t); + } + } + } + + double max_nmse_err() override { + // turbo3 is 3-bit quantization with WHT rotation. + // The round-trip error (f32 -> turbo3 -> f32) is higher than q8_0 + // but bounded. Empirically ~0.02 NMSE for uniform[-1,1] data. + return 0.05; + } +}; + // GGML_OP_FLASH_ATTN_EXT struct test_flash_attn_ext : public test_case { const int64_t hsk; // K head size @@ -8585,6 +8705,38 @@ static std::vector> make_test_cases_eval() { } } + // TURBO_WHT tests + for (int dir : {0, 1}) { + for (int64_t hd : {128, 256, 512}) { + for (int64_t nh : {1, 4, 8}) { + test_cases.emplace_back(new test_turbo_wht(hd, nh, dir)); + } + } + } + + // TURBO_WHT round-trip tests (forward then inverse = identity) + for (int64_t hd : {128, 256, 512}) { + for (int64_t nh : {1, 4, 8}) { + test_cases.emplace_back(new test_turbo_wht_roundtrip(hd, nh)); + } + } + + // SET_ROWS with turbo3 destination: quantize then dequant round-trip + // Small tensors (single-dim dispatch) + for (ggml_type idx_type : {GGML_TYPE_I32, GGML_TYPE_I64}) { + for (int64_t ne0 : {128, 256, 512}) { + for (int r : {1, 4, 7}) { + test_cases.emplace_back(new test_set_rows_turbo3(idx_type, ne0, 16, r)); + } + } + } + // Large tensors -- exercises 2D dispatch grid (>512 workgroups), + // matching actual inference dimensions (4 kv_heads, batch=1024+) + test_cases.emplace_back(new test_set_rows_turbo3(GGML_TYPE_I32, 128, 4096, 1024)); + test_cases.emplace_back(new test_set_rows_turbo3(GGML_TYPE_I32, 256, 2048, 512)); + test_cases.emplace_back(new test_set_rows_turbo3(GGML_TYPE_I32, 512, 1024, 256)); + + for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 320, 512, 576 }) { for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) { if (hsk != 192 && hsk != 320 && hsk != 576 && hsk != hsv) continue; @@ -8612,8 +8764,9 @@ static std::vector> make_test_cases_eval() { for (int nb : { 1, 3, 32, 75, }) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; - for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue; + for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_TURBO3_0}) { + if (type_KV == GGML_TYPE_TURBO3_0 && hsk < 128) continue; + if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue; test_cases.emplace_back(new test_flash_attn_ext( hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV)); // run fewer test cases permuted