-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[iree.build] Add
turbine_generate
iree.build rule.
- Loading branch information
1 parent
ee62366
commit e76aa63
Showing
2 changed files
with
216 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from abc import abstractmethod, ABC | ||
import functools | ||
from pathlib import Path | ||
import os | ||
import typing | ||
import types | ||
|
||
import inspect | ||
|
||
from iree.build.executor import ActionConcurrency, BuildAction, BuildContext, BuildFile | ||
from iree.turbine.aot.fx_programs import FxPrograms | ||
|
||
__all__ = [ | ||
"turbine_generate", | ||
] | ||
|
||
|
||
class ReturnMarshaller(ABC): | ||
@abstractmethod | ||
def prepare_action( | ||
self, | ||
context: BuildContext, | ||
name: str, | ||
action: "TurbineBuilderAction", | ||
return_arity: int, | ||
): | ||
... | ||
|
||
@abstractmethod | ||
def save_remote_result(self, result, path: Path): | ||
... | ||
|
||
|
||
class FxProgramsReturnMarshaller(ReturnMarshaller): | ||
def prepare_action( | ||
self, | ||
context: BuildContext, | ||
name: str, | ||
action: "TurbineBuilderAction", | ||
return_arity: int, | ||
): | ||
# Need to allocate one file for output. | ||
file_name = ( | ||
f"{name}_{len(action.returns)}.mlir" if return_arity > 1 else f"{name}.mlir" | ||
) | ||
output_file = context.allocate_file(file_name) | ||
action.returns.append((self, output_file)) | ||
output_file.deps.add(action) | ||
|
||
def save_remote_result(self, result, path: Path): | ||
if not isinstance(result, FxPrograms): | ||
raise RuntimeError( | ||
"Turbine generator was declared to return an FxPrograms instance, " | ||
f"but it returned {type(result)}" | ||
) | ||
import iree.turbine.aot as turbine_aot | ||
|
||
output = turbine_aot.export(result) | ||
output.save_mlir(path) | ||
|
||
|
||
RETURN_MARSHALLERS_BY_TYPE = { | ||
FxPrograms: FxProgramsReturnMarshaller(), | ||
} | ||
EXPLICIT_MARSHALLER_TYPES = list(RETURN_MARSHALLERS_BY_TYPE.keys()) | ||
|
||
|
||
def get_return_marshaller(t: type) -> ReturnMarshaller: | ||
m = RETURN_MARSHALLERS_BY_TYPE.get(t) | ||
if m is not None: | ||
return m | ||
|
||
# Do an exhaustive subclass check. | ||
for k, m in RETURN_MARSHALLERS_BY_TYPE.items(): | ||
if issubclass(t, k): | ||
# Cache it. | ||
RETURN_MARSHALLERS_BY_TYPE[t] = m | ||
return m | ||
raise ValueError( | ||
f"In order to wrap a function with @turbine_builder it must be annotated with " | ||
f"specific return types. Found '{t}' but only {EXPLICIT_MARSHALLER_TYPES} " | ||
f"are supported" | ||
) | ||
|
||
|
||
def unwrap_return_annotation(annot) -> list[ReturnMarshaller]: | ||
if ( | ||
isinstance(annot, (types.GenericAlias, typing._GenericAlias)) | ||
and annot.__origin__ is tuple | ||
): | ||
unpacked = annot.__args__ | ||
else: | ||
unpacked = [annot] | ||
return [get_return_marshaller(it) for it in unpacked] | ||
|
||
|
||
def turbine_generate(*, name: str, generator: callable): | ||
sig = inspect.signature(generator, eval_str=True) | ||
return_marshallers = unwrap_return_annotation(sig.return_annotation) | ||
|
||
context = BuildContext.current() | ||
action = TurbineBuilderAction( | ||
generator, | ||
desc=f"Export turbine model {name}", | ||
executor=context.executor, | ||
) | ||
for rm in return_marshallers: | ||
rm.prepare_action(context, name, action, len(return_marshallers)) | ||
return [r[1] for r in action.returns] | ||
|
||
|
||
class RemoteGenerator: | ||
def __init__( | ||
self, generation_thunk, return_info: list[tuple[ReturnMarshaller, Path]] | ||
): | ||
self.generation_thunk = generation_thunk | ||
self.return_info = return_info | ||
|
||
def __call__(self): | ||
print("JOB PID:", os.getpid()) | ||
results = self.generation_thunk() | ||
if not isinstance(results, (tuple, list)): | ||
results = [results] | ||
if len(results) != len(self.return_info): | ||
raise RuntimeError( | ||
f"Turbine generator {self.generation_thunk} returned {len(results)} values, " | ||
f"but it was declared to return {len(self.return_info)}" | ||
) | ||
for result, (marshaller, output_path) in zip(results, self.return_info): | ||
marshaller.save_remote_result(result, output_path) | ||
|
||
|
||
class TurbineBuilderAction(BuildAction): | ||
def __init__(self, thunk, concurrency=ActionConcurrency.PROCESS, **kwargs): | ||
super().__init__(concurrency=concurrency, **kwargs) | ||
self.thunk = thunk | ||
self.returns: list[tuple[ReturnMarshaller, BuildFile]] = [] | ||
|
||
def _remotable_thunk(self): | ||
print("SCHEDULER PID:", os.getpid()) | ||
remotable_return_info = [ | ||
(marshaller, bf.get_fs_path()) for marshaller, bf in self.returns | ||
] | ||
return RemoteGenerator(self.thunk, remotable_return_info) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Tuple | ||
|
||
from iree.build import * | ||
from iree.turbine.aot.build_actions import * | ||
from iree.turbine.aot import FxProgramsBuilder | ||
|
||
|
||
def export_simple_model() -> FxProgramsBuilder: | ||
import torch | ||
|
||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU()) | ||
self.branch2 = torch.nn.Sequential( | ||
torch.nn.Linear(128, 64), torch.nn.ReLU() | ||
) | ||
self.buffer = torch.ones(32) | ||
|
||
def forward(self, x1, x2): | ||
out1 = self.branch1(x1) | ||
out2 = self.branch2(x2) | ||
return (out1 + self.buffer, out2) | ||
|
||
example_args = (torch.randn(32, 64), torch.randn(32, 128)) | ||
|
||
# Create a dynamic batch size | ||
batch = torch.export.Dim("batch") | ||
# Specify that the first dimension of each input is that batch size | ||
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} | ||
|
||
fxb = FxProgramsBuilder(M()) | ||
|
||
@fxb.export_program(args=example_args, dynamic_shapes=dynamic_shapes) | ||
def dynamic_batch(module: M, x1, x2): | ||
return module.forward(x1, x2) | ||
|
||
return fxb | ||
|
||
|
||
@entrypoint(description="Builds an awesome pipeline") | ||
def pipe(): | ||
results = [] | ||
for i in range(3): | ||
turbine_generate( | ||
name=f"import_stage{i}", | ||
generator=export_simple_model, | ||
) | ||
results.extend( | ||
compile( | ||
name=f"stage{i}", | ||
source=f"import_stage{i}.mlir", | ||
) | ||
) | ||
return results | ||
|
||
|
||
if __name__ == "__main__": | ||
iree_build_main() |