@@ -86,7 +86,7 @@ def __init__(
8686 )
8787 self .node_to_group = collections .defaultdict (int )
8888 self .all_nodes_in_groups = set ()
89- if node_groups :
89+ if self . node_groups :
9090 for i , group in enumerate (self .node_groups ):
9191 for node in group :
9292 # Node is in multiple groups - not allowed
@@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
101101 p2_nodes = set (partitions_by_id [p2 ].nodes .keys ())
102102 combined_nodes = p1_nodes .union (p2_nodes )
103103
104- for node in combined_nodes :
105- # Get all downstream nodes that are not in the combined partition
106- external_downstreams = {
107- n
108- for n in self .dependency_viewer .downstreams_of (node )
109- if n not in combined_nodes
110- }
104+ user_nodes = []
105+ # topologically, p2_nodes comes before p1_nodes, so we only
106+ # need to check the downstream nodes of p2.
107+ # Additionally, we don't need to check all the downstream nodes
108+ # of p2, we only need to check the nodes directly outside of p2.
109+ # example:
110+ # partition[a --> b --> c] --> d --> e --> f
111+ # we don't need to check [d, e, f] we only need to check [d] because
112+ # the downstream users of [d] will include [e, f]
113+ for node in p2_nodes :
114+ for user in node .users :
115+ if user not in combined_nodes :
116+ user_nodes .append (user )
111117
118+ for external_node in user_nodes :
112119 # Check if any external downstream nodes have downstream nodes in the combined partition
113- for external_node in external_downstreams :
114- downstream_nodes = self .dependency_viewer .downstreams_of (external_node )
115- if any (n in combined_nodes for n in downstream_nodes ):
116- return False
120+ downstream_nodes = self .dependency_viewer .downstreams_of (external_node )
121+ if any (n in combined_nodes for n in downstream_nodes ):
122+ return False
117123
118124 return True
119125
@@ -133,13 +139,30 @@ def _process_node_groups(
133139 if not self .node_groups :
134140 return group_to_partition_id
135141
136- for i , group in enumerate (self .node_groups ):
137- # Create a partition for each group
142+ processed_nodes = set ()
143+
144+ # We have to create the partitions in reverse topological order
145+ # so we find the groups as we traverse backwards in the graph
146+ # this likely needs to be combined with the process_remaining_nodes
147+ # TODO: this currently doesn't work with _process_remaining_nodes so
148+ # if a user provides grouped nodes with operatorsupport, then this will
149+ # faile
150+ for node in reversed (self .graph_module .graph .nodes ):
151+ if node not in self .node_to_group :
152+ continue
153+
154+ if node in processed_nodes :
155+ continue
156+
157+ group_idx = self .node_to_group [node ]
158+ group = self .node_groups [group_idx ]
159+
160+ # Create a partition for group
138161 partition_id = next (new_partition_id )
139162 partition = Partition (id = partition_id , nodes = set ())
140163 partitions_by_id [partition_id ] = partition
141164 partitions_order [partition_id ] = partition_id
142- group_to_partition_id [i ] = partition_id
165+ group_to_partition_id [group_idx ] = partition_id
143166
144167 # Add all supported nodes from the group to the partition
145168 for node in group :
@@ -164,6 +187,12 @@ def _process_node_groups(
164187 partition_map [partition_id ].add (target_id )
165188 partition_map [partition_id ].update (partition_map [target_id ])
166189
190+ # all the nodes in the group have now been processed
191+ # so skip if we encoutner them again in our rev topo
192+ # iteration
193+ for node in group :
194+ processed_nodes .add (node )
195+
167196 return group_to_partition_id
168197
169198 def _process_remaining_nodes (
@@ -209,7 +238,6 @@ def _merge_partitions(
209238
210239 # Set to track removed partitions from initial static list so we can skip them
211240 already_merged = set ()
212-
213241 # Try to merge each pair of partitions
214242 for i , p1 in enumerate (partition_ids ):
215243 # Skip if this partition has been already merged
0 commit comments