diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d3fb19048d9..bead686a4b3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -49,6 +49,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include #include #include +#include #include #include #include @@ -254,6 +255,8 @@ static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backe static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); +// Forward decl: needed by the transposed-A pipeline-selection guards below. +static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer); static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { /* .get_name = */ ggml_backend_vk_buffer_type_name, /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, @@ -700,6 +703,8 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_COUNT]; vk_matmul_pipeline pipeline_matmul_id_f32 {}; vk_matmul_pipeline pipeline_matmul_id_bf16 {}; @@ -888,6 +893,7 @@ struct vk_device_struct { bool disable_host_visible_vidmem; bool allow_sysmem_fallback; bool disable_graph_optimize; + bool transpose_a; std::unique_ptr memory_logger; @@ -1006,6 +1012,7 @@ struct vk_mat_mat_push_constants { #define MAT_VEC_FUSION_FLAGS_BIAS1 0x2 #define MAT_VEC_FUSION_FLAGS_SCALE0 0x4 #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 +#define MAT_VEC_FUSION_FLAGS_TRANSPOSE_A 0x10 struct vk_mat_vec_push_constants { uint32_t ncols; @@ -2027,6 +2034,12 @@ struct ggml_backend_vk_buffer_context { vk_buffer dev_buffer; std::string name; + // Tensors actually repacked into transposed-A layout by set_tensor. + // Other upload paths (set_tensor_async chunked uploads, set_tensor_2d, + // buffer_from_host_ptr) bypass the repack, so pipeline selection must + // consult this set instead of inferring from type/name/shape. + std::unordered_set transposed_a_tensors; + ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : device(device), dev_buffer(dev_buffer), @@ -3865,8 +3878,18 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + if (device->transpose_a) { + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q4_K], matmul_q4_k_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + if (device->transpose_a) { + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q5_K], matmul_q5_k_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + if (device->transpose_a) { + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q6_K], matmul_q6_k_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q5_1], matmul_q5_1_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3889,8 +3912,18 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + if (device->transpose_a) { + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32_transa, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + if (device->transpose_a) { + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32_transa, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + if (device->transpose_a) { + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32_transa, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32_transa, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3926,6 +3959,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + if (device->transpose_a) { + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + } CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); @@ -3990,8 +4029,18 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + if (device->transpose_a) { + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q4_K], matmul_q4_k_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + } CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + if (device->transpose_a) { + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q5_K], matmul_q5_k_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + } CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + if (device->transpose_a) { + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q6_K], matmul_q6_k_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q5_1], matmul_q5_1_f32_transa, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + } CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -4039,6 +4088,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + if (device->transpose_a) { + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32_transa, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32_transa, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32_transa, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32_transa, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + } CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -4085,6 +4140,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + if (device->transpose_a) { + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q4_K], matmul_id_q4_k_f32_transa, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_K], matmul_id_q5_k_f32_transa, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q6_K], matmul_id_q6_k_f32_transa, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_1], matmul_id_q5_1_f32_transa, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + } CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4159,8 +4220,18 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + if (device->transpose_a) { + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32_transa, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + } CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + if (device->transpose_a) { + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32_transa, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + } CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + if (device->transpose_a) { + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32_transa, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_transa[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32_transa, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + } CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -4206,6 +4277,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + if (device->transpose_a) { + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32_transa, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32_transa, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32_transa, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32_transa, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + } CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -4234,6 +4311,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + if (device->transpose_a) { + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32_transa, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32_transa, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32_transa, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_transa[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32_transa, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + } CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -5058,6 +5141,8 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE"); device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr; + device->transpose_a = getenv("GGML_VK_NO_TRANSPOSE_A") == nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -7769,6 +7854,29 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); } + // Use transposed-A pipeline only for tensors that set_tensor actually repacked. + bool src0_transposed_a = false; + if (ctx->device->transpose_a && src0->buffer != nullptr && ggml_backend_buffer_is_vk(src0->buffer)) { + ggml_backend_vk_buffer_context * src0_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + src0_transposed_a = src0_ctx->transposed_a_tensors.count(src0) != 0; + } + if (src0_transposed_a && !qx_needs_dequant) { + vk_matmul_pipeline2 & transa = ctx->device->pipeline_dequant_mul_mat_mat_transa[src0->type]; + if (ctx->device->coopmat_support) { + vk_matmul_pipeline candidate = (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && (ggml_prec)dst->op_params[0] == GGML_PREC_DEFAULT) ? transa.f16acc : transa.f32acc; + if (candidate && !candidate->is_empty()) { + mmp = candidate; + quantize_y = false; + } + } else { + vk_matmul_pipeline candidate = (ctx->device->fp16 && (ggml_prec)dst->op_params[0] == GGML_PREC_DEFAULT) ? transa.f16acc : transa.f32acc; + if (candidate && !candidate->is_empty()) { + mmp = candidate; + quantize_y = false; + } + } + } + // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT @@ -8220,6 +8328,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& uint32_t fusion_flags = 0; + // Same guard as in ggml_vk_mul_mat_q_f16. + bool src0_transposed_a_mv = false; + if (ctx->device->transpose_a && !x_non_contig && src0->buffer != nullptr && ggml_backend_buffer_is_vk(src0->buffer)) { + ggml_backend_vk_buffer_context * src0_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + src0_transposed_a_mv = src0_ctx->transposed_a_tensors.count(src0) != 0; + } + if (src0_transposed_a_mv) { + fusion_flags |= MAT_VEC_FUSION_FLAGS_TRANSPOSE_A; + } + vk_subbuffer d_F0 = d_D; if (ctx->num_additional_fused_ops > 0) { const ggml_tensor * add = cgraph->nodes[node_idx + 1]; @@ -8606,6 +8724,29 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); } + // Use transposed-A pipeline only for tensors that set_tensor actually repacked. + bool src0_transposed_a = false; + if (ctx->device->transpose_a && src0->buffer != nullptr && ggml_backend_buffer_is_vk(src0->buffer)) { + ggml_backend_vk_buffer_context * src0_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + src0_transposed_a = src0_ctx->transposed_a_tensors.count(src0) != 0; + } + if (src0_transposed_a && !qx_needs_dequant) { + vk_matmul_pipeline2 & transa = ctx->device->pipeline_dequant_mul_mat_mat_id_transa[src0->type]; + if (ctx->device->coopmat_support) { + vk_matmul_pipeline candidate = (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && (ggml_prec)dst->op_params[0] == GGML_PREC_DEFAULT) ? transa.f16acc : transa.f32acc; + if (candidate && !candidate->is_empty()) { + mmp = candidate; + quantize_y = false; + } + } else { + vk_matmul_pipeline candidate = (ctx->device->fp16 && (ggml_prec)dst->op_params[0] == GGML_PREC_DEFAULT) ? transa.f16acc : transa.f32acc; + if (candidate && !candidate->is_empty()) { + mmp = candidate; + quantize_y = false; + } + } + } + // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT @@ -8995,6 +9136,16 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte uint32_t fusion_flags = 0; + // Same guard as in ggml_vk_mul_mat_id_q_f16. + bool src0_transposed_a_mv = false; + if (ctx->device->transpose_a && !x_non_contig && src0->buffer != nullptr && ggml_backend_buffer_is_vk(src0->buffer)) { + ggml_backend_vk_buffer_context * src0_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + src0_transposed_a_mv = src0_ctx->transposed_a_tensors.count(src0) != 0; + } + if (src0_transposed_a_mv) { + fusion_flags |= MAT_VEC_FUSION_FLAGS_TRANSPOSE_A; + } + if (ctx->num_additional_fused_ops > 0) { const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1]; @@ -9156,7 +9307,12 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co return supported; } +static void ggml_vk_ensure_non_transposed(const ggml_tensor * tensor); + static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst) { + ggml_vk_ensure_non_transposed(k); + ggml_vk_ensure_non_transposed(v); + VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; @@ -10498,7 +10654,51 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } +// Revert a transposed-A tensor back to its original layout on the GPU. +// Called by ops that lack transa-aware shaders (e.g. GET_ROWS, FLASH_ATTN_EXT, CPY). +// For non-transposed tensors this is a no-op (zero overhead). +static void ggml_vk_ensure_non_transposed(const ggml_tensor * tensor) { + if (!tensor->buffer || !ggml_backend_buffer_is_vk(tensor->buffer)) { + return; + } + ggml_backend_vk_buffer_context * buf_ctx = + (ggml_backend_vk_buffer_context *)tensor->buffer->context; + if (!buf_ctx->transposed_a_tensors.count(tensor)) { + return; + } + + const size_t block_size = ggml_type_size(tensor->type); + const int64_t n_rows = tensor->ne[1]; + const int64_t blocks_per_row = tensor->ne[0] / ggml_blck_size(tensor->type); + const int64_t n_experts = tensor->ne[2]; + const size_t expert_blocks = n_rows * blocks_per_row; + const size_t expert_size = expert_blocks * block_size; + const size_t total_size = n_experts * expert_size; + + std::vector transposed(total_size); + vk_buffer buf = buf_ctx->dev_buffer; + const uint64_t buf_off = vk_tensor_offset(tensor) + tensor->view_offs; + ggml_vk_buffer_read(buf, buf_off, transposed.data(), total_size); + + std::vector original(total_size); + for (int64_t e = 0; e < n_experts; e++) { + const size_t eo = e * expert_size; + for (int64_t row = 0; row < n_rows; row++) { + for (int64_t kb = 0; kb < blocks_per_row; kb++) { + memcpy(original.data() + eo + (row * blocks_per_row + kb) * block_size, + transposed.data() + eo + (kb * n_rows + row) * block_size, + block_size); + } + } + } + + ggml_vk_buffer_write(buf, buf_off, original.data(), total_size); + buf_ctx->transposed_a_tensors.erase(tensor); +} + static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_vk_ensure_non_transposed(src0); + const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); @@ -11178,6 +11378,8 @@ static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subct } static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + ggml_vk_ensure_non_transposed(src0); + uint32_t ne = (uint32_t)ggml_nelements(src0); if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { // Convert from number of logical elements to 2- or 4-byte units. @@ -13844,6 +14046,19 @@ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, g ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size); } +static bool ggml_vk_has_transa_pipeline(const vk_device & device, ggml_type type) { + if (type >= GGML_TYPE_COUNT) { + return false; + } + auto non_empty = [](const vk_matmul_pipeline & p) { + return p && !p->is_empty(); + }; + const vk_matmul_pipeline2 & mm = device->pipeline_dequant_mul_mat_mat_transa[type]; + const vk_matmul_pipeline2 & mid = device->pipeline_dequant_mul_mat_mat_id_transa[type]; + return non_empty(mm.f16acc) || non_empty(mm.f32acc) + || non_empty(mid.f16acc) || non_empty(mid.f32acc); +} + static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; @@ -13853,6 +14068,55 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml return; } + // Repack weight blocks from [row, k_block] to [k_block, row] order. + // Token embedding is excluded: it is used as a lookup table, not a matmul operand. + // Guard: only repack when a transa pipeline actually exists for this dtype, otherwise + // dispatch will use a non-transa pipeline and read the data with the wrong layout. + // Views are excluded: graph_copy reads via get_tensor on the parent tensor, not the + // view, so the un-transpose in get_tensor would never trigger (the parent pointer + // is not in transposed_a_tensors). Real model weights are never views. + auto dev = buf_ctx->device.lock(); + if (dev && dev->transpose_a && offset == 0 && tensor->ne[3] == 1 + && tensor->view_src == nullptr + && strstr(tensor->name, "token_embd") == nullptr + && ggml_vk_has_transa_pipeline(dev, tensor->type)) { + const size_t block_size = ggml_type_size(tensor->type); + const int64_t n_rows = tensor->ne[1]; + const int64_t blocks_per_row = tensor->ne[0] / ggml_blck_size(tensor->type); + const size_t total_blocks = n_rows * blocks_per_row; + const int64_t n_experts = tensor->ne[2]; + const size_t expert_size = total_blocks * block_size; + + if (size == n_experts * expert_size) { + std::vector transposed(size); + const uint8_t * src = (const uint8_t *)data; + uint8_t * dst = transposed.data(); + + for (int64_t e = 0; e < n_experts; e++) { + const size_t expert_offset = e * expert_size; + for (int64_t row = 0; row < n_rows; row++) { + for (int64_t kb = 0; kb < blocks_per_row; kb++) { + memcpy(dst + expert_offset + (kb * n_rows + row) * block_size, + src + expert_offset + (row * blocks_per_row + kb) * block_size, + block_size); + } + } + } + + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs, transposed.data(), size); + buf_ctx->transposed_a_tensors.insert(tensor); + return; + } + } + + // When writing to a view, invalidate any transpose record on its parent. + // The test framework creates a base tensor (transposed in set_tensor above), + // then writes the view with non-transposed data. Without this erasure, + // get_tensor would un-transpose the already-non-transposed view region. + if (tensor->view_src != nullptr) { + buf_ctx->transposed_a_tensors.erase(tensor->view_src); + } + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } @@ -13867,6 +14131,7 @@ static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, g return; } + buf_ctx->transposed_a_tensors.erase(tensor); ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies); } @@ -13881,6 +14146,35 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons vk_buffer buf = buf_ctx->dev_buffer; ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); + + // Un-transpose only tensors that set_tensor actually repacked. + if (buf_ctx->transposed_a_tensors.count(tensor) && offset == 0) { + const size_t block_size = ggml_type_size(tensor->type); + const int64_t n_rows = tensor->ne[1]; + const int64_t blocks_per_row = tensor->ne[0] / ggml_blck_size(tensor->type); + const int64_t n_experts = tensor->ne[2]; + const size_t expert_blocks = n_rows * blocks_per_row; + const size_t expert_size = expert_blocks * block_size; + + if (size == n_experts * expert_size) { + std::vector original(size); + const uint8_t * src = (const uint8_t *)data; + uint8_t * dst = original.data(); + + for (int64_t e = 0; e < n_experts; e++) { + const size_t expert_offset = e * expert_size; + for (int64_t row = 0; row < n_rows; row++) { + for (int64_t kb = 0; kb < blocks_per_row; kb++) { + memcpy(dst + expert_offset + (row * blocks_per_row + kb) * block_size, + src + expert_offset + (kb * n_rows + row) * block_size, + block_size); + } + } + } + + memcpy(data, original.data(), size); + } + } } static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, @@ -13922,6 +14216,7 @@ static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, cons static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + ctx->transposed_a_tensors.clear(); ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 2271be4021b..1fe8533ec1e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -47,9 +47,13 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); } #endif + const bool transa_rt = (p.fusion_flags & MAT_VEC_FUSION_FLAGS_TRANSPOSE_A) != 0; + const uint a_kb = col / QUANT_K; uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib = (ibi + col)/QUANT_K; // block index + const uint ib = transa_rt + ? (a_kb * p.stride_d + (first_row + n)) + : ((ibi + col)/QUANT_K); ibi += p.ncols; #if K_PER_ITER == 8 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index e8d053cdd43..e70c44b751c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -4,6 +4,7 @@ #define MAT_VEC_FUSION_FLAGS_BIAS1 0x2 #define MAT_VEC_FUSION_FLAGS_SCALE0 0x4 #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 +#define MAT_VEC_FUSION_FLAGS_TRANSPOSE_A 0x10 layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; #if defined(A_TYPEV4) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 93fbacc6282..e9bb9f0e817 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -12,8 +12,11 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint y1_idx = i * QUANT_K + y_offset; const uint y2_idx = y1_idx + 128; + const bool transpose_a = (p.fusion_flags & MAT_VEC_FUSION_FLAGS_TRANSPOSE_A) != 0; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; + const uint ib0 = transpose_a ? (a_offset + i * (p.stride_d - 1) + (first_row+n)) + : (a_offset + (first_row+n)*num_blocks_per_row); const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 54d7e1bcdca..f01ba970cc6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -12,8 +12,11 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint y1_idx = i * QUANT_K + y_offset; const uint y2_idx = y1_idx + 128; + const bool transpose_a = (p.fusion_flags & MAT_VEC_FUSION_FLAGS_TRANSPOSE_A) != 0; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; + const uint ib0 = transpose_a ? (a_offset + i * (p.stride_d - 1) + (first_row+n)) + : (a_offset + (first_row+n)*num_blocks_per_row); const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 3e89d91cbb0..e4eb9fb12fb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -14,8 +14,11 @@ uint csel = 0; void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { const uint y_idx = i * QUANT_K + y_offset; + const bool transpose_a = (p.fusion_flags & MAT_VEC_FUSION_FLAGS_TRANSPOSE_A) != 0; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; + const uint ib0 = transpose_a ? (a_offset + i * (p.stride_d - 1) + (first_row+n)) + : (a_offset + (first_row+n)*num_blocks_per_row); csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 6fe3e2dc043..95cc08a35b2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -66,8 +66,15 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const #endif uint ibi = first_row*p.ncols; + const bool transpose_a = (p.fusion_flags & MAT_VEC_FUSION_FLAGS_TRANSPOSE_A) != 0; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset; +#if defined(DATA_A_QUANT_K) + const uint a_block_idx = (transpose_a ? ((col / QUANT_K) * p.stride_d + (first_row + n)) * (QUANT_K / QUANT_K_Q8_1) + ((col / QUANT_K_Q8_1) & ((QUANT_K / QUANT_K_Q8_1) - 1)) + : (ibi + col)/QUANT_K_Q8_1) + a_offset; +#else + const uint a_block_idx = (transpose_a ? (col / QUANT_K_Q8_1) * p.stride_d + first_row + n + : (ibi + col)/QUANT_K_Q8_1) + a_offset; +#endif ibi += p.ncols; temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 89346e48e06..fb8867a68f3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -242,7 +242,11 @@ void main() { #else batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) + #endif +#ifdef TRANSPOSE_A + 0; +#else (ir * BM * p.stride_a + start_k) / LOAD_VEC_A; +#endif #ifdef MUL_MAT_ID uint pos_b = 0; #else @@ -286,7 +290,9 @@ void main() { barrier(); +#ifndef TRANSPOSE_A pos_a += BK / LOAD_VEC_A; +#endif pos_b += BK / LOAD_VEC_B; #ifdef COOPMAT diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 73595168984..6e0853c576c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -1,7 +1,17 @@ +#ifdef TRANSPOSE_A +#define QUANT_IDX_A \ + const uint _qklva = QUANT_K / LOAD_VEC_A; \ + const uint _k_elem = block + row * LOAD_VEC_A; \ + const uint idx = pos_a + (_k_elem / QUANT_K) * p.M * _qklva + idx_m * _qklva + ((_k_elem / LOAD_VEC_A) % _qklva); +#else +#define QUANT_IDX_A \ + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; +#endif + void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) { #if defined(DATA_A_F32) || defined(DATA_A_F16) #if LOAD_VEC_A == 8 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); buf_a[buf_idx ] = aa[0].xy; @@ -9,7 +19,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 2] = aa[1].xy; buf_a[buf_idx + 3] = aa[1].zw; #elif LOAD_VEC_A == 4 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); buf_a[buf_idx ] = aa.xy; @@ -28,7 +38,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #endif #elif defined(DATA_A_BF16) #if LOAD_VEC_A == 4 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); buf_a[buf_idx ] = aa.xy; @@ -46,7 +56,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin } #endif #elif defined(DATA_A_Q4_0) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; @@ -62,7 +72,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 8] = FLOAT_TYPEV2(v1.xy); buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q4_1) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; @@ -78,7 +88,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 8 ] = FLOAT_TYPEV2(v1.xy); buf_a[buf_idx + 9 ] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q5_0) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; @@ -95,7 +105,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPEV2(v.xz); buf_a[buf_idx + 8] = FLOAT_TYPEV2(v.yw); #elif defined(DATA_A_Q5_1) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; @@ -117,7 +127,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 8] = FLOAT_TYPEV2(v0.yw); buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.yw); #elif defined(DATA_A_Q8_0) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 8; @@ -145,7 +155,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 2] = FLOAT_TYPEV2((bits & 0x10u) != 0u ? d : -d, (bits & 0x20u) != 0u ? d : -d); buf_a[buf_idx + 3] = FLOAT_TYPEV2((bits & 0x40u) != 0u ? d : -d, (bits & 0x80u) != 0u ? d : -d); #elif defined(DATA_A_Q2_K) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 64; // 4 values per idx @@ -164,7 +174,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_Q3_K) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 128; // 2 values per idx @@ -188,7 +198,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx] = FLOAT_TYPEV2(dl * (qs.x - hm.x), dl * (qs.y - hm.y)); #elif defined(DATA_A_Q4_K) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 64; // 4 values per idx @@ -224,7 +234,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q5_K) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 64; // 4 values per idx @@ -263,7 +273,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q6_K) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 128; // 2 values per idx @@ -285,7 +295,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx] = FLOAT_TYPEV2(q.x, q.y); #elif defined(DATA_A_IQ1_S) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 32; // 8 values per idx @@ -304,7 +314,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ1_M) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 32; // 8 values per idx @@ -326,7 +336,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ2_XXS) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 32; // 8 values per idx @@ -357,7 +367,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 32; // 8 values per idx @@ -383,7 +393,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 32; // 8 values per idx @@ -411,7 +421,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 64; // 4 values per idx @@ -435,7 +445,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 64; // 4 values per idx @@ -457,7 +467,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint ib = idx / 64; // 4 values per idx @@ -475,7 +485,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_IQ4_NL) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; @@ -489,7 +499,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 8] = d * FLOAT_TYPEV2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], kvalues_iq4nl[vui >> 12]); #elif defined(DATA_A_MXFP4) - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + QUANT_IDX_A const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e3a9d61a558..7eba6582d03 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -588,6 +588,14 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } + // TRANSPOSE_A variants + if (!coopmat2 && (tname == "q4_k" || tname == "q5_k" || tname == "q6_k" || tname == "q5_1")) { + string_to_spv(shader_name + "_" + tname + "_f32_transa", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"TRANSPOSE_A", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_transa_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}, {"TRANSPOSE_A", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_transa", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"TRANSPOSE_A", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_transa_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}, {"TRANSPOSE_A", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) // Integer dot mmq performs better with f32 accumulators if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {