Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 113 additions & 19 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];

int concur_list[GGML_MAX_NODES];
int concur_list_len;

// custom kernels
#define GGML_METAL_DECL_KERNEL(name) \
id<MTLFunction> function_##name; \
Expand Down Expand Up @@ -98,6 +101,7 @@ @implementation GGMLMetalClass
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0;
ctx->concur_list_len = 0;

// determine if we can use MPS
if (MPSSupportsMTLDevice(ctx->device)) {
Expand Down Expand Up @@ -355,11 +359,92 @@ void ggml_metal_get_tensor(
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
}

void ggml_metal_graph_find_concurrency(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
int nodes_unused[GGML_MAX_NODES];

for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
ctx->concur_list_len = 0;

int n_left = gf->n_nodes;
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos

while (n_left > 0) {
// number of nodes at a layer (that can be issued concurrently)
int concurrency = 0;
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
if (nodes_unused[i]) {
// if the requirements for gf->nodes[i] are satisfied
int exe_flag=1;
// scan all srcs
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
if (src_cur) {
// if is leaf nodes it's satisfied.
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}

// otherwise this src should be the output from previous nodes.
int is_found = 0;
// scan 2*search_depth back because we inserted barrier.
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
}
if (is_found == 0) {exe_flag = 0; break;}
}
}
if (exe_flag) {
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
int64_t data_start = (int64_t) gf->nodes[i]->data;
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
for (int j = n_start; j < i; j++) {
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
&& gf->nodes[j]->op != GGML_OP_VIEW \
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
continue;
} else {
exe_flag = 0;
}
}
}
}
if (exe_flag) {
ctx->concur_list[level_pos + concurrency] = i;
nodes_unused[i] = 0;
concurrency++;
ctx->concur_list_len++;
}
}
}
n_left -= concurrency;
// adding a barrier different layer
ctx->concur_list[level_pos + concurrency] = -1;
ctx->concur_list_len++;
// jump all sorted nodes at nodes_bak
while (!nodes_unused[n_start]) {n_start++;}
level_pos += concurrency + 1;
}

if (ctx->concur_list_len > GGML_MAX_NODES) {
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
}
}

void ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
metal_printf("%s: evaluating graph\n", __func__);

if (!ctx->concur_list_len) {
ggml_metal_graph_find_concurrency(ctx,gf);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this will break when computing graphs of different topologies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now we can assume that during the lifetime of metal ctx, the topology of graph won't change. In future, when we need to change the topology of graph and we also have a mechanism to tell the backend that the graph topology is changed, we can easily add the necessary code to address the updated topology.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For llama that's true, but there are other users of ggml.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe instead of letting ggml-metal.m automatically call ggml_metal_graph_find_concurrency, we let llama.cpp decide if we should call ggml_metal_graph_find_concurrency and set the metal_ctx->concur_list?

The logic on ggml-metal.m will be not so intrusive: If metal_ctx->concur_list is set then dispatch ops concurrently, otherwise fallback to the original code path.

This will add backend-specific code in llama.cpp for now, but I imagine that in future we can have a ggml_backend_graph_optimize() interface to unify them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. Eventually we will need a better solution, but for now that should do. @ggerganov

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel

Expand All @@ -378,7 +463,7 @@ void ggml_metal_graph_compute(
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);

for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
const int n_nodes_per_cb = (ctx->concur_list_len + n_cb - 1) / n_cb;

dispatch_async(queue, ^{
size_t offs_src0 = 0;
Expand All @@ -390,9 +475,18 @@ void ggml_metal_graph_compute(
id<MTLComputeCommandEncoder> encoder = nil;

const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;

for (int i = node_start; i < node_end; ++i) {
const int node_end = (cb_idx == n_cb - 1) ? ctx->concur_list_len : (cb_idx + 1) * n_nodes_per_cb;

for (int ind = node_start; ind < node_end; ++ind) {
int i = ctx->concur_list[ind];
if (i == -1) {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
continue;
}
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
}
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));

struct ggml_tensor * src0 = gf->nodes[i]->src[0];
Expand Down Expand Up @@ -463,7 +557,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

if (ggml_nelements(src1) == ne10) {
Expand All @@ -484,7 +578,7 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

if (ggml_nelements(src1) == ne10) {
Expand All @@ -505,7 +599,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

const float scale = *(const float *) src1->data;
Expand All @@ -522,7 +616,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SILU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

[encoder setComputePipelineState:ctx->pipeline_silu];
Expand All @@ -536,7 +630,7 @@ void ggml_metal_graph_compute(
case GGML_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

[encoder setComputePipelineState:ctx->pipeline_relu];
Expand All @@ -550,7 +644,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

[encoder setComputePipelineState:ctx->pipeline_gelu];
Expand All @@ -564,7 +658,7 @@ void ggml_metal_graph_compute(
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

const int nth = 32;
Expand All @@ -582,7 +676,7 @@ void ggml_metal_graph_compute(
case GGML_OP_DIAG_MASK_INF:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

const int n_past = ((int32_t *)(dst->op_params))[0];
Expand Down Expand Up @@ -645,7 +739,7 @@ void ggml_metal_graph_compute(
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

int nth0 = 32;
Expand Down Expand Up @@ -772,7 +866,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

switch (src0->type) {
Expand Down Expand Up @@ -801,7 +895,7 @@ void ggml_metal_graph_compute(
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

const float eps = 1e-6f;
Expand All @@ -823,7 +917,7 @@ void ggml_metal_graph_compute(
case GGML_OP_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

const float eps = 1e-5f;
Expand All @@ -845,7 +939,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ALIBI:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

GGML_ASSERT((src0t == GGML_TYPE_F32));
Expand Down Expand Up @@ -888,7 +982,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

const int n_past = ((int32_t *) dst->op_params)[0];
Expand Down Expand Up @@ -932,7 +1026,7 @@ void ggml_metal_graph_compute(
case GGML_OP_CONT:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
}

const int nth = 32;
Expand Down