From e76aa631ed36fc6a95d62af85a9403435b6adab1 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 1 Nov 2024 19:48:11 -0700 Subject: [PATCH] [iree.build] Add `turbine_generate` iree.build rule. --- iree/turbine/aot/build_actions.py | 150 ++++++++++++++++++++++++++++++ tests/aot/example_builder.py | 66 +++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 iree/turbine/aot/build_actions.py create mode 100644 tests/aot/example_builder.py diff --git a/iree/turbine/aot/build_actions.py b/iree/turbine/aot/build_actions.py new file mode 100644 index 00000000..77997d83 --- /dev/null +++ b/iree/turbine/aot/build_actions.py @@ -0,0 +1,150 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from abc import abstractmethod, ABC +import functools +from pathlib import Path +import os +import typing +import types + +import inspect + +from iree.build.executor import ActionConcurrency, BuildAction, BuildContext, BuildFile +from iree.turbine.aot.fx_programs import FxPrograms + +__all__ = [ + "turbine_generate", +] + + +class ReturnMarshaller(ABC): + @abstractmethod + def prepare_action( + self, + context: BuildContext, + name: str, + action: "TurbineBuilderAction", + return_arity: int, + ): + ... + + @abstractmethod + def save_remote_result(self, result, path: Path): + ... + + +class FxProgramsReturnMarshaller(ReturnMarshaller): + def prepare_action( + self, + context: BuildContext, + name: str, + action: "TurbineBuilderAction", + return_arity: int, + ): + # Need to allocate one file for output. + file_name = ( + f"{name}_{len(action.returns)}.mlir" if return_arity > 1 else f"{name}.mlir" + ) + output_file = context.allocate_file(file_name) + action.returns.append((self, output_file)) + output_file.deps.add(action) + + def save_remote_result(self, result, path: Path): + if not isinstance(result, FxPrograms): + raise RuntimeError( + "Turbine generator was declared to return an FxPrograms instance, " + f"but it returned {type(result)}" + ) + import iree.turbine.aot as turbine_aot + + output = turbine_aot.export(result) + output.save_mlir(path) + + +RETURN_MARSHALLERS_BY_TYPE = { + FxPrograms: FxProgramsReturnMarshaller(), +} +EXPLICIT_MARSHALLER_TYPES = list(RETURN_MARSHALLERS_BY_TYPE.keys()) + + +def get_return_marshaller(t: type) -> ReturnMarshaller: + m = RETURN_MARSHALLERS_BY_TYPE.get(t) + if m is not None: + return m + + # Do an exhaustive subclass check. + for k, m in RETURN_MARSHALLERS_BY_TYPE.items(): + if issubclass(t, k): + # Cache it. + RETURN_MARSHALLERS_BY_TYPE[t] = m + return m + raise ValueError( + f"In order to wrap a function with @turbine_builder it must be annotated with " + f"specific return types. Found '{t}' but only {EXPLICIT_MARSHALLER_TYPES} " + f"are supported" + ) + + +def unwrap_return_annotation(annot) -> list[ReturnMarshaller]: + if ( + isinstance(annot, (types.GenericAlias, typing._GenericAlias)) + and annot.__origin__ is tuple + ): + unpacked = annot.__args__ + else: + unpacked = [annot] + return [get_return_marshaller(it) for it in unpacked] + + +def turbine_generate(*, name: str, generator: callable): + sig = inspect.signature(generator, eval_str=True) + return_marshallers = unwrap_return_annotation(sig.return_annotation) + + context = BuildContext.current() + action = TurbineBuilderAction( + generator, + desc=f"Export turbine model {name}", + executor=context.executor, + ) + for rm in return_marshallers: + rm.prepare_action(context, name, action, len(return_marshallers)) + return [r[1] for r in action.returns] + + +class RemoteGenerator: + def __init__( + self, generation_thunk, return_info: list[tuple[ReturnMarshaller, Path]] + ): + self.generation_thunk = generation_thunk + self.return_info = return_info + + def __call__(self): + print("JOB PID:", os.getpid()) + results = self.generation_thunk() + if not isinstance(results, (tuple, list)): + results = [results] + if len(results) != len(self.return_info): + raise RuntimeError( + f"Turbine generator {self.generation_thunk} returned {len(results)} values, " + f"but it was declared to return {len(self.return_info)}" + ) + for result, (marshaller, output_path) in zip(results, self.return_info): + marshaller.save_remote_result(result, output_path) + + +class TurbineBuilderAction(BuildAction): + def __init__(self, thunk, concurrency=ActionConcurrency.PROCESS, **kwargs): + super().__init__(concurrency=concurrency, **kwargs) + self.thunk = thunk + self.returns: list[tuple[ReturnMarshaller, BuildFile]] = [] + + def _remotable_thunk(self): + print("SCHEDULER PID:", os.getpid()) + remotable_return_info = [ + (marshaller, bf.get_fs_path()) for marshaller, bf in self.returns + ] + return RemoteGenerator(self.thunk, remotable_return_info) diff --git a/tests/aot/example_builder.py b/tests/aot/example_builder.py new file mode 100644 index 00000000..b5d0eeac --- /dev/null +++ b/tests/aot/example_builder.py @@ -0,0 +1,66 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Tuple + +from iree.build import * +from iree.turbine.aot.build_actions import * +from iree.turbine.aot import FxProgramsBuilder + + +def export_simple_model() -> FxProgramsBuilder: + import torch + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU()) + self.branch2 = torch.nn.Sequential( + torch.nn.Linear(128, 64), torch.nn.ReLU() + ) + self.buffer = torch.ones(32) + + def forward(self, x1, x2): + out1 = self.branch1(x1) + out2 = self.branch2(x2) + return (out1 + self.buffer, out2) + + example_args = (torch.randn(32, 64), torch.randn(32, 128)) + + # Create a dynamic batch size + batch = torch.export.Dim("batch") + # Specify that the first dimension of each input is that batch size + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + + fxb = FxProgramsBuilder(M()) + + @fxb.export_program(args=example_args, dynamic_shapes=dynamic_shapes) + def dynamic_batch(module: M, x1, x2): + return module.forward(x1, x2) + + return fxb + + +@entrypoint(description="Builds an awesome pipeline") +def pipe(): + results = [] + for i in range(3): + turbine_generate( + name=f"import_stage{i}", + generator=export_simple_model, + ) + results.extend( + compile( + name=f"stage{i}", + source=f"import_stage{i}.mlir", + ) + ) + return results + + +if __name__ == "__main__": + iree_build_main()