From c9d71476a643d738897fcb7458d0456c97ff00ef Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Thu, 30 Oct 2025 19:11:44 +0000 Subject: [PATCH 1/2] vulkan: fix shmem overrun in mmq id shader --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 4 ++++ ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl | 2 +- tests/test-backend-ops.cpp | 3 +++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 8b238ac4bc1..d955b4fc7af 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32; #include "mul_mmq_shmem_types.glsl" +#ifdef MUL_MAT_ID +#define BK_STEP 1 +#else #ifndef BK_STEP #define BK_STEP 4 #endif +#endif // Shared memory cache shared block_a_cache buf_a[BM * BK_STEP]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index 72fec440490..1c0f5306f38 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -27,7 +27,7 @@ struct block_a_cache { #elif defined(DATA_A_Q8_0) #define QUANT_R_MMQ 1 // AMD likes 4, Intel likes 1 and Nvidia likes 2 -#define BK_STEP 1 +// #define BK_STEP 1 struct block_a_cache { int32_t qs[32/4]; FLOAT_TYPE dm; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 92361d6f0f4..fa12c06ccdd 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6880,6 +6880,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1)); test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3)); + // gpt-oss issue with Vulkan mmq_id + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); + for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (int n_mats : {4, 8}) { From dfd8ec07d72b5922d2edd4ca13394083f94f37be Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Oct 2025 23:06:56 +0200 Subject: [PATCH 2/2] metal : fix mul_mm_id --- ggml/src/ggml-metal/ggml-metal-device.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 1a3c7873b74..5607deaf414 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_ char name[256]; snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); - snprintf(name, 256, "%s", base); + snprintf(name, 256, "%s_ne02=%d", base, ne02); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) {