Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2821,15 +2821,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);

if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {

if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
return false;
}

for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
}
if (ops.size() == topk_moe_ops_with_norm.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+8];

Expand All @@ -2838,16 +2831,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}

if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {

if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
return false;
}

for (size_t i = 0; i < topk_moe_ops.size(); i++) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
}

if (ops.size() == topk_moe_ops.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx+4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
Expand Down
36 changes: 36 additions & 0 deletions ggml/src/ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,35 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
}

GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
const int * node_idxs,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs);

// Returns true if the subgraph formed by {node_idxs} can be fused
// checks whethers all nodes which are not part of outputs can be elided
// by checking if their num_uses are confined to the subgraph
static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
int node_idx,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs) {
if (node_idx + count > cgraph->n_nodes) {
return false;
}

int idxs[32];

for (int i = 0; i < count; ++i) {
idxs[i] = node_idx + i;
}

return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
}

#ifdef __cplusplus
}
#endif
Expand All @@ -660,6 +689,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
}

inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
int start_idx,
std::initializer_list<enum ggml_op> ops,
std::initializer_list<int> outputs = {}) {
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}

// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
Expand Down
68 changes: 68 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6964,6 +6964,74 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_LOG_INFO("========================================\n");
}

static int ggml_find_tensor_node_list(const struct ggml_cgraph * cgraph,
const int * idxs,
int count,
const struct ggml_tensor * tensor) {
GGML_ASSERT(cgraph && idxs);
for (int i = 0; i < count; ++i) {
const int node_idx = idxs[i];

if (node_idx >= cgraph->n_nodes) {
return -1;
}
if (cgraph->nodes[node_idx] == tensor) {
return i;
}
}
return -1;
}

bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
const int * node_idxs,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs) {
GGML_ASSERT(count < 32 && outputs && num_outputs > 0);

for (int i = 0; i < count; ++i) {
if (node_idxs[i] >= cgraph->n_nodes || cgraph->nodes[node_idxs[i]]->op != ops[i]) {
return false;
}

const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];

if (ggml_find_tensor_node_list(cgraph, outputs, num_outputs, node) != -1) {
continue;
}

if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
return false;
}

int subgraph_uses = 0;
for (int j = i + 1; j < count; ++j) {
const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
if (other_node->src[src_idx] == node) {
subgraph_uses++;
}
}
}

if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
return false;
}

// if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
struct ggml_tensor * view_src = node->view_src;
while (view_src) {
if (ggml_find_tensor_node_list(cgraph, node_idxs, count, view_src) == -1) {
return false;
}
view_src = view_src->view_src;
}
}

return true;
}

// check if node is part of the graph
static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
if (cgraph == NULL) {
Expand Down
Loading