|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +import contextlib | 
|  | 8 | +import os | 
|  | 9 | +import typing | 
|  | 10 | +from enum import Enum | 
|  | 11 | + | 
|  | 12 | +from typing import Any, Dict, final, List, Optional, Set | 
|  | 13 | + | 
|  | 14 | +import torch | 
|  | 15 | +from executorch.backends.apple.metal.replace_slice_copy_with_slice import ( | 
|  | 16 | +    ReplaceSliceCopyWithSlicePass, | 
|  | 17 | +) | 
|  | 18 | +from executorch.exir._serialize._named_data_store import NamedDataStore | 
|  | 19 | +from executorch.exir._warnings import experimental | 
|  | 20 | +from executorch.exir.backend.backend_details import ( | 
|  | 21 | +    BackendDetails, | 
|  | 22 | +    ExportedProgram, | 
|  | 23 | +    PreprocessResult, | 
|  | 24 | +) | 
|  | 25 | +from executorch.exir.backend.compile_spec_schema import CompileSpec | 
|  | 26 | +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu | 
|  | 27 | +from torch.export.passes import move_to_device_pass | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +# exist fallback operators in et namespace; | 
|  | 31 | +supported_fallback_kernels: Dict[str, Any] = { | 
|  | 32 | +    "aoti_torch_mps_addmm_out": None, | 
|  | 33 | +    "aoti_torch_mps_convolution": None, | 
|  | 34 | +    "aoti_torch_mps_mm_out": None, | 
|  | 35 | +    "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, | 
|  | 36 | +} | 
|  | 37 | + | 
|  | 38 | +# required fallback kernels but not supported | 
|  | 39 | +missing_fallback_kernels: Set[str] = set() | 
|  | 40 | + | 
|  | 41 | + | 
|  | 42 | +class COMPILE_SPEC_KEYS(Enum): | 
|  | 43 | +    METHOD_NAME = "method_name" | 
|  | 44 | + | 
|  | 45 | + | 
|  | 46 | +# context manager for non-fallback guarantee | 
|  | 47 | +# it will raise exception when generating fallback kernels during aoti compile | 
|  | 48 | +@contextlib.contextmanager | 
|  | 49 | +def collect_unsupported_fallback_kernels(): | 
|  | 50 | +    original_generate_c_shim_extern_kernel_call = ( | 
|  | 51 | +        CppWrapperCpu.generate_c_shim_extern_kernel_call | 
|  | 52 | +    ) | 
|  | 53 | + | 
|  | 54 | +    def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( | 
|  | 55 | +        self, | 
|  | 56 | +        kernel: str, | 
|  | 57 | +        args: list[str], | 
|  | 58 | +        device: str, | 
|  | 59 | +        *, | 
|  | 60 | +        debug_args: Optional[list[str]] = None, | 
|  | 61 | +        debug_handle: Optional[int] = None, | 
|  | 62 | +    ): | 
|  | 63 | +        if kernel not in supported_fallback_kernels: | 
|  | 64 | +            missing_fallback_kernels.add(kernel) | 
|  | 65 | + | 
|  | 66 | +        original_generate_c_shim_extern_kernel_call( | 
|  | 67 | +            self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle | 
|  | 68 | +        ) | 
|  | 69 | + | 
|  | 70 | +    CppWrapperCpu.generate_c_shim_extern_kernel_call = ( | 
|  | 71 | +        generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels | 
|  | 72 | +    ) | 
|  | 73 | +    try: | 
|  | 74 | +        yield | 
|  | 75 | +    finally: | 
|  | 76 | +        CppWrapperCpu.generate_c_shim_extern_kernel_call = ( | 
|  | 77 | +            original_generate_c_shim_extern_kernel_call | 
|  | 78 | +        ) | 
|  | 79 | + | 
|  | 80 | + | 
|  | 81 | +@final | 
|  | 82 | +@experimental( | 
|  | 83 | +    "This API and all of Metal backend related functionality are experimental." | 
|  | 84 | +) | 
|  | 85 | +class MetalBackend(BackendDetails): | 
|  | 86 | +    @staticmethod | 
|  | 87 | +    def preprocess( | 
|  | 88 | +        edge_program: ExportedProgram, | 
|  | 89 | +        compile_specs: List[CompileSpec], | 
|  | 90 | +    ) -> PreprocessResult: | 
|  | 91 | +        print("entering the lowerable parts in MetalBackend.preprocess....") | 
|  | 92 | +        # Move the edge_program from CPU to MPS for aoti compile | 
|  | 93 | +        mps_edge_program = move_to_device_pass(edge_program, "mps") | 
|  | 94 | + | 
|  | 95 | +        # replace slice_copy with slice | 
|  | 96 | +        ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module) | 
|  | 97 | + | 
|  | 98 | +        edge_program_module = mps_edge_program.module() | 
|  | 99 | + | 
|  | 100 | +        # Grab all input placeholders from the graph | 
|  | 101 | +        user_input_names = mps_edge_program.graph_signature.user_inputs | 
|  | 102 | +        user_input_placeholders = [] | 
|  | 103 | +        for node in mps_edge_program.graph.nodes: | 
|  | 104 | +            if node.op == "placeholder" and node.name in user_input_names: | 
|  | 105 | +                user_input_placeholders.append(node.meta["val"]) | 
|  | 106 | + | 
|  | 107 | +        # Base options for all devices | 
|  | 108 | +        options: dict[str, typing.Any] = { | 
|  | 109 | +            # Do not link against the full PyTorch/libtorch library | 
|  | 110 | +            "aot_inductor.link_libtorch": False, | 
|  | 111 | +            # Package model constants and other generated files directly in the shared object (.so) file | 
|  | 112 | +            "aot_inductor.package_constants_in_so": True, | 
|  | 113 | +            # Enable maximum automatic tuning for optimal performance | 
|  | 114 | +            "max_autotune": True, | 
|  | 115 | +            # "aot_inductor.debug_compile": True, | 
|  | 116 | +            # "aot_inductor.force_mmap_weights": False, | 
|  | 117 | +        } | 
|  | 118 | + | 
|  | 119 | +        with collect_unsupported_fallback_kernels(): | 
|  | 120 | +            so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options)  # type: ignore[arg-type] | 
|  | 121 | +            if len(missing_fallback_kernels) > 0: | 
|  | 122 | +                formatted_kernels = "\n  - ".join(sorted(missing_fallback_kernels)) | 
|  | 123 | +                raise RuntimeError( | 
|  | 124 | +                    f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n  - {formatted_kernels}\n" | 
|  | 125 | +                    "Please add them to the AOTI backend." | 
|  | 126 | +                ) | 
|  | 127 | + | 
|  | 128 | +        # pyre-ignorep[6]: Incompatible parameter type | 
|  | 129 | +        with open(so_path, "rb") as f: | 
|  | 130 | +            so_data = f.read() | 
|  | 131 | + | 
|  | 132 | +        named_data_store = NamedDataStore() | 
|  | 133 | +        method_name = MetalBackend.method_name_from_compile_specs(compile_specs) | 
|  | 134 | +        named_data_store.add_named_data( | 
|  | 135 | +            method_name + "_so_blob", so_data, 1, "aoti_metal_blob" | 
|  | 136 | +        ) | 
|  | 137 | + | 
|  | 138 | +        # Clean up the generated so file; it has been packaged into the NamdeDataStore | 
|  | 139 | +        # pyre-ignorep[6]: Incompatible parameter type | 
|  | 140 | +        os.remove(so_path) | 
|  | 141 | + | 
|  | 142 | +        return PreprocessResult( | 
|  | 143 | +            processed_bytes=b"", | 
|  | 144 | +            debug_handle_map={}, | 
|  | 145 | +            data_store_output=named_data_store.get_named_data_store_output(), | 
|  | 146 | +        ) | 
|  | 147 | + | 
|  | 148 | +    @staticmethod | 
|  | 149 | +    def generate_method_name_compile_spec( | 
|  | 150 | +        method_name: str, | 
|  | 151 | +    ) -> CompileSpec: | 
|  | 152 | +        """ | 
|  | 153 | +        Generates a CompileSpec for the given method name. | 
|  | 154 | +        """ | 
|  | 155 | +        return CompileSpec( | 
|  | 156 | +            COMPILE_SPEC_KEYS.METHOD_NAME.value, | 
|  | 157 | +            method_name.encode("utf-8"), | 
|  | 158 | +        ) | 
|  | 159 | + | 
|  | 160 | +    @staticmethod | 
|  | 161 | +    def method_name_from_compile_specs( | 
|  | 162 | +        compile_specs: List[CompileSpec], | 
|  | 163 | +    ) -> str: | 
|  | 164 | +        """ | 
|  | 165 | +        Returns the method name from the compile specs. | 
|  | 166 | +        """ | 
|  | 167 | +        for spec in compile_specs: | 
|  | 168 | +            if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: | 
|  | 169 | +                return spec.value.decode("utf-8") | 
|  | 170 | +        raise RuntimeError( | 
|  | 171 | +            f"Could not find method name in compile specs: {compile_specs}" | 
|  | 172 | +        ) | 
0 commit comments