Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 71 additions & 67 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3308,6 +3308,71 @@ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int nod
return true;
}

// returns whether the write (out) nodes overwrite the read nodes in operation
static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph,
const int node_idx,
const int node_count,
const int * out_nodes,
const int out_count,
const bool is_topk_moe = false) {
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
const int64_t a_start = (int64_t) a->data;
const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a);

const int64_t b_start = (int64_t) b->data;
const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b);

if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
return true;
}

return false;
};

bool is_ok = true;
// exception for topk-moe, as each row is read entirely before writing
if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) {
Comment on lines +3333 to +3334
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// exception for topk-moe, as each row is read entirely before writing
if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) {
// for topk-moe each thread only writes back to its own memory so an in-place operation is okay
if (is_topk_moe && ggml_nrows(cgraph->nodes[node_idx]) == 1) {

To me this wording is more clear. Also I thing the code is more intuitive if you put the flag at the beginning of the conditional.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the suggested comment is correct, because we read n_experts and write back n_expert_used

return true;
}

for (int i = 0; i < out_count; ++i) {
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];

for (int j = node_idx; j < node_idx + node_count; ++j) {
// Loop over all srcs of all nodes in the fusion. If the src overlaps
// the destination and the src is not an intermediate node that's being
// elided, then disable fusion.

for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];

if (!src || src->op == GGML_OP_NONE) {
continue;
}

if (nodes_overlap(dst, src)) {
bool found = false;

for (int k = node_idx; k < j; ++k) {
if (cgraph->nodes[k] == src) {
found = true;
break;
}
}

if (!found) {
is_ok = false;
break;
}
}
}
}
}

return is_ok;
}


static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops,
Expand Down Expand Up @@ -3337,7 +3402,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
const ggml_tensor * glu = cgraph->nodes[node_idx + 4];

if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
return true;
int out_nodes[] = { node_idx + 4 };
return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
}
}

Expand All @@ -3348,7 +3414,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];

if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
return true;
int out_nodes[] = { node_idx + 2 };
return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
}
}

Expand Down Expand Up @@ -3474,69 +3541,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
return false;
}

// returns whether the write (out) nodes overwrite the read nodes in operation
static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
int node_idx,
int node_count,
int * out_nodes,
int out_count) {
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
const int64_t a_start = (int64_t) a->data;
const int64_t a_end = a_start + ggml_nbytes(a);

const int64_t b_start = (int64_t) b->data;
const int64_t b_end = b_start + ggml_nbytes(b);

if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
return true;
}

return false;
};

bool is_ok = true;
// for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
return true;
}

for (int i = 0; i < out_count; ++i) {
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];

for (int j = node_idx; j < node_idx + node_count; ++j) {
// Loop over all srcs of all nodes in the fusion. If the src overlaps
// the destination and the src is not an intermediate node that's being
// elided, then disable fusion.

for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];

if (!src || src->op == GGML_OP_NONE) {
continue;
}

if (nodes_overlap(dst, src)) {
bool found = false;

for (int k = node_idx; k < j; ++k) {
if (cgraph->nodes[k] == src) {
found = true;
break;
}
}

if (!found) {
is_ok = false;
break;
}
}
}
}
}

return is_ok;
}

static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
bool graph_evaluated_or_captured = false;

Expand Down Expand Up @@ -3734,7 +3738,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud

if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
Expand All @@ -3750,7 +3754,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
Expand Down
Loading