Skip to content

Commit 49d11b1

Browse files
committed
init
1 parent 5a594a7 commit 49d11b1

File tree

3 files changed

+55
-6
lines changed

3 files changed

+55
-6
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,16 @@ def ops_to_not_decompose(
111111
do_not_decompose = []
112112
op_support = OperatorsSupportedForCoreMLBackend()
113113
for node in ep.graph.nodes:
114-
if (
115-
node.op == "call_function"
116-
and isinstance(node.target, torch._ops.OpOverload)
117-
and op_support.is_node_supported(None, node)
114+
if node.op == "call_function" and isinstance(
115+
node.target, torch._ops.OpOverload
118116
):
119-
do_not_decompose.append(node.target)
117+
try:
118+
if op_support.is_node_supported(None, node):
119+
do_not_decompose.append(node.target)
120+
except Exception as e:
121+
# CoreML's op_support.is_node_supported will sometimes throw
122+
# for unsupported ops, rather than returning False
123+
logger.warning(
124+
f"Encountered exception when checking node support: {e}"
125+
)
120126
return do_not_decompose, None

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,28 @@ def test_vit_skip_conv(self):
8282

8383
def test_ops_to_not_decompose(self):
8484
class Model(torch.nn.Module):
85+
def __init__(self) -> None:
86+
super().__init__()
87+
88+
buffer = torch.ones(1)
89+
self.register_buffer("buffer", buffer)
90+
8591
def forward(self, q, k, v, mask):
86-
return torch.ops.aten.scaled_dot_product_attention.default(
92+
out = torch.ops.aten.scaled_dot_product_attention.default(
8793
q, k, v, attn_mask=mask
8894
)
8995

96+
# Add non-functional and alias ops
97+
# These will be removed by ExecuTorch in non-decomposition
98+
# table because they cannot be functionalized
99+
out = out.transpose(1, 2)
100+
out = out.view(1, -1)
101+
out = out.permute(0, 1)
102+
out = out.add_(self.buffer)
103+
out = torch.ops.aten.view_copy.default(out, (-1,))
104+
out = out.select(0, 0)
105+
return out
106+
90107
model = Model()
91108
model.eval()
92109

@@ -107,6 +124,9 @@ def forward(self, q, k, v, mask):
107124
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
108125
ep, partitioner=[coreml_partitioner]
109126
)
127+
print(
128+
format_delegated_graph(edge_program_manager.exported_program().graph_module)
129+
)
110130
self.assertTrue(
111131
"executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default"
112132
in format_delegated_graph(

exir/program/_program.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
2727
from executorch.exir.error import ExportError
2828
from executorch.exir.graph_module import get_control_flow_submodules
29+
from executorch.exir.operator.convert import _pybind_schema_to_native_schema
2930
from executorch.exir.pass_base import PassBase
3031
from executorch.exir.pass_manager import PassType
3132
from executorch.exir.passes import (
@@ -836,6 +837,9 @@ def _replace_aten_ops_with_transformed_ops(
836837
ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose(
837838
program
838839
)
840+
ops_set_to_not_decompose = _remove_invalid_ops_for_not_decompose(
841+
ops_set_to_not_decompose
842+
)
839843

840844
for op_aten in ops_set_to_not_decompose:
841845
_register_no_decomp_op(op_aten)
@@ -965,6 +969,21 @@ def _sanity_check_graph_for_non_decomp_ops(
965969
logging.warning(warning_str)
966970

967971

972+
def _remove_invalid_ops_for_not_decompose(
973+
ops_to_not_decompose: List[torch._ops.OpOverload],
974+
) -> List[torch._ops.OpOverload]:
975+
def keep(op):
976+
schema = op._schema
977+
native_schema = _pybind_schema_to_native_schema(schema)
978+
if native_schema.is_mutable:
979+
return False
980+
if native_schema.aliased_return_names() != [None]:
981+
return False
982+
return True
983+
984+
return list(filter(keep, ops_to_not_decompose))
985+
986+
968987
def _gen_edge_manager_for_partitioners(
969988
partitioner: Dict[str, List[Partitioner]],
970989
aten_programs: Dict[str, ExportedProgram],
@@ -992,6 +1011,9 @@ def _gen_edge_manager_for_partitioners(
9921011
all_ops_no_decomp = set()
9931012
for curr_partitioner in partitioner.get(name, []):
9941013
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1014+
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1015+
curr_ops_no_decomp
1016+
)
9951017
all_ops_no_decomp |= set(curr_ops_no_decomp)
9961018

9971019
table = _default_decomposition_table()
@@ -1113,6 +1135,7 @@ def to_edge_transform_and_lower(
11131135
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
11141136
program
11151137
)
1138+
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
11161139
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
11171140
_sanity_check_graph_for_non_decomp_ops(
11181141
name,

0 commit comments

Comments
 (0)