@@ -567,7 +567,11 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
567567
568568// return true if the node's results are only used by N other nodes
569569// and can be fused into their calculations.
570- static inline bool ggml_node_has_n_uses (const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
570+ static inline bool ggml_node_has_n_uses_impl (
571+ const struct ggml_cgraph * cgraph,
572+ int node_idx,
573+ int32_t n_uses,
574+ bool allow_views) {
571575 const struct ggml_tensor * node = cgraph->nodes [node_idx];
572576
573577 // check the use count against how many we're replacing
@@ -579,7 +583,14 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
579583 // if node is a view, some other node might be using the intermediate result
580584 // via the view source.
581585 if (node->view_src ) {
582- return false ;
586+ if (!allow_views) {
587+ return false ;
588+ }
589+
590+ size_t src_hash_pos = ggml_hash_find (&cgraph->visited_hash_set , node->view_src );
591+ if (!ggml_bitset_get (cgraph->visited_hash_set .used , src_hash_pos) || cgraph->use_counts [src_hash_pos] != 1 ) {
592+ return false ;
593+ }
583594 }
584595
585596 // If the user requested output for the node, can't fuse
@@ -590,35 +601,83 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
590601 return true ;
591602}
592603
604+ static inline bool ggml_node_has_n_uses (const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
605+ return ggml_node_has_n_uses_impl (cgraph, node_idx, n_uses, false );
606+ }
607+
593608// Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[]
594609// and are fusable. Nodes are considered fusable according to this function if:
595- // - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
596- // - all nodes except the last are a src of the following node.
597- // - all nodes are the same shape.
610+ // - all nodes except the last have only one use and their consumers are inside the fusion set.
611+ // - dependencies between nodes follow the order provided in node_idxs.
598612// TODO: Consider allowing GGML_OP_NONE nodes in between
599613static inline bool ggml_can_fuse_ext (const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) {
614+ GGML_ASSERT (num_ops <= 32 );
615+
616+ if (num_ops <= 0 ) {
617+ return false ;
618+ }
619+
620+ struct ggml_tensor * nodes[32 ] = {0 };
621+
600622 for (int i = 0 ; i < num_ops; ++i) {
601- if (node_idxs[i] >= cgraph->n_nodes ) {
623+ const int idx = node_idxs[i];
624+ if (idx >= cgraph->n_nodes ) {
602625 return false ;
603626 }
604627
605- struct ggml_tensor * node = cgraph->nodes [node_idxs[i]];
606- if (node->op != ops[i]) {
607- return false ;
608- }
609- if (i < num_ops - 1 && !ggml_node_has_n_uses (cgraph, node_idxs[i], 1 )) {
628+ nodes[i] = cgraph->nodes [idx];
629+ if (nodes[i]->op != ops[i]) {
610630 return false ;
611631 }
612- if (i > 0 ) {
613- struct ggml_tensor * prev = cgraph->nodes [node_idxs[i - 1 ]];
614- if (node->src [0 ] != prev && node->src [1 ] != prev) {
632+ }
633+
634+ for (int i = 0 ; i < num_ops; ++i) {
635+ struct ggml_tensor * node = nodes[i];
636+
637+ if (i < num_ops - 1 ) {
638+ const bool allow_views = node->view_src != NULL ;
639+ if (!ggml_node_has_n_uses_impl (cgraph, node_idxs[i], 1 , allow_views)) {
615640 return false ;
616641 }
617- if (!ggml_are_same_shape (node, prev)) {
642+ }
643+
644+ for (int j = 0 ; j < GGML_MAX_SRC; ++j) {
645+ struct ggml_tensor * src = node->src [j];
646+ if (!src) {
647+ continue ;
648+ }
649+
650+ int src_pos = -1 ;
651+ for (int k = 0 ; k < num_ops; ++k) {
652+ if (nodes[k] == src) {
653+ src_pos = k;
654+ break ;
655+ }
656+ }
657+
658+ if (src_pos != -1 && src_pos >= i) {
618659 return false ;
619660 }
620661 }
621662 }
663+
664+ for (int i = 0 ; i < num_ops - 1 ; ++i) {
665+ bool has_consumer = false ;
666+ for (int k = i + 1 ; k < num_ops && !has_consumer; ++k) {
667+ struct ggml_tensor * consumer = nodes[k];
668+ for (int s = 0 ; s < GGML_MAX_SRC; ++s) {
669+ if (consumer->src [s] == nodes[i]) {
670+ has_consumer = true ;
671+ break ;
672+ }
673+ }
674+ }
675+
676+ if (!has_consumer) {
677+ return false ;
678+ }
679+ }
680+
622681 return true ;
623682}
624683
0 commit comments