Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 61 additions & 17 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1392,34 +1392,78 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
GGML_UNUSED(op);
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
ggml_metal_library_t lib,
ggml_op op,
int32_t n_fuse,
bool row) {
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
char base[256];
char name[256];

const char * op_str = "undefined";
switch (op) {
case GGML_OP_ADD: op_str = "add"; break;
case GGML_OP_SUB: op_str = "sub"; break;
case GGML_OP_MUL: op_str = "mul"; break;
case GGML_OP_DIV: op_str = "div"; break;
int op_num = -1;

switch (op->op) {
case GGML_OP_ADD: op_num = 0; break;
case GGML_OP_SUB: op_num = 1; break;
case GGML_OP_MUL: op_num = 2; break;
case GGML_OP_DIV: op_num = 3; break;
default: GGML_ABORT("fatal error");
};

if (row) {
snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
} else {
snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t1_str = ggml_type_name(op->src[1]->type);
const char * t_str = ggml_type_name(op->type);

const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);

const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;

snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);
}

snprintf(name, 256, "%s", base);
res.c4 = is_c4;
res.cnt = is_rb;

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
char base[256];
char name[256];

int op_num = -1;

switch (op) {
case GGML_OP_ADD: op_num = 0; break;
case GGML_OP_SUB: op_num = 1; break;
case GGML_OP_MUL: op_num = 2; break;
case GGML_OP_DIV: op_num = 3; break;
default: GGML_ABORT("fatal error");
};

snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1);
ggml_metal_cv_set_bool (cv, false, FC_BIN + 2);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);
}

return res;
Expand Down
6 changes: 5 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params {
int nr1;

size_t smem;

bool c4;
bool cnt;
};

int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
Expand Down Expand Up @@ -134,7 +137,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
Expand Down
10 changes: 7 additions & 3 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta

struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nsg =*/ 0,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
/*.c4 =*/ false,
/*.cnt =*/ false,
};

res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
Expand All @@ -362,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nsg =*/ 0,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
/*.c4 =*/ false,
/*.cnt =*/ false,
};

[lib->lock lock];
Expand Down Expand Up @@ -1054,7 +1058,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
#define FC_SSM_CONV 900
#define FC_SOLVE_TRI 1000
#define FC_COUNT_EQUAL 1100
#define FC_BIN 1200

// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
Expand Down
33 changes: 14 additions & 19 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
/*.o1 =*/ { 0 },
};

auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
Expand Down Expand Up @@ -2895,8 +2895,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));

bool bcast_row = false;

ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
Expand Down Expand Up @@ -2990,18 +2988,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {

struct ggml_metal_pipeline_with_params pipeline;

if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));

// src1 is a row
GGML_ASSERT(ne11 == 1);

pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);

bcast_row = true;
} else {
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
}
pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);

if (n_fuse > 1) {
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
Expand All @@ -3015,20 +3002,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
}
}

if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne10 = ne10/4;
args.ne0 = ne0/4;
}

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);

if (bcast_row) {
const int64_t n = ggml_nelements(op)/4;
if (pipeline.cnt) {
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);

ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
} else {
int nth = 32;
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

int nth = 1;

while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
while (2*nth < args.ne0 && nth < nth_max) {
nth *= 2;
}

Expand Down
Loading
Loading