Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
const ggml_type tsrc1 = op->src[1]->type;

const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;

constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;

const bool has_tensor = ggml_metal_device_get_props(ggml_metal_library_get_device(lib))->has_tensor;

const bool bc_out = has_tensor
? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0)
: (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0);

snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
Expand All @@ -694,8 +702,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
ggml_metal_cv_free(cv);
}

// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
res.smem = bc_out ? 8192 : 4096 + 2048;
if (has_tensor) {
res.nr0 = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; // 64
res.nr1 = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; // 32

const size_t NRA = res.nr0;
const size_t smem_a = NRA * N_MM_NK_TOTAL * sizeof(ggml_fp16_t);
res.smem = smem_a;
} else {
res.nr0 = 64;
res.nr1 = 32;

res.smem = bc_out ? 8192 : (4096 + 2048);
}

res.nsg = N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;

return res;
}
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev

void ggml_metal_library_free(ggml_metal_library_t lib);

ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib);

struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);

Expand Down
17 changes: 11 additions & 6 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi

struct ggml_metal_library {
id<MTLLibrary> obj;
id<MTLDevice> device;

ggml_metal_device_t dev;
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines

NSLock * lock;
Expand Down Expand Up @@ -251,7 +251,7 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));

res->obj = library;
res->device = device;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];

Expand Down Expand Up @@ -318,7 +318,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
}

res->obj = library;
res->device = device;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];

Expand All @@ -341,6 +341,10 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
free(lib);
}

ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib) {
return lib->dev;
}

struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
[lib->lock lock];

Expand Down Expand Up @@ -405,7 +409,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
return res;
}

id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
id<MTLDevice> device = ggml_metal_device_get_obj(lib->dev);
id<MTLComputePipelineState> obj = [device newComputePipelineStateWithFunction:mtl_function error:&error];

[mtl_function release];

Expand Down Expand Up @@ -699,7 +704,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
" auto sB = tB.slice(0, 0); \n"
" mm.run(sB, sA, cT); \n"
" \n"
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(16, 16)); \n"
" \n"
" cT.store(tC); \n"
"}";
Expand Down Expand Up @@ -749,7 +754,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
" auto sB = tB.slice(0, 0); \n"
" mm.run(sB, sA, cT); \n"
" \n"
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(16, 16)); \n"
" \n"
" cT.store(tC); \n"
"}";
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
//
// TODO: for optimal performance, become function of the device and work size

#define SZ_SIMDGROUP 16
#define N_MM_NK 2
#define N_MM_NK_TOTAL (SZ_SIMDGROUP * N_MM_NK)

#define N_MM_BLOCK_X 4
#define N_MM_BLOCK_Y 2
#define N_MM_SIMD_GROUP_X 2
#define N_MM_SIMD_GROUP_Y 2

#define N_R0_Q1_0 8
#define N_SG_Q1_0 2

Expand Down
7 changes: 6 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2195,7 +2195,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
const size_t smem = pipeline.smem;

ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);

const int nr0 = pipeline.nr0;
const int nr1 = pipeline.nr1;
const int nsg = pipeline.nsg;

ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + nr1 - 1) / nr1), ((ne01 + nr0 - 1) / nr0), ne12 * ne13, 32, nsg, 1);
} else {
auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);

Expand Down
Loading