@@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
101101 p2_nodes = set (partitions_by_id [p2 ].nodes .keys ())
102102 combined_nodes = p1_nodes .union (p2_nodes )
103103
104- for node in combined_nodes :
105- # Get all downstream nodes that are not in the combined partition
106- external_downstreams = {
107- n
108- for n in self .dependency_viewer .downstreams_of (node )
109- if n not in combined_nodes
110- }
104+ user_nodes = []
105+ # topologically, p2_nodes comes before p1_nodes, so we only
106+ # need to check the downstream nodes of p2.
107+ # Additionally, we don't need to check all the downstream nodes
108+ # of p2, we only need to check the nodes directly outside of p2.
109+ # example:
110+ # partition[a --> b --> c] --> d --> e --> f
111+ # we don't need to check [d, e, f] we only need to check [d] because
112+ # the downstream users of [d] will include [e, f]
113+ for node in p2_nodes :
114+ for user in node .users :
115+ if user not in combined_nodes :
116+ user_nodes .append (user )
111117
118+ for external_node in user_nodes :
112119 # Check if any external downstream nodes have downstream nodes in the combined partition
113- for external_node in external_downstreams :
114- downstream_nodes = self .dependency_viewer .downstreams_of (external_node )
115- if any (n in combined_nodes for n in downstream_nodes ):
116- return False
120+ downstream_nodes = self .dependency_viewer .downstreams_of (external_node )
121+ if any (n in combined_nodes for n in downstream_nodes ):
122+ return False
117123
118124 return True
119125
@@ -133,13 +139,35 @@ def _process_node_groups(
133139 if not self .node_groups :
134140 return group_to_partition_id
135141
136- for i , group in enumerate (self .node_groups ):
137- # Create a partition for each group
142+ node_to_group_index = {}
143+ for idx , group in enumerate (self .node_groups ):
144+ for node in group :
145+ node_to_group_index [node ] = idx
146+
147+ processed_nodes = set ()
148+
149+ # We have to create the partitions in reverse topological order
150+ # so we find the groups as we traverse backwards in the graph
151+ # this likely needs to be combined with the process_remaining_nodes
152+ # TODO: this currently doesn't work with _process_remaining_nodes so
153+ # if a user provides grouped nodes with operatorsupport, then this will
154+ # faile
155+ for node in reversed (self .graph_module .graph .nodes ):
156+ if node not in node_to_group_index :
157+ continue
158+
159+ if node in processed_nodes :
160+ continue
161+
162+ group_idx = node_to_group_index [node ]
163+ group = self .node_groups [group_idx ]
164+
165+ # Create a partition for group
138166 partition_id = next (new_partition_id )
139167 partition = Partition (id = partition_id , nodes = set ())
140168 partitions_by_id [partition_id ] = partition
141169 partitions_order [partition_id ] = partition_id
142- group_to_partition_id [i ] = partition_id
170+ group_to_partition_id [group_idx ] = partition_id
143171
144172 # Add all supported nodes from the group to the partition
145173 for node in group :
@@ -164,6 +192,12 @@ def _process_node_groups(
164192 partition_map [partition_id ].add (target_id )
165193 partition_map [partition_id ].update (partition_map [target_id ])
166194
195+ # all the nodes in the group have now been processed
196+ # so skip if we encoutner them again in our rev topo
197+ # iteration
198+ for node in group :
199+ processed_nodes .add (node )
200+
167201 return group_to_partition_id
168202
169203 def _process_remaining_nodes (
@@ -209,7 +243,6 @@ def _merge_partitions(
209243
210244 # Set to track removed partitions from initial static list so we can skip them
211245 already_merged = set ()
212-
213246 # Try to merge each pair of partitions
214247 for i , p1 in enumerate (partition_ids ):
215248 # Skip if this partition has been already merged
0 commit comments