Skip to content

Commit 9e7e17c

Browse files
GasoonjiaCopilot
andauthored
[aoti-backend-consolidation 2/3] backend.py (#15528)
Summary: # Summary This diff consolidates the backend functionality into a single target `//executorch/backends/aoti:aoti_backend` and simplifies the cuda backend target by making it dependent on the consolidated backend target. The following changes are made in this diff: * Creation of a new target `//executorch/backends/aoti:aoti_backend` in `fbcode/executorch/backends/aoti/targets.bzl` which includes the necessary dependencies for the AOTI backend. * Update of the `//executorch/backends/cuda:cuda_backend` target in `fbcode/executorch/backends/cuda/TARGETS` to depend on the new `//executorch/backends/aoti:aoti_backend` target instead of individual AOTI backend dependencies. * Creation of a new file `fbcode/executorch/backends/aoti/aoti_backend.py` which imports the necessary dependencies and passes for the AOTI backend. * Simplification of the `xplat/executorch/backends/cuda/cuda_backend.py` file by removing unnecessary imports and using the new `AotiBackend` class from the `aoti_backend.py` file. ghstack-source-id: 319556735 Reviewed By: larryliu0820 Differential Revision: D85704977 --------- Co-authored-by: Copilot <[email protected]>
1 parent 2b91382 commit 9e7e17c

File tree

5 files changed

+378
-375
lines changed

5 files changed

+378
-375
lines changed

backends/aoti/aoti_backend.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import os
9+
import typing
10+
from abc import ABC, abstractmethod
11+
from enum import Enum
12+
from typing import Any, Dict, List, Optional, Set
13+
14+
import torch
15+
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
16+
ReplaceViewCopyWithViewPass,
17+
)
18+
from executorch.exir._serialize._named_data_store import NamedDataStore
19+
from executorch.exir._warnings import experimental
20+
from executorch.exir.backend.backend_details import ExportedProgram, PreprocessResult
21+
from executorch.exir.backend.compile_spec_schema import CompileSpec
22+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
23+
from torch.export.passes import move_to_device_pass
24+
25+
26+
class COMPILE_SPEC_KEYS(Enum):
27+
METHOD_NAME = "method_name"
28+
29+
30+
@experimental(
31+
"This API and all of aoti-driven backend related functionality are experimental."
32+
)
33+
class AotiBackend(ABC):
34+
"""
35+
Base mixin class for AOTInductor-based backends.
36+
37+
This class provides common functionality for compiling models using AOTInductor
38+
with different device targets (CUDA, Metal, etc.).
39+
40+
This is a mixin class, not an actual backend object, for aoti-driven backends.
41+
Concrete backends (e.g., CudaBackend, MetalBackend) should inherit from both
42+
BackendDetails and AotiBackend to get the full functionality.
43+
"""
44+
45+
@classmethod
46+
@abstractmethod
47+
def get_device_name(cls) -> str:
48+
"""Return the device name for this backend (e.g., 'cuda', 'metal')."""
49+
pass
50+
51+
@classmethod
52+
@abstractmethod
53+
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
54+
"""Return the set of supported fallback kernels for this backend."""
55+
pass
56+
57+
@classmethod
58+
@abstractmethod
59+
def get_decomposition_table(cls) -> Dict[Any, Any]:
60+
"""Return the decomposition table for this backend."""
61+
pass
62+
63+
@classmethod
64+
@abstractmethod
65+
def get_aoti_compile_options(
66+
cls, compile_specs: List[CompileSpec]
67+
) -> Dict[str, typing.Any]:
68+
"""Return the AOTInductor compilation options for this backend."""
69+
pass
70+
71+
@classmethod
72+
@abstractmethod
73+
def get_custom_passes(cls) -> List[typing.Any]:
74+
"""Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition."""
75+
pass
76+
77+
@classmethod
78+
@contextlib.contextmanager
79+
def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
80+
"""
81+
Context manager to collect unsupported fallback kernels during compilation.
82+
Monitors both extern kernel calls and runtime lookup.
83+
"""
84+
supported_kernels = cls.get_supported_fallback_kernels()
85+
86+
original_generate_c_shim_extern_kernel_call = (
87+
CppWrapperCpu.generate_c_shim_extern_kernel_call
88+
)
89+
original_generate_fallback_kernel_with_runtime_lookup_aot = (
90+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot
91+
)
92+
93+
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
94+
self,
95+
kernel: str,
96+
args: list[str],
97+
device: str,
98+
*,
99+
debug_args: Optional[list[str]] = None,
100+
debug_handle: Optional[int] = None,
101+
):
102+
if kernel not in supported_kernels:
103+
missing_fallback_kernels.add(kernel)
104+
105+
original_generate_c_shim_extern_kernel_call(
106+
self,
107+
kernel,
108+
args,
109+
device,
110+
debug_args=debug_args,
111+
debug_handle=debug_handle,
112+
)
113+
114+
def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
115+
self,
116+
op_overload,
117+
raw_args,
118+
output_args,
119+
raw_outputs,
120+
):
121+
kernel_name = getattr(op_overload, "_name", str(op_overload))
122+
if kernel_name not in supported_kernels:
123+
missing_fallback_kernels.add(kernel_name)
124+
125+
original_generate_fallback_kernel_with_runtime_lookup_aot(
126+
self, op_overload, raw_args, output_args, raw_outputs
127+
)
128+
129+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
130+
generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
131+
)
132+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels
133+
134+
try:
135+
yield
136+
finally:
137+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
138+
original_generate_c_shim_extern_kernel_call
139+
)
140+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = (
141+
original_generate_fallback_kernel_with_runtime_lookup_aot
142+
)
143+
144+
@classmethod
145+
def preprocess(
146+
cls,
147+
edge_program: ExportedProgram,
148+
compile_specs: List[CompileSpec],
149+
) -> PreprocessResult:
150+
"""
151+
Preprocess the edge program and compile it using AOTInductor.
152+
Weights are always separated from the SO file.
153+
"""
154+
device_name = cls.get_device_name()
155+
decomposition_table = cls.get_decomposition_table()
156+
options = cls.get_aoti_compile_options(compile_specs)
157+
158+
# Move the edge_program to the target device
159+
device_edge_program = move_to_device_pass(
160+
edge_program, device_name if device_name != "metal" else "mps"
161+
)
162+
163+
# Replace view_copy with view
164+
ReplaceViewCopyWithViewPass()(device_edge_program.graph_module)
165+
166+
# Apply custom backend-specific passes
167+
custom_passes = cls.get_custom_passes()
168+
for custom_pass in custom_passes:
169+
custom_pass(device_edge_program.graph_module)
170+
171+
# Run decompositions if any
172+
if decomposition_table:
173+
device_edge_program = device_edge_program.run_decompositions(
174+
decomposition_table
175+
)
176+
177+
edge_program_module = device_edge_program.module()
178+
179+
# Grab all input placeholders from the graph
180+
user_input_names = device_edge_program.graph_signature.user_inputs
181+
user_input_placeholders = []
182+
for node in device_edge_program.graph.nodes:
183+
if node.op == "placeholder" and node.name in user_input_names:
184+
user_input_placeholders.append(node.meta["val"])
185+
186+
# Track missing fallback kernels
187+
missing_fallback_kernels: Set[str] = set()
188+
189+
# Compile with fallback kernel collection
190+
with cls.collect_unsupported_fallback_kernels(
191+
missing_fallback_kernels
192+
), torch.no_grad():
193+
paths = torch._inductor.aot_compile(
194+
edge_program_module, tuple(user_input_placeholders), options=options
195+
)
196+
197+
if len(missing_fallback_kernels) > 0:
198+
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
199+
method_name = cls.method_name_from_compile_specs(compile_specs)
200+
raise RuntimeError(
201+
f"Method {method_name} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
202+
"Please add them to the AOTI backend."
203+
)
204+
205+
# Extract paths - weights are always separated
206+
so_path = None
207+
blob_path = None
208+
209+
if isinstance(paths, list):
210+
for path in paths:
211+
if path.endswith(".wrapper.so"):
212+
so_path = path
213+
elif path.endswith(".wrapper_weights.blob"):
214+
blob_path = path
215+
else:
216+
so_path = paths
217+
218+
if so_path is None or blob_path is None:
219+
raise RuntimeError(
220+
f"Could not find required files in compiled paths, got {paths}"
221+
)
222+
223+
# Read SO file
224+
with open(so_path, "rb") as f:
225+
so_data = f.read()
226+
227+
# Read weights blob
228+
with open(blob_path, "rb") as f:
229+
blob_data = f.read()
230+
231+
# Create named data store
232+
named_data_store = NamedDataStore()
233+
method_name = cls.method_name_from_compile_specs(compile_specs)
234+
235+
# Add SO and weights blob separately
236+
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)
237+
weights_blob_data_type = f"aoti_{device_name}_blob"
238+
named_data_store.add_named_data(
239+
method_name + "_weights_blob", blob_data, 1, weights_blob_data_type
240+
)
241+
242+
# Clean up the generated files
243+
os.remove(so_path)
244+
os.remove(blob_path)
245+
246+
return PreprocessResult(
247+
processed_bytes=b"",
248+
debug_handle_map={},
249+
data_store_output=named_data_store.get_named_data_store_output(),
250+
)
251+
252+
@classmethod
253+
def generate_method_name_compile_spec(
254+
cls,
255+
method_name: str,
256+
) -> CompileSpec:
257+
"""
258+
Generate a CompileSpec for the given method name.
259+
"""
260+
return CompileSpec(
261+
COMPILE_SPEC_KEYS.METHOD_NAME.value,
262+
method_name.encode("utf-8"),
263+
)
264+
265+
@classmethod
266+
def method_name_from_compile_specs(
267+
cls,
268+
compile_specs: List[CompileSpec],
269+
) -> str:
270+
"""
271+
Extract the method name from the compile specs.
272+
"""
273+
for spec in compile_specs:
274+
if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value:
275+
return spec.value.decode("utf-8")
276+
raise RuntimeError(
277+
f"Could not find method name in compile specs: {compile_specs}"
278+
)

backends/aoti/targets.bzl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@ def define_common_targets():
1616
],
1717
)
1818

19+
runtime.python_library(
20+
name = "aoti_backend",
21+
srcs = [
22+
"aoti_backend.py",
23+
],
24+
visibility = [
25+
"//executorch/...",
26+
],
27+
deps = [
28+
"//caffe2:torch",
29+
"//executorch/backends/aoti/passes:passes",
30+
"//executorch/exir/_serialize:lib",
31+
"//executorch/exir/backend:backend_details",
32+
"//executorch/exir/backend:compile_spec_schema",
33+
],
34+
)
35+
1936
# AOTI common shims functionality
2037
runtime.cxx_library(
2138
name = "common_shims",

0 commit comments

Comments
 (0)