Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
call_submodule_node.replace_all_uses_with(call_delegate_node)
owning_graph_module.graph.erase_node(call_submodule_node)

if is_submodule:
assert len(toplevel_input_specs_to_delete) == 0
assert len(toplevel_output_specs_to_delete) == 0
Expand Down Expand Up @@ -574,26 +573,29 @@ def lower_all_submodules_to_backend(
# The created exported program for the submodules are in the call_module node's meta data
# We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs
method_to_partitioned_program = {
method_name: [node.meta["submodule_program"] for node in call_submodule_nodes]
method_name: [
copy.deepcopy(node.meta["submodule_program"])
for node in call_submodule_nodes
]
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
}
method_to_compile_specs = {
method_name: [node.meta["compile_spec"] for node in call_submodule_nodes]
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
}
backend_found = False
for cls in BackendDetails.__subclasses__():
if backend_id == cls.__name__:
method_to_preprocess_result: dict[str, List[PreprocessResult]] = (
cls.preprocess_multimethod(
method_to_partitioned_program, method_to_compile_specs
)
)
backend_found = True

if not backend_found:
backend_name_to_subclass = {
subclass.__name__: subclass for subclass in BackendDetails.__subclasses__()
}
if backend_id not in backend_name_to_subclass:
raise NotImplementedError(f"Backend {backend_id} was not found.")

method_to_preprocess_result: dict[str, List[PreprocessResult]] = (
backend_name_to_subclass[backend_id].preprocess_multimethod(
method_to_partitioned_program, method_to_compile_specs
)
)

for method_name in method_to_preprocess_result.keys():
owning_program = method_to_tagged_edge_program[method_name]
list_of_preprocess_results = method_to_preprocess_result[method_name]
Expand All @@ -612,6 +614,9 @@ def lower_all_submodules_to_backend(
compile_specs=compile_spec,
named_data_store_output=preprocess_result.data_store_output,
)
lowered_module.meta = {
"debug_handle_map": preprocess_result.debug_handle_map,
}
is_submodule = call_submodule_node.meta["is_submodule"]
toplevel_input_specs_to_delete = call_submodule_node.meta[
"toplevel_input_specs_to_delete"
Expand All @@ -633,6 +638,20 @@ def lower_all_submodules_to_backend(
)


def remove_used_metadata(graph: torch.fx.Graph) -> None:
"""
Remove the used metadata from the graph.
"""
for node in graph.nodes:
node.meta.pop("delegation_tag", None)
node.meta.pop("backend_id", None)
node.meta.pop("submodule_program", None)
node.meta.pop("toplevel_input_specs_to_delete", None)
node.meta.pop("toplevel_output_specs_to_delete", None)
node.meta.pop("is_submodule", None)
node.meta.pop("submodule_output_node", None)


@dataclass
class MethodProgramsPartitionerSpec:
"""
Expand Down Expand Up @@ -748,6 +767,7 @@ def to_backend(
if method_name in method_to_tagged_exported_program:
tagged_exported_program = method_to_tagged_exported_program[method_name]
tagged_exported_program._validate()
remove_used_metadata(tagged_exported_program.graph_module.graph)
partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(
root=tagged_exported_program.graph_module,
graph=tagged_exported_program.graph_module.graph,
Expand Down
39 changes: 25 additions & 14 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from executorch.exir._serialize._serialize import serialize_for_executorch
from executorch.exir._serialize.data_serializer import DataSerializer
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.backend_api import (
MethodProgramsPartitionerSpec,
to_backend,
)
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.delegate import executorch_call_delegate, is_lowered_module
Expand Down Expand Up @@ -1239,10 +1242,16 @@ def to_edge_transform_and_lower(
if transform_passes is not None:
edge_manager = edge_manager.transform(transform_passes)

if partitioner is not None:
max_num_partitioners = 0
for partitioner_list in partitioner.values():
max_num_partitioners = max(max_num_partitioners, len(partitioner_list))

for i in range(max_num_partitioners):
method_to_partitioner = {}
for name, partitioner_list in partitioner.items():
for curr_partitioner in partitioner_list:
edge_manager = edge_manager.to_backend({name: curr_partitioner})
if i < len(partitioner_list):
method_to_partitioner[name] = partitioner_list[i]
edge_manager = edge_manager.to_backend(method_to_partitioner)

for name, program in edge_manager._edge_programs.items():
ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set()
Expand Down Expand Up @@ -1475,7 +1484,8 @@ def transform(

@et_logger("to_backend")
def to_backend(
self, partitioner: Union[Partitioner, Dict[str, Partitioner]]
self,
partitioner: Union[Partitioner, Dict[str, Partitioner]],
) -> "EdgeProgramManager":
"""
Returns a semantically-equivalent program to the one given as input,
Expand All @@ -1501,17 +1511,18 @@ def to_backend(
specified subgraphs lowered.
"""
new_edge_programs: Dict[str, ExportedProgram] = {}
if isinstance(partitioner, dict):
for name, program in self._edge_programs.items():
if name in partitioner.keys():
new_edge_programs[name] = to_backend(program, partitioner[name])
else:
new_edge_programs[name] = program
method_to_partitioner: Dict[str, Partitioner] = {}
if not isinstance(partitioner, dict):
method_to_partitioner = {name: partitioner for name in self._edge_programs}
else:
method_to_partitioner = partitioner

else: # apply partitioner to every method
for name, program in self._edge_programs.items():
new_edge_programs[name] = to_backend(program, partitioner)
method_to_programs_and_partitioners = MethodProgramsPartitionerSpec(
self._edge_programs,
method_to_partitioner,
)

new_edge_programs = to_backend(method_to_programs_and_partitioners)
config = EdgeCompileConfig(_check_ir_validity=False)
return EdgeProgramManager(
new_edge_programs,
Expand Down
Loading