diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 173677a2637a9..a9c7c6c2fb0b4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -486,6 +486,7 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT]; vk_pipeline pipeline_matmul_split_k_reduce; vk_pipeline pipeline_quantize_q8_1; @@ -2448,8 +2449,11 @@ static void ggml_vk_load_shaders(vk_device& device) { l_warptile_id, m_warptile_id, s_warptile_id, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, + l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k, l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, - l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; + l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid, + l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int, + l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k; std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, @@ -2512,10 +2516,16 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + // Integer MMQ has a smaller shared memory profile, but heavier register use l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + // K-quants use even more registers, mitigate by setting WMITER to 1 + l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 }; + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; @@ -2524,10 +2534,18 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; + m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; + s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; + + l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; + m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; + s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; + // chip specific tuning if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; - m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -2912,18 +2930,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ -#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ if (device->mul_mat ## ID ## _m[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ if (device->mul_mat ## ID ## _s[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ // Create 2 variants, {f16,f32} accumulator @@ -2962,11 +2977,19 @@ static void ggml_vk_load_shaders(vk_device& device) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { - CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); } #endif @@ -2996,6 +3019,24 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + } +#endif } else { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); @@ -3022,6 +3063,24 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + } +#endif } #undef CREATE_MM2 #undef CREATE_MMQ @@ -3086,6 +3145,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); } #endif @@ -3145,7 +3210,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } // reusing CREATE_MM from the fp32 path if ((device->coopmat2 || device->coopmat_support) -#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && !device->coopmat_bf16_support #endif ) { @@ -4928,7 +4993,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte // MMQ if (src1_type == GGML_TYPE_Q8_1) { - vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { return nullptr; @@ -5075,6 +5140,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc; + + if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + return nullptr; + } + + return pipelines; + } + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { @@ -6880,10 +6956,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + + // Check for mmq first + vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; + + if (mmp == nullptr) { + // Fall back to f16 dequant mul mat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + quantize_y = false; + } const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat @@ -6893,8 +6978,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); - const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); @@ -6907,12 +6992,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; - const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t ids_sz = nbi2; const uint64_t d_sz = sizeof(float) * d_ne; vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; + vk_pipeline to_q8_1 = nullptr; if (x_non_contig) { to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); @@ -6927,9 +7013,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); + } + if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; - const uint64_t y_sz_upd = y_sz * ne12 * ne13; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } if ( (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { @@ -6938,7 +7031,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ctx->prealloc_size_x = x_sz_upd; } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { ctx->prealloc_size_y = y_sz_upd; } @@ -6950,6 +7043,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } return; } @@ -6986,6 +7082,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qy_needs_dequant) { d_Y = ctx->prealloc_y; GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -7017,6 +7116,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ctx->prealloc_y_last_tensor_used = src1; } } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } uint32_t stride_batch_x = ne00*ne01; uint32_t stride_batch_y = ne10*ne11; @@ -7025,14 +7135,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); } - if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + // compute ggml_vk_matmul_id( ctx, subctx, pipeline, - { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, ne01, ne21, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 0d98f5a9d6bf1..09676a623ba63 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { #if defined(DATA_A_MXFP4) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]); + return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5; } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { vec2 v0 = dequantize(ib, iqs, a_offset); @@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]); const uint scales = data_a[a_offset + ib].scales[scalesi]; - const vec2 d = vec2(data_a[a_offset + ib].d); + const vec2 dm = vec2(data_a[a_offset + ib].dm); - return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); } vec2 get_dm(uint ib, uint a_offset) { return vec2(1, 0); @@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint is = 2 * n + b; // 0..7 const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - const vec2 loadd = vec2(data_a[a_offset + ib].d); + const vec2 loadd = vec2(data_a[a_offset + ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); @@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint8_t hm = uint8_t(1 << (iqs / 16)); - const vec2 loadd = vec2(data_a[a_offset + ib].d); + const vec2 loadd = vec2(data_a[a_offset + ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 67baedf7c6147..8ac6482dc944b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2 float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); - const f16vec2 d = bl.block.d; + const f16vec2 dm = bl.block.dm; const uint idx = coordInBlock[1]; const uint scalesi = (idx & 0xF0) >> 4; // 0..15 @@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2 qs = unpack8(qs)[idx & 1]; const uint scales = bl.block.scales[scalesi]; - float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); + float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4); return ret; } @@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords uint32_t qs = bl.block.qs[iqs]; qs >>= shift; qs &= 0xF; - float16_t ret = float16_t(kvalues_mxfp4[qs] * d); + float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); return ret; } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp index ffba5a77ddf53..3194ba291f311 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -26,7 +26,7 @@ void main() { const float d = e8m0_to_fp32(data_a[ib].e); [[unroll]] for (uint l = 0; l < 8; ++l) { - data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]); - data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]); + data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); + data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index 58dc2e5dfde9d..dc05a78348909 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -24,8 +24,8 @@ void main() { const uint ql_idx = 32 * ip + il; const uint8_t qs = data_a[i].qs[32 * ip + il]; - FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); - FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); + FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x); + FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y); data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 8b7be557e9548..0f23dc0a349f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -20,8 +20,8 @@ void main() { const uint is = 2 * il; const uint n = 4; - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y); const uint y_idx = ib * QUANT_K + 64 * il + n * ir; const uint qs_idx = 32*il + n * ir; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 6bc04670fc593..970469a601cc6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -19,8 +19,8 @@ void main() { const uint ir = tid % 16; const uint is = 2 * il; - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y); const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; const uint qs_idx = 32*il + 2 * ir; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 03ed25d3bfe4e..c460ec5c773b9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -41,9 +41,9 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + vec2 dm = vec2(data_a[ib0 + i].dm); + const FLOAT_TYPE dall = FLOAT_TYPE(dm.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(dm.y); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 21d07d2e50964..417d96af26498 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -14,9 +14,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].dm.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].dm.y); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 9e46c89a11f50..fb9c1c72419b5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -14,9 +14,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].dm.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].dm.y); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index a20788c4b51e3..d260969f07e88 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; #define NUM_WARPS (BLOCK_SIZE / WARP) -#ifdef MUL_MAT_ID -shared u16vec2 row_ids[BN]; -uint _ne1; - -#ifdef MUL_MAT_ID_USE_SUBGROUPS -shared uvec4 ballots_sh[NUM_WARPS]; - -void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { - _ne1 = 0; - uint num_elements = p.nei1 * p.nei0; - uint nei0shift = findLSB(p.nei0); - - uint ids[16]; - uint iter = 0; - - for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { - // prefetch up to 16 elements - if (iter == 0) { - [[unroll]] for (uint k = 0; k < 16; ++k) { - uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; - bool in_range = i < num_elements; - uint ii1; - if (nei0_is_pow2) { - ii1 = i >> nei0shift; - } else { - ii1 = i / p.nei0; - } - uint ii0 = i - ii1 * p.nei0; - ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; - } - } - uint i = j + gl_LocalInvocationIndex; - bool in_range = i < num_elements; - uint ii1; - if (nei0_is_pow2) { - ii1 = i >> nei0shift; - } else { - ii1 = i / p.nei0; - } - uint ii0 = i - ii1 * p.nei0; - uint id = ids[iter++]; - uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - - ballots_sh[gl_SubgroupID] = ballot; - barrier(); - - uint subgroup_base = 0; - uint total = 0; - for (uint k = 0; k < gl_NumSubgroups; ++k) { - if (k == gl_SubgroupID) { - subgroup_base = total; - } - total += subgroupBallotBitCount(ballots_sh[k]); - } - barrier(); - - uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); - if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { - row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); - } - _ne1 += total; - iter &= 15; - if (_ne1 >= (ic + 1) * BN) { - break; - } - } - barrier(); -} -#endif // MUL_MAT_ID_USE_SUBGROUPS -#endif // MUL_MAT_ID - #ifdef COOPMAT shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif +#include "mul_mm_id_funcs.glsl" #include "mul_mm_funcs.glsl" void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 0ebfbd6462c8b..ee5ded2e8d3eb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 - const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15 const uint scalesi = iqs / 8; // 0..15 const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi])); const uint scales = data_a[ib].scales[scalesi]; - const vec2 d = vec2(data_a[ib].d); + const vec2 dm = vec2(data_a[ib].dm); - const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_Q3_K) @@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint is = 2 * n + b; // 0..7 const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - const vec2 loadd = vec2(data_a[ib].d); + const vec2 loadd = vec2(data_a[ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); @@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint8_t hm = uint8_t(1 << (iqs / 16)); - const vec2 loadd = vec2(data_a[ib].d); + const vec2 loadd = vec2(data_a[ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); @@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 8; const uint iqs = (idx & 0x07) * 2; - const float d = e8m0_to_fp32(data_a[ib].e); + const float d = e8m0_to_fp32(data_a[ib].e) * 0.5; const uint vui = uint(data_a[ib].qs[iqs]); const uint vui2 = uint(data_a[ib].qs[iqs+1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl new file mode 100644 index 0000000000000..1d0e84ac94250 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl @@ -0,0 +1,70 @@ +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[BN]; +uint _ne1; + +#ifdef MUL_MAT_ID_USE_SUBGROUPS +shared uvec4 ballots_sh[NUM_WARPS]; + +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} +#endif // MUL_MAT_ID_USE_SUBGROUPS +#endif // MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index b5d761c0bab9e..8b238ac4bc117 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -10,10 +10,9 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif -#ifdef COOPMAT -#extension GL_KHR_cooperative_matrix : enable -#extension GL_KHR_memory_scope_semantics : enable +#if defined(MUL_MAT_ID_USE_SUBGROUPS) #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable #endif #ifdef MUL_MAT_ID @@ -24,7 +23,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif @@ -76,40 +78,27 @@ layout (constant_id = 10) const uint WARP = 32; #define BK 32 -#ifdef COOPMAT -#define SHMEM_STRIDE (BK / 4 + 4) -#else -#define SHMEM_STRIDE (BK / 4 + 1) -#endif +#define MMQ_SHMEM -shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; +#include "mul_mmq_shmem_types.glsl" -#ifndef COOPMAT -#if QUANT_AUXF == 1 -shared FLOAT_TYPE buf_a_dm[BM]; -#else -shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; -#endif +#ifndef BK_STEP +#define BK_STEP 4 #endif -shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; -#ifndef COOPMAT -shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; -#endif +// Shared memory cache +shared block_a_cache buf_a[BM * BK_STEP]; +shared block_b_cache buf_b[BN * BK_STEP]; +// Register cache +block_a_cache cache_a[WMITER * TM]; +block_b_cache cache_b; -#define LOAD_VEC_A (4 * QUANT_R) +#define LOAD_VEC_A (4 * QUANT_R_MMQ) #define LOAD_VEC_B 16 -#ifdef MUL_MAT_ID -shared u16vec2 row_ids[4096]; -#endif // MUL_MAT_ID - #define NUM_WARPS (BLOCK_SIZE / WARP) -#ifdef COOPMAT -shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; -#endif - +#include "mul_mm_id_funcs.glsl" #include "mul_mmq_funcs.glsl" void main() { @@ -139,26 +128,12 @@ void main() { const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; const uint WSUBN = WN / WNITER; - -#ifdef COOPMAT - const uint warp_i = gl_SubgroupID; - - const uint tiw = gl_SubgroupInvocationID; - - const uint cms_per_row = WM / TM; - const uint cms_per_col = WN / TN; - - const uint storestride = WARP / TM; - const uint store_r = tiw % TM; - const uint store_c = tiw / TM; -#else const uint warp_i = gl_LocalInvocationID.x / WARP; const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM); -#endif const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); @@ -172,17 +147,27 @@ void main() { const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; #ifdef MUL_MAT_ID - uint _ne1 = 0; - for (uint ii1 = 0; ii1 < p.nei1; ii1++) { - for (uint ii0 = 0; ii0 < p.nei0; ii0++) { +#ifdef MUL_MAT_ID_USE_SUBGROUPS + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); + } +#else + _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { - row_ids[_ne1] = u16vec2(ii0, ii1); + if (_ne1 >= ic * BN) { + row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); + } _ne1++; } } } barrier(); +#endif // Workgroup has no work if (ic * BN >= _ne1) return; @@ -209,159 +194,70 @@ void main() { uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; #endif -#ifdef COOPMAT - coopmat cache_a; - coopmat cache_b; - coopmat cm_result; - - coopmat factors[cms_per_row * cms_per_col]; - - coopmat sums[cms_per_row * cms_per_col]; - - [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { - sums[i] = coopmat(0.0f); - } -#else - int32_t cache_a_qs[WMITER * TM * BK / 4]; - - int32_t cache_b_qs[TN * BK / 4]; - ACC_TYPE sums[WMITER * TM * WNITER * TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); } -#endif -#if QUANT_AUXF == 1 - FLOAT_TYPE cache_a_dm[WMITER * TM]; -#else - FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; -#endif - - FLOAT_TYPE_VEC2 cache_b_ds[TN]; - - for (uint block = start_k; block < end_k; block += BK) { + for (uint block = start_k; block < end_k; block += BK * BK_STEP) { [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { - const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; - const uint iqs = loadr_a; const uint buf_ib = loadc_a + l; + const uint ib = pos_a_ib + buf_ib * p.stride_a / BK; + const uint iqs = loadr_a; - if (iqs == 0) { -#if QUANT_AUXF == 1 - buf_a_dm[buf_ib] = get_d(ib); -#else - buf_a_dm[buf_ib] = get_dm(ib); -#endif + [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { + block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs); } -#if QUANT_R == 1 - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); -#else - const i32vec2 vals = repack(ib, iqs); - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; -#endif } [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { + const uint buf_ib = loadc_b + l; + #ifdef MUL_MAT_ID - const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; - const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; - const uint ib = idx / 8; - const uint iqs = idx & 0x7; + const u16vec2 row_idx = row_ids[buf_ib]; + const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK; #else - const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; - const uint ib_outer = ib / 4; - const uint ib_inner = ib % 4; - - const uint iqs = loadr_b; + const uint ib = pos_b_ib + buf_ib * p.stride_b / BK; #endif + const uint iqs = loadr_b; - const uint buf_ib = loadc_b + l; - - if (iqs == 0) { - buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { + block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs); } - const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w; } barrier(); - pos_a_ib += 1; - pos_b_ib += 1; + pos_a_ib += BK_STEP; + pos_b_ib += BK_STEP; -#ifdef COOPMAT - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - const uint ib_a = warp_r * WM + cm_row * TM; + for (uint k_step = 0; k_step < BK_STEP; k_step++) { // Load from shared into cache - coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - - // TODO: only cache values that are actually needed - [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { - cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; - } - - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - const uint ib_b = warp_c * WN + cm_col * TN; - coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); - - // TODO: only cache values that are actually needed - [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { - cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; - } - - cm_result = coopmat(0); - cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); - } - - coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); - } - } -#else - // Load from shared into cache - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; - cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; - } - } - } + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint reg_ib = wsir * TM + cr; + const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; - cache_b_ds[cc] = buf_b_ds[ib]; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; + block_a_to_registers(reg_ib, k_step * BM + buf_ib); } } - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint cache_a_idx = wsir * TM + cr; - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - int32_t q_sum = 0; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], - cache_b_qs[cc * (BK / 4) + idx_k]); - } + const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + block_b_to_registers(ib); - sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint cache_a_idx = wsir * TM + cr; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + + sums[sums_idx] += mmq_dot_product(cache_a_idx); + } } } } } -#endif barrier(); } @@ -373,54 +269,6 @@ void main() { const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif -#ifdef COOPMAT -#ifdef MUL_MAT_ID - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < BN; col += storestride) { - const uint row_i = dc + cm_col * TN + col + store_c; - if (row_i >= _ne1) break; - - const u16vec2 row_idx = row_ids[row_i]; - - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } - } -#else - const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float - - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; - - if (is_aligned && is_in_bounds) { - // Full coopMat is within bounds and stride_d is aligned with 16B - coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); - coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); - } else if (is_in_bounds) { - // Full coopMat is within bounds, but stride_d is not aligned - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { - // Partial coopMat is within bounds - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { - data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } - } - } - } -#endif // MUL_MAT_ID -#else [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { @@ -431,19 +279,21 @@ void main() { const uint row_i = dc_warp + cc; if (row_i >= _ne1) break; - const u16vec2 row_idx = row_ids[row_i]; + const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr; #ifdef MUL_MAT_ID - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); + } #else if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); } #endif // MUL_MAT_ID } } } } -#endif // COOPMAT } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index fe71eb131c807..c0c03fedcc222 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -6,41 +6,89 @@ // Each iqs value maps to a 32-bit integer -#if defined(DATA_A_Q4_0) +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +// 2-byte loads for Q4_0 blocks (18 bytes) +// 4-byte loads for Q4_1 blocks (20 bytes) i32vec2 repack(uint ib, uint iqs) { - // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 - const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], - data_a[ib].qs[iqs * 2 + 1]); +#ifdef DATA_A_Q4_0 + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); return i32vec2( vui & 0x0F0F0F0F, (vui >> 4) & 0x0F0F0F0F); +#else // DATA_A_Q4_1 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +#endif } +#ifdef DATA_A_Q4_0 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); } +#else // DATA_A_Q4_1 +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} #endif -#if defined(DATA_A_Q4_1) -i32vec2 repack(uint ib, uint iqs) { - // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#ifdef DATA_A_Q4_0 + buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + } +#else // DATA_A_Q4_1 + buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + } +#endif } -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } } -#endif -#if defined(DATA_A_Q5_0) +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const uint32_t vui = cache_a[ib_a].qs[iqs]; + const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); + + const int32_t qs_b0 = cache_b.qs[iqs]; + const int32_t qs_b1 = cache_b.qs[iqs + 4]; + + q_sum += dotPacked4x8EXT(qs_a.x, qs_b0); + q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM + +#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +// 2-byte loads for Q5_0 blocks (22 bytes) +// 4-byte loads for Q5_1 blocks (24 bytes) i32vec2 repack(uint ib, uint iqs) { - // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 - const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], - data_a[ib].qs[iqs * 2 + 1]); + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); - const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); +#ifdef DATA_A_Q5_0 + const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); +#else // DATA_A_Q5_1 + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); +#endif const int32_t v0 = int32_t(vui & 0x0F0F0F0F) | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) @@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) { return i32vec2(v0, v1); } +#ifdef DATA_A_Q5_0 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); } +#else // DATA_A_Q5_1 +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} #endif -#if defined(DATA_A_Q5_1) -i32vec2 repack(uint ib, uint iqs) { - // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); - const int32_t v0 = int32_t(vui & 0x0F0F0F0F) - | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#ifdef DATA_A_Q5_0 + buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); - const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) - | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1])); + } +#else // DATA_A_Q5_1 + buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; - return i32vec2(v0, v1); + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].qh = data_a_packed32[ib].qh; + } +#endif } -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + cache_a[reg_ib].qh = buf_a[buf_ib].qh; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } } + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const uint32_t vui = cache_a[ib_a].qs[iqs]; + const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs)); + const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + const int32_t qs_b0 = cache_b.qs[iqs]; + const int32_t qs_b1 = cache_b.qs[iqs + 4]; + + q_sum += dotPacked4x8EXT(qs_a0, qs_b0); + q_sum += dotPacked4x8EXT(qs_a1, qs_b1); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q8_0) +// 2-byte loads for Q8_0 blocks (34 bytes) int32_t repack(uint ib, uint iqs) { - // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 - return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], - data_a[ib].qs[iqs * 2 + 1])); + return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1])); } ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * da * dsb.x); } + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + const int32_t qs_b = cache_b.qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, qs_b); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +#if defined(DATA_A_MXFP4) +// 1-byte loads for mxfp4 blocks (17 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + return i32vec2( quants & 0x0F0F0F0F, + (quants >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * dsb.x * float(q_sum)); +} + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + + buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])); + buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])); + + if (iqs == 0) { + buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].d = buf_a[buf_ib].d; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + + return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide +// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants +#if defined(DATA_A_Q2_K) +// 4-byte loads for Q2_K blocks (84 bytes) +int32_t repack(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303); +} + +uint8_t get_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + return data_a[ib_k].scales[iqs_k / 4]; +} + +ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); +} + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + // Repack 4x4 quants into one int + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303; + const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303; + const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303; + + buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + cache_a[reg_ib].scales = buf_a[buf_ib].scales; + + [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t sum_d = 0; + int32_t sum_m = 0; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const uint8_t scale = cache_a[ib_a].scales[iqs / 4]; + const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. + const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303); + + sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); + } + + return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +#if defined(DATA_A_Q3_K) +// 2-byte loads for Q3_K blocks (110 bytes) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint hm_idx = iqs * QUANT_R_MMQ; + const uint iqs_k = (ib % 8) * 8 + hm_idx; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + const uint hm_shift = iqs_k / 8; + + // Repack 2x4 quants into one int + // Add the 3rd bit instead of subtracting it to allow packing the quants + const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); + buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) | + (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4); + + if (iqs == 0) { + const uint is = iqs_k / 4; + const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | + (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))); + + buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + float result = 0.0; + int32_t q_sum = 0; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + // Subtract 4 from the quants to correct the 3rd bit offset + const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)); + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[0]) * float(q_sum); + q_sum = 0; + + [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { + const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)); + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); + + return ACC_TYPE(cache_b.ds.x * result); +} +#endif // MMQ_SHMEM +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y); +} + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; + + const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 16) / 8) * 4; + + // Repack 2x4 quants into one int +#if defined(DATA_A_Q4_K) + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F; + + buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4); +#else // defined(DATA_A_Q5_K) + const uint qh_idx = iqs * QUANT_R_MMQ; + const uint qh_shift = iqs_k / 8; + + buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4)); +#endif + + + if (iqs == 0) { + // Scale index + const uint is = iqs_k / 8; + u8vec2 scale_dm; + if (is < 4) { + scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); + } else { + scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), + (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); + } + + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { +#if defined(DATA_A_Q4_K) + const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F); +#else // defined(DATA_A_Q5_K) + const int32_t qs_a = cache_a[ib_a].qs[iqs]; +#endif + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +#ifdef MMQ_SHMEM +void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; + + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } + + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b[buf_ib].qs[iqs * 4 ] = values.x; + buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; +} + +void block_b_to_registers(const uint ib) { + cache_b.ds = buf_b[ib].ds; + [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + } +} +#endif + +#if defined(DATA_A_Q6_K) +// 2-byte loads for Q6_K blocks (210 bytes) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16; + const uint ql_shift = ((iqs_k % 32) / 16) * 4; + + const uint qh_idx = (iqs_k / 32) * 8 + iqs; + const uint qh_shift = ((iqs_k % 32) / 8) * 2; + + const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)); + + if (iqs == 0) { + const uint is = iqs_k / 4; + const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]); + + buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + float result = 0.0; + int32_t q_sum = 0; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[0]) * float(q_sum); + q_sum = 0; + + [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); + + return ACC_TYPE(cache_b.ds.x * result); +} +#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) @@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) { return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); } #endif + +#if defined(DATA_A_Q2_K) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + const uint ib_k = ib / 8; + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl new file mode 100644 index 0000000000000..72fec44049001 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -0,0 +1,78 @@ +#if defined(DATA_A_Q4_0) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_Q4_1) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q5_0) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + uint32_t qh; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_Q5_1) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + uint32_t qh; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q8_0) +#define QUANT_R_MMQ 1 +// AMD likes 4, Intel likes 1 and Nvidia likes 2 +#define BK_STEP 1 +struct block_a_cache { + int32_t qs[32/4]; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_MXFP4) +#define QUANT_R_MMQ 2 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE d; +}; +#elif defined(DATA_A_Q2_K) +#define QUANT_R_MMQ 4 +struct block_a_cache { + uint32_t qs[2]; + u8vec2 scales; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q3_K) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[4]; + FLOAT_TYPE_VEC2 d_scales; +}; +#elif defined(DATA_A_Q4_K) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[4]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q5_K) +#define QUANT_R_MMQ 1 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q6_K) +#define QUANT_R_MMQ 1 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE_VEC2 d_scales; +}; +#endif + +struct block_b_cache +{ + int32_t qs[8]; + FLOAT_TYPE_VEC2 ds; +}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 2fa54ce51fc83..02578c77c4f31 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -66,6 +66,7 @@ struct block_q4_0_packed16 #define QUANT_AUXF 1 #define A_TYPE block_q4_0 #define A_TYPE_PACKED16 block_q4_0_packed16 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q4_1 32 @@ -98,6 +99,7 @@ struct block_q4_1_packed32 #define A_TYPE block_q4_1 #define A_TYPE_PACKED16 block_q4_1_packed16 #define A_TYPE_PACKED32 block_q4_1_packed32 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q5_0 32 @@ -123,6 +125,7 @@ struct block_q5_0_packed16 #define QUANT_AUXF 1 #define A_TYPE block_q5_0 #define A_TYPE_PACKED16 block_q5_0_packed16 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q5_1 32 @@ -158,6 +161,7 @@ struct block_q5_1_packed32 #define A_TYPE block_q5_1 #define A_TYPE_PACKED16 block_q5_1_packed16 #define A_TYPE_PACKED32 block_q5_1_packed32 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q8_0 32 @@ -186,6 +190,7 @@ struct block_q8_0_packed32 #define A_TYPE block_q8_0 #define A_TYPE_PACKED16 block_q8_0_packed16 #define A_TYPE_PACKED32 block_q8_0_packed32 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q8_1 32 @@ -226,21 +231,21 @@ struct block_q2_K { uint8_t scales[QUANT_K_Q2_K/16]; uint8_t qs[QUANT_K_Q2_K/4]; - f16vec2 d; + f16vec2 dm; }; struct block_q2_K_packed16 { uint16_t scales[QUANT_K_Q2_K/16/2]; uint16_t qs[QUANT_K_Q2_K/4/2]; - f16vec2 d; + f16vec2 dm; }; struct block_q2_K_packed32 { uint32_t scales[QUANT_K_Q2_K/16/4]; uint32_t qs[QUANT_K_Q2_K/4/4]; - f16vec2 d; + f16vec2 dm; }; #if defined(DATA_A_Q2_K) @@ -249,6 +254,8 @@ struct block_q2_K_packed32 #define A_TYPE block_q2_K #define A_TYPE_PACKED16 block_q2_K_packed16 #define A_TYPE_PACKED32 block_q2_K_packed32 +#define SCALES_PER_32 2 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q3_K 256 @@ -274,27 +281,28 @@ struct block_q3_K_packed16 #define QUANT_R 1 #define A_TYPE block_q3_K #define A_TYPE_PACKED16 block_q3_K_packed16 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q4_K 256 struct block_q4_K { - f16vec2 d; + f16vec2 dm; uint8_t scales[3*QUANT_K_Q4_K/64]; uint8_t qs[QUANT_K_Q4_K/2]; }; struct block_q4_K_packed16 { - f16vec2 d; + f16vec2 dm; uint16_t scales[3*QUANT_K_Q4_K/64/2]; uint16_t qs[QUANT_K_Q4_K/2/2]; }; struct block_q4_K_packed32 { - f16vec2 d; + f16vec2 dm; uint32_t scales[3*QUANT_K_Q4_K/64/4]; uint32_t qs[QUANT_K_Q4_K/2/4]; }; @@ -310,13 +318,14 @@ struct block_q4_K_packed128 #define A_TYPE block_q4_K #define A_TYPE_PACKED16 block_q4_K_packed16 #define A_TYPE_PACKED32 block_q4_K_packed32 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q5_K 256 struct block_q5_K { - f16vec2 d; + f16vec2 dm; uint8_t scales[12]; uint8_t qh[QUANT_K_Q5_K/8]; uint8_t qs[QUANT_K_Q5_K/2]; @@ -324,12 +333,20 @@ struct block_q5_K struct block_q5_K_packed16 { - f16vec2 d; + f16vec2 dm; uint16_t scales[12/2]; uint16_t qh[QUANT_K_Q5_K/8/2]; uint16_t qs[QUANT_K_Q5_K/2/2]; }; +struct block_q5_K_packed32 +{ + f16vec2 dm; + uint32_t scales[12/4]; + uint32_t qh[QUANT_K_Q5_K/8/4]; + uint32_t qs[QUANT_K_Q5_K/2/4]; +}; + struct block_q5_K_packed128 { uvec4 q5k[11]; @@ -340,6 +357,8 @@ struct block_q5_K_packed128 #define QUANT_R 1 #define A_TYPE block_q5_K #define A_TYPE_PACKED16 block_q5_K_packed16 +#define A_TYPE_PACKED32 block_q5_K_packed32 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q6_K 256 @@ -356,7 +375,7 @@ struct block_q6_K_packed16 { uint16_t ql[QUANT_K_Q6_K/2/2]; uint16_t qh[QUANT_K_Q6_K/4/2]; - int8_t scales[QUANT_K_Q6_K/16]; + int16_t scales[QUANT_K_Q6_K/16/2]; float16_t d; }; @@ -365,6 +384,7 @@ struct block_q6_K_packed16 #define QUANT_R 1 #define A_TYPE block_q6_K #define A_TYPE_PACKED16 block_q6_K_packed16 +#define DATA_A_QUANT_K #endif // IQuants @@ -1363,18 +1383,11 @@ struct block_mxfp4 uint8_t qs[QUANT_K_MXFP4/2]; }; -//struct block_mxfp4_packed16 -//{ -// uint8_t e; -// uint16_t qs[QUANT_K_MXFP4/2/2]; -//}; - #if defined(DATA_A_MXFP4) #define QUANT_K QUANT_K_MXFP4 #define QUANT_R QUANT_R_MXFP4 #define QUANT_AUXF 1 #define A_TYPE block_mxfp4 -//#define A_TYPE_PACKED16 block_mxfp4_packed16 #endif #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) @@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize) #endif #if defined(DATA_A_MXFP4) -const FLOAT_TYPE kvalues_mxfp4_const[16] = { - FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f), - FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f) +const int8_t kvalues_mxfp4_const[16] = { + int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), + int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), }; -shared FLOAT_TYPE kvalues_mxfp4[16]; +shared int8_t kvalues_mxfp4[16]; #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 0f25ba3453093..1668f5ea4e7b4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) { + // Integer dot mmq performs better with f32 accumulators + if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif