Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-context.m
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *

// enter here only when capturing in order to wait for all computation to finish
// otherwise, we leave the graph to compute asynchronously
if (!use_capture && ctx->capture_started) {
if (use_capture && ctx->capture_started) {
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
Expand Down Expand Up @@ -606,6 +606,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *

[ctx->capture_scope endScope];
[[MTLCaptureManager sharedCaptureManager] stopCapture];

ctx->capture_started = false;
}
}

Expand Down
4 changes: 3 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,10 +1470,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l

const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);

const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0];
const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;

snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
Expand All @@ -1482,6 +1483,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
ggml_metal_cv_set_bool (cv, is_cb, FC_BIN + 3);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

Expand Down
4 changes: 1 addition & 3 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3180,9 +3180,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);

if (pipeline.cnt) {
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);

ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

Expand Down
11 changes: 7 additions & 4 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,7 @@ template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_un
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];

template <typename T0, typename T1, typename T>
kernel void kernel_bin_fuse_impl(
Expand All @@ -1124,11 +1125,12 @@ kernel void kernel_bin_fuse_impl(
#define FC_OP FC_bin_op
#define FC_F FC_bin_f
#define FC_RB FC_bin_rb
#define FC_CB FC_bin_cb

if (FC_RB) {
// row broadcast
const uint i0 = tgpig.x;
const uint i1 = i0%args.ne10;
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;

device const T0 * src0_row = (device const T0 *) (src0);
device T * dst_row = (device T *) (dst);
Expand Down Expand Up @@ -1200,7 +1202,7 @@ kernel void kernel_bin_fuse_impl(
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);

for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
const int i10 = FC_CB ? i0%args.ne10 : i0;

if (FC_OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
Expand All @@ -1225,7 +1227,7 @@ kernel void kernel_bin_fuse_impl(
}

for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
const int i10 = FC_CB ? i0%args.ne10 : i0;

T res = src0_ptr[i0];

Expand Down Expand Up @@ -1261,6 +1263,7 @@ kernel void kernel_bin_fuse_impl(
#undef FC_OP
#undef FC_F
#undef FC_RB
#undef FC_CB
}

typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
Expand Down
Loading