Skip to content

Commit 3c3e2f3

Browse files
committed
[Group Partitioner] Optimize Speed
ghstack-source-id: 4abdb52 ghstack-comment-id: 3115642422 Pull Request resolved: #12844
1 parent f592d85 commit 3c3e2f3

File tree

1 file changed

+48
-15
lines changed

1 file changed

+48
-15
lines changed

exir/backend/canonical_partitioners/group_partitioner.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
101101
p2_nodes = set(partitions_by_id[p2].nodes.keys())
102102
combined_nodes = p1_nodes.union(p2_nodes)
103103

104-
for node in combined_nodes:
105-
# Get all downstream nodes that are not in the combined partition
106-
external_downstreams = {
107-
n
108-
for n in self.dependency_viewer.downstreams_of(node)
109-
if n not in combined_nodes
110-
}
104+
user_nodes = []
105+
# topologically, p2_nodes comes before p1_nodes, so we only
106+
# need to check the downstream nodes of p2.
107+
# Additionally, we don't need to check all the downstream nodes
108+
# of p2, we only need to check the nodes directly outside of p2.
109+
# example:
110+
# partition[a --> b --> c] --> d --> e --> f
111+
# we don't need to check [d, e, f] we only need to check [d] because
112+
# the downstream users of [d] will include [e, f]
113+
for node in p2_nodes:
114+
for user in node.users:
115+
if user not in combined_nodes:
116+
user_nodes.append(user)
111117

118+
for external_node in user_nodes:
112119
# Check if any external downstream nodes have downstream nodes in the combined partition
113-
for external_node in external_downstreams:
114-
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
115-
if any(n in combined_nodes for n in downstream_nodes):
116-
return False
120+
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
121+
if any(n in combined_nodes for n in downstream_nodes):
122+
return False
117123

118124
return True
119125

@@ -133,13 +139,35 @@ def _process_node_groups(
133139
if not self.node_groups:
134140
return group_to_partition_id
135141

136-
for i, group in enumerate(self.node_groups):
137-
# Create a partition for each group
142+
node_to_group_index = {}
143+
for idx, group in enumerate(self.node_groups):
144+
for node in group:
145+
node_to_group_index[node] = idx
146+
147+
processed_nodes = set()
148+
149+
# We have to create the partitions in reverse topological order
150+
# so we find the groups as we traverse backwards in the graph
151+
# this likely needs to be combined with the process_remaining_nodes
152+
# TODO: this currently doesn't work with _process_remaining_nodes so
153+
# if a user provides grouped nodes with operatorsupport, then this will
154+
# faile
155+
for node in reversed(self.graph_module.graph.nodes):
156+
if node not in node_to_group_index:
157+
continue
158+
159+
if node in processed_nodes:
160+
continue
161+
162+
group_idx = node_to_group_index[node]
163+
group = self.node_groups[group_idx]
164+
165+
# Create a partition for group
138166
partition_id = next(new_partition_id)
139167
partition = Partition(id=partition_id, nodes=set())
140168
partitions_by_id[partition_id] = partition
141169
partitions_order[partition_id] = partition_id
142-
group_to_partition_id[i] = partition_id
170+
group_to_partition_id[group_idx] = partition_id
143171

144172
# Add all supported nodes from the group to the partition
145173
for node in group:
@@ -164,6 +192,12 @@ def _process_node_groups(
164192
partition_map[partition_id].add(target_id)
165193
partition_map[partition_id].update(partition_map[target_id])
166194

195+
# all the nodes in the group have now been processed
196+
# so skip if we encoutner them again in our rev topo
197+
# iteration
198+
for node in group:
199+
processed_nodes.add(node)
200+
167201
return group_to_partition_id
168202

169203
def _process_remaining_nodes(
@@ -209,7 +243,6 @@ def _merge_partitions(
209243

210244
# Set to track removed partitions from initial static list so we can skip them
211245
already_merged = set()
212-
213246
# Try to merge each pair of partitions
214247
for i, p1 in enumerate(partition_ids):
215248
# Skip if this partition has been already merged

0 commit comments

Comments
 (0)