Skip to content

Commit d853036

Browse files
committed
1. remove inputs from signature as they are transient nodes
2. add check for views: view_src should be part of the subgraph
1 parent ba472d1 commit d853036

File tree

3 files changed

+15
-21
lines changed

3 files changed

+15
-21
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2822,8 +2822,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28222822
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
28232823

28242824
if (ops.size() == topk_moe_ops_with_norm.size() &&
2825-
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx },
2826-
{ node_idx + 3, node_idx + 8 })) {
2825+
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
28272826
ggml_tensor * softmax = cgraph->nodes[node_idx];
28282827
ggml_tensor * weights = cgraph->nodes[node_idx+8];
28292828

@@ -2833,7 +2832,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28332832
}
28342833

28352834
if (ops.size() == topk_moe_ops.size() &&
2836-
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx }, { node_idx + 3, node_idx + 4 })) {
2835+
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) {
28372836
ggml_tensor * softmax = cgraph->nodes[node_idx];
28382837
ggml_tensor * weights = cgraph->nodes[node_idx+4];
28392838
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {

ggml/src/ggml-impl.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -651,20 +651,16 @@ GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
651651
const int * node_idxs,
652652
int count,
653653
const enum ggml_op * ops,
654-
const int * inputs,
655-
int num_inputs,
656654
const int * outputs,
657655
int num_outputs);
658656

659657
// Returns true if the subgraph formed by {node_idxs} can be fused
660-
// checks whethers all nodes which are not part of inputs/outputs can be elided
658+
// checks whethers all nodes which are not part of outputs can be elided
661659
// by checking if their num_uses are confined to the subgraph
662660
static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
663661
int node_idx,
664662
int count,
665663
const enum ggml_op * ops,
666-
const int * inputs,
667-
int num_inputs,
668664
const int * outputs,
669665
int num_outputs) {
670666
if (node_idx + count > cgraph->n_nodes) {
@@ -677,7 +673,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
677673
idxs[i] = node_idx + i;
678674
}
679675

680-
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, inputs, num_inputs, outputs, num_outputs);
676+
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
681677
}
682678

683679
#ifdef __cplusplus
@@ -696,10 +692,8 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
696692
inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
697693
int start_idx,
698694
std::initializer_list<enum ggml_op> ops,
699-
std::initializer_list<int> inputs = {},
700695
std::initializer_list<int> outputs = {}) {
701-
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), inputs.begin(), inputs.size(),
702-
outputs.begin(), outputs.size());
696+
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
703697
}
704698

705699
// expose GGUF internals for test code

ggml/src/ggml.c

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6989,11 +6989,9 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
69896989
const int * node_idxs,
69906990
int count,
69916991
const enum ggml_op * ops,
6992-
const int * inputs,
6993-
int num_inputs,
69946992
const int * outputs,
69956993
int num_outputs) {
6996-
GGML_ASSERT(count < 32 && num_inputs > 0 && num_outputs > 0);
6994+
GGML_ASSERT(count < 32 && outputs && num_outputs > 0);
69976995
int interior_nodes_count = 0;
69986996
int interior_nodes[32];
69996997

@@ -7008,25 +7006,20 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
70087006
return false;
70097007
}
70107008

7011-
if (ggml_find_tensor_node_list(cgraph, inputs, num_inputs, node) != -1) {
7012-
continue;
7013-
}
7014-
70157009
if (ggml_find_tensor_node_list(cgraph, outputs, num_outputs, node) != -1) {
70167010
continue;
70177011
}
70187012

70197013
interior_nodes[interior_nodes_count++] = node_idxs[i];
70207014
}
70217015

7022-
// if interior-node has n-uses, ensure that all of them lie within in this subgraph
70237016
for (int i = 0; i < interior_nodes_count; ++i) {
70247017
const int num_uses = ggml_node_get_use_count(cgraph, interior_nodes[i]);
70257018

70267019
const struct ggml_tensor * node = cgraph->nodes[interior_nodes[i]];
70277020

7021+
// if interior-node has n-uses, ensure that all of them lie within in this subgraph
70287022
int subgraph_uses = 0;
7029-
//check if all uses are within the graph
70307023
for (int j = 0; j < count; ++j) {
70317024
const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
70327025
for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
@@ -7039,6 +7032,14 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
70397032
if (subgraph_uses != num_uses) {
70407033
return false;
70417034
}
7035+
7036+
// if node is a view, check if the view src is within the subgraph
7037+
if (node->view_src) {
7038+
const struct ggml_tensor * view_src = node->view_src;
7039+
if (ggml_find_tensor_node_list(cgraph, node_idxs, count, view_src) == -1) {
7040+
return false;
7041+
}
7042+
}
70427043
}
70437044

70447045
return true;

0 commit comments

Comments
 (0)