|
26 | 26 | from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap |
27 | 27 | from executorch.exir.error import ExportError |
28 | 28 | from executorch.exir.graph_module import get_control_flow_submodules |
| 29 | +from executorch.exir.operator.convert import _pybind_schema_to_native_schema |
29 | 30 | from executorch.exir.pass_base import PassBase |
30 | 31 | from executorch.exir.pass_manager import PassType |
31 | 32 | from executorch.exir.passes import ( |
@@ -836,6 +837,9 @@ def _replace_aten_ops_with_transformed_ops( |
836 | 837 | ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose( |
837 | 838 | program |
838 | 839 | ) |
| 840 | + ops_set_to_not_decompose = _remove_invalid_ops_for_not_decompose( |
| 841 | + ops_set_to_not_decompose |
| 842 | + ) |
839 | 843 |
|
840 | 844 | for op_aten in ops_set_to_not_decompose: |
841 | 845 | _register_no_decomp_op(op_aten) |
@@ -965,6 +969,21 @@ def _sanity_check_graph_for_non_decomp_ops( |
965 | 969 | logging.warning(warning_str) |
966 | 970 |
|
967 | 971 |
|
| 972 | +def _remove_invalid_ops_for_not_decompose( |
| 973 | + ops_to_not_decompose: List[torch._ops.OpOverload], |
| 974 | +) -> List[torch._ops.OpOverload]: |
| 975 | + def keep(op): |
| 976 | + schema = op._schema |
| 977 | + native_schema = _pybind_schema_to_native_schema(schema) |
| 978 | + if native_schema.is_mutable: |
| 979 | + return False |
| 980 | + if native_schema.aliased_return_names() != [None]: |
| 981 | + return False |
| 982 | + return True |
| 983 | + |
| 984 | + return list(filter(keep, ops_to_not_decompose)) |
| 985 | + |
| 986 | + |
968 | 987 | def _gen_edge_manager_for_partitioners( |
969 | 988 | partitioner: Dict[str, List[Partitioner]], |
970 | 989 | aten_programs: Dict[str, ExportedProgram], |
@@ -992,6 +1011,9 @@ def _gen_edge_manager_for_partitioners( |
992 | 1011 | all_ops_no_decomp = set() |
993 | 1012 | for curr_partitioner in partitioner.get(name, []): |
994 | 1013 | curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) |
| 1014 | + curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( |
| 1015 | + curr_ops_no_decomp |
| 1016 | + ) |
995 | 1017 | all_ops_no_decomp |= set(curr_ops_no_decomp) |
996 | 1018 |
|
997 | 1019 | table = _default_decomposition_table() |
@@ -1113,6 +1135,7 @@ def to_edge_transform_and_lower( |
1113 | 1135 | curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose( |
1114 | 1136 | program |
1115 | 1137 | ) |
| 1138 | + curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set) |
1116 | 1139 | ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set) |
1117 | 1140 | _sanity_check_graph_for_non_decomp_ops( |
1118 | 1141 | name, |
|
0 commit comments