Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
char name[256];

snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_ne02=%d", base, ne02);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,13 @@ layout (constant_id = 10) const uint WARP = 32;

#include "mul_mmq_shmem_types.glsl"

#ifdef MUL_MAT_ID
#define BK_STEP 1
#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
#endif

// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct block_a_cache {
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
#define BK_STEP 1
// #define BK_STEP 1
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
Expand Down
3 changes: 3 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6880,6 +6880,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));

// gpt-oss issue with Vulkan mmq_id
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));

for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) {
Expand Down
Loading