@@ -583,27 +583,27 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
583583    return  true ;
584584}
585585
586- //  Returns true if nodes [i, i+ops.size())  are the sequence of ggml_ops in ops[]
586+ //  Returns true if nodes with indices { node_idxs }  are the sequence of ggml_ops in ops[]
587587//  and are fusable. Nodes are considered fusable according to this function if:
588588//  - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
589589//  - all nodes except the last are a src of the following node.
590590//  - all nodes are the same shape.
591591//  TODO: Consider allowing GGML_OP_NONE nodes in between
592- static  inline  bool  ggml_can_fuse (const  struct  ggml_cgraph  * cgraph, int  node_idx, const  enum  ggml_op * ops, int  num_ops) {
593-     if  (node_idx + num_ops > cgraph->n_nodes ) {
594-         return  false ;
595-     }
596- 
592+ static  inline  bool  ggml_can_fuse_ext (const  struct  ggml_cgraph  * cgraph, const  int  * node_idxs, const  enum  ggml_op * ops, int  num_ops) {
597593    for  (int  i = 0 ; i < num_ops; ++i) {
598-         struct  ggml_tensor  * node = cgraph->nodes [node_idx + i];
594+         if  (node_idxs[i] >= cgraph->n_nodes ) {
595+             return  false ;
596+         }
597+ 
598+         struct  ggml_tensor  * node = cgraph->nodes [node_idxs[i]];
599599        if  (node->op  != ops[i]) {
600600            return  false ;
601601        }
602-         if  (i < num_ops - 1  && !ggml_node_has_n_uses (cgraph, node_idx + i , 1 )) {
602+         if  (i < num_ops - 1  && !ggml_node_has_n_uses (cgraph, node_idxs[i] , 1 )) {
603603            return  false ;
604604        }
605605        if  (i > 0 ) {
606-             struct  ggml_tensor  * prev = cgraph->nodes [node_idx +  i - 1 ];
606+             struct  ggml_tensor  * prev = cgraph->nodes [node_idxs[ i - 1 ] ];
607607            if  (node->src [0 ] != prev && node->src [1 ] != prev) {
608608                return  false ;
609609            }
@@ -615,6 +615,22 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
615615    return  true ;
616616}
617617
618+ //  same as above, for sequential indices starting at node_idx
619+ static  inline  bool  ggml_can_fuse (const  struct  ggml_cgraph  * cgraph, int  node_idx, const  enum  ggml_op * ops, int  num_ops) {
620+     assert (num_ops < 32 );
621+ 
622+     if  (node_idx + num_ops > cgraph->n_nodes ) {
623+         return  false ;
624+     }
625+ 
626+     int  idxs[32 ];
627+     for  (int  i = 0 ; i < num_ops; ++i) {
628+         idxs[i] = node_idx + i;
629+     }
630+ 
631+     return  ggml_can_fuse_ext (cgraph, idxs, ops, num_ops);
632+ }
633+ 
618634#ifdef  __cplusplus
619635}
620636#endif 
0 commit comments