Skip to content

Commit ab68379

Browse files
committed
[Group Partitioner] leverage group partitioner for config-based partitioner
ghstack-source-id: fdc99b6 ghstack-comment-id: 3115642963 Pull Request resolved: #12845
1 parent 8bae6f7 commit ab68379

File tree

2 files changed

+88
-10
lines changed

2 files changed

+88
-10
lines changed

exir/backend/canonical_partitioners/config_partitioner.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010
import torch
1111
from executorch.exir.backend.backend_details import ExportedProgram
1212
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
13-
generate_partitions_from_list_of_nodes,
13+
generate_grouped_partitions_from_list_of_nodes,
1414
)
1515
from executorch.exir.backend.partitioner import (
1616
DelegationSpec,
1717
Partitioner,
1818
PartitionResult,
1919
)
20+
21+
from exir.backend.canonical_partitioners.pattern_op_partitioner import (
22+
generate_grouped_partitions_from_list_of_nodes,
23+
)
2024
from torch.fx.passes.infra.partitioner import Partition
2125

2226

@@ -162,23 +166,49 @@ def filter_fn(node: torch.fx.Node) -> bool:
162166
def get_matched_nodes_from_configs(
163167
self, ep: ExportedProgram
164168
) -> List[List[torch.fx.Node]]:
169+
# disjoint set union
170+
parent = {}
171+
172+
def find(x):
173+
parent.setdefault(x, x)
174+
if parent[x] != x:
175+
parent[x] = find(parent[x])
176+
return parent[x]
177+
178+
def union(x, y):
179+
parent[find(x)] = find(y)
180+
165181
# gather supported nodes
166-
matched_nodes = []
167182
gm = ep.graph_module
168183
for node in gm.graph.nodes:
169-
if node.op == "call_function":
170-
target = format_target_name(node.target.__name__)
171-
if target in self.target_partitioner_configs:
172-
node_config = self.target_partitioner_configs[target]
173-
if node_config.check_constraints(node, ep):
174-
matched_nodes.append(node_config.get_partition(node, ep))
184+
if node.op != "call_function":
185+
continue
186+
target = format_target_name(node.target.__name__)
187+
188+
if target not in self.target_partitioner_configs:
189+
continue
190+
191+
node_config = self.target_partitioner_configs[target]
192+
if not node_config.check_constraints(node, ep):
193+
continue
194+
195+
partition = node_config.get_partition(node, ep)
196+
if len(partition) > 0:
197+
parent[partition[0]] = partition[0]
198+
for i in range(1, len(partition)):
199+
union(partition[0], partition[i])
200+
201+
groups = {}
202+
for node in parent.keys():
203+
root = find(node)
204+
groups.setdefault(root, set()).add(node)
175205

176-
return matched_nodes
206+
return [list(group) for group in groups.values()]
177207

178208
def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
179209
matched_nodes = self.get_matched_nodes_from_configs(ep)
180210
# create partitions
181-
partitions = generate_partitions_from_list_of_nodes(
211+
partitions = generate_grouped_partitions_from_list_of_nodes(
182212
ep.graph_module,
183213
matched_nodes,
184214
)

exir/backend/canonical_partitioners/pattern_op_partitioner.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from typing import List, Optional
99

1010
import torch
11+
12+
from executorch.exir.backend.canonical_partitioners.group_partitioner import (
13+
GroupBasedPartitioner,
14+
)
1115
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
1216
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
1317
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
@@ -56,6 +60,50 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5660
return partition_list
5761

5862

63+
def generate_grouped_partitions_from_list_of_nodes(
64+
graph_module: torch.fx.GraphModule,
65+
pattern_list: Optional[List[List[torch.fx.Node]]] = None,
66+
op_support: Optional[OperatorSupportBase] = None,
67+
) -> List[Partition]:
68+
final_op_support: Optional[OperatorSupportBase] = op_support
69+
70+
if pattern_list is not None:
71+
# Tag all the nodes in these patterns
72+
for node_list in pattern_list:
73+
for node in node_list:
74+
node.meta["match"] = True
75+
76+
class MatchTag(OperatorSupportBase):
77+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
78+
return node.meta.get("match", False)
79+
80+
final_op_support = (
81+
MatchTag()
82+
if final_op_support is None
83+
else any_chain(final_op_support, MatchTag())
84+
)
85+
86+
assert (
87+
final_op_support is not None
88+
), "Did not give a pattern or OperatorSupportBase instance to partition with"
89+
90+
# Run the CapabilityBasedPartitioner to return the largest possible
91+
# subgraphs containing the nodes with the tags
92+
group_partitioner = GroupBasedPartitioner(
93+
graph_module,
94+
final_op_support,
95+
node_groups=pattern_list,
96+
allows_single_node_partition=True,
97+
)
98+
partition_list = group_partitioner.propose_partitions()
99+
100+
# Remove the metadata field we added
101+
for partition in partition_list:
102+
for node in partition.nodes:
103+
node.meta.pop("match", False)
104+
return partition_list
105+
106+
59107
def generate_pattern_op_partitions(
60108
graph_module: torch.fx.GraphModule,
61109
patterns: Optional[List[torch.fx.Graph]] = None,

0 commit comments

Comments
 (0)