Skip to content

Commit

Permalink
[iree.build] Add turbine_generate iree.build rule.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Nov 2, 2024
1 parent ee62366 commit e76aa63
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 0 deletions.
150 changes: 150 additions & 0 deletions iree/turbine/aot/build_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from abc import abstractmethod, ABC
import functools
from pathlib import Path
import os
import typing
import types

import inspect

from iree.build.executor import ActionConcurrency, BuildAction, BuildContext, BuildFile
from iree.turbine.aot.fx_programs import FxPrograms

__all__ = [
"turbine_generate",
]


class ReturnMarshaller(ABC):
@abstractmethod
def prepare_action(
self,
context: BuildContext,
name: str,
action: "TurbineBuilderAction",
return_arity: int,
):
...

@abstractmethod
def save_remote_result(self, result, path: Path):
...


class FxProgramsReturnMarshaller(ReturnMarshaller):
def prepare_action(
self,
context: BuildContext,
name: str,
action: "TurbineBuilderAction",
return_arity: int,
):
# Need to allocate one file for output.
file_name = (
f"{name}_{len(action.returns)}.mlir" if return_arity > 1 else f"{name}.mlir"
)
output_file = context.allocate_file(file_name)
action.returns.append((self, output_file))
output_file.deps.add(action)

def save_remote_result(self, result, path: Path):
if not isinstance(result, FxPrograms):
raise RuntimeError(
"Turbine generator was declared to return an FxPrograms instance, "
f"but it returned {type(result)}"
)
import iree.turbine.aot as turbine_aot

output = turbine_aot.export(result)
output.save_mlir(path)


RETURN_MARSHALLERS_BY_TYPE = {
FxPrograms: FxProgramsReturnMarshaller(),
}
EXPLICIT_MARSHALLER_TYPES = list(RETURN_MARSHALLERS_BY_TYPE.keys())


def get_return_marshaller(t: type) -> ReturnMarshaller:
m = RETURN_MARSHALLERS_BY_TYPE.get(t)
if m is not None:
return m

# Do an exhaustive subclass check.
for k, m in RETURN_MARSHALLERS_BY_TYPE.items():
if issubclass(t, k):
# Cache it.
RETURN_MARSHALLERS_BY_TYPE[t] = m
return m
raise ValueError(
f"In order to wrap a function with @turbine_builder it must be annotated with "
f"specific return types. Found '{t}' but only {EXPLICIT_MARSHALLER_TYPES} "
f"are supported"
)


def unwrap_return_annotation(annot) -> list[ReturnMarshaller]:
if (
isinstance(annot, (types.GenericAlias, typing._GenericAlias))
and annot.__origin__ is tuple
):
unpacked = annot.__args__
else:
unpacked = [annot]
return [get_return_marshaller(it) for it in unpacked]


def turbine_generate(*, name: str, generator: callable):
sig = inspect.signature(generator, eval_str=True)
return_marshallers = unwrap_return_annotation(sig.return_annotation)

context = BuildContext.current()
action = TurbineBuilderAction(
generator,
desc=f"Export turbine model {name}",
executor=context.executor,
)
for rm in return_marshallers:
rm.prepare_action(context, name, action, len(return_marshallers))
return [r[1] for r in action.returns]


class RemoteGenerator:
def __init__(
self, generation_thunk, return_info: list[tuple[ReturnMarshaller, Path]]
):
self.generation_thunk = generation_thunk
self.return_info = return_info

def __call__(self):
print("JOB PID:", os.getpid())
results = self.generation_thunk()
if not isinstance(results, (tuple, list)):
results = [results]
if len(results) != len(self.return_info):
raise RuntimeError(
f"Turbine generator {self.generation_thunk} returned {len(results)} values, "
f"but it was declared to return {len(self.return_info)}"
)
for result, (marshaller, output_path) in zip(results, self.return_info):
marshaller.save_remote_result(result, output_path)


class TurbineBuilderAction(BuildAction):
def __init__(self, thunk, concurrency=ActionConcurrency.PROCESS, **kwargs):
super().__init__(concurrency=concurrency, **kwargs)
self.thunk = thunk
self.returns: list[tuple[ReturnMarshaller, BuildFile]] = []

def _remotable_thunk(self):
print("SCHEDULER PID:", os.getpid())
remotable_return_info = [
(marshaller, bf.get_fs_path()) for marshaller, bf in self.returns
]
return RemoteGenerator(self.thunk, remotable_return_info)
66 changes: 66 additions & 0 deletions tests/aot/example_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Tuple

from iree.build import *
from iree.turbine.aot.build_actions import *
from iree.turbine.aot import FxProgramsBuilder


def export_simple_model() -> FxProgramsBuilder:
import torch

class M(torch.nn.Module):
def __init__(self):
super().__init__()

self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU())
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)

def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

fxb = FxProgramsBuilder(M())

@fxb.export_program(args=example_args, dynamic_shapes=dynamic_shapes)
def dynamic_batch(module: M, x1, x2):
return module.forward(x1, x2)

return fxb


@entrypoint(description="Builds an awesome pipeline")
def pipe():
results = []
for i in range(3):
turbine_generate(
name=f"import_stage{i}",
generator=export_simple_model,
)
results.extend(
compile(
name=f"stage{i}",
source=f"import_stage{i}.mlir",
)
)
return results


if __name__ == "__main__":
iree_build_main()

0 comments on commit e76aa63

Please sign in to comment.