-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[iree.build] Add
turbine_generate
iree.build rule. (#249)
This supports performing parallelizable aot model export as part of the iree.build tool. By default, it runs actions out of process, meaning that multiple torch.export and torch-mlir imports can be running with true concurrency. --------- Signed-off-by: Stella Laurenzo <[email protected]>
- Loading branch information
1 parent
9e79f4e
commit d1dfb3c
Showing
5 changed files
with
342 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Callable | ||
|
||
from abc import abstractmethod, ABC | ||
from pathlib import Path | ||
import typing | ||
import types | ||
|
||
import inspect | ||
|
||
from iree.build.executor import ActionConcurrency, BuildAction, BuildContext, BuildFile | ||
from iree.turbine.aot.exporter import ExportOutput | ||
|
||
__all__ = [ | ||
"turbine_generate", | ||
] | ||
|
||
|
||
def turbine_generate( | ||
generator: Callable, | ||
*args, | ||
name: str, | ||
out_of_process: bool = True, | ||
**kwargs, | ||
): | ||
"""Invokes a user-defined generator callable as an action, performing turbine | ||
import and storing the resulting artifacts as outputs. | ||
Because torch-based generation is usually quite slow and a bottleneck, this | ||
action takes pains to use the out of process action pool, allowing multiple | ||
generation activities to take place concurrently. Since this requires interacting | ||
with the pickle infrastructure, it puts some constraints on usage: | ||
* generator must be a pickleable callable. In practice, this means that it must | ||
be a named function at module scope (without decorator) or a named class at | ||
module scope with a `__call__` method. | ||
* args and kwargs must be pickleable. In practice, this means primitive values. | ||
Arguments to the generator are taken from the positional and unmatched keyword | ||
arguments passed to `turbine_generate`. | ||
The generator makes artifacts available as outputs by returning corresponding | ||
Python instances (which must be declared as typing parameters for the remoting | ||
to work): | ||
* `ExportOutput`: The result of calling `aot.export(...)` will result in | ||
`save_mlir()` being called on it while still in the subprocess to write to | ||
a file names `{name}.mlir` if there is one return or `{name}_{n}.mlir` if | ||
multiple. | ||
By default, import is run in a subprocess pool. It can be run in the main | ||
process by passing `out_of_process=False`. | ||
See testing/example_builder.py for an example. | ||
""" | ||
sig = inspect.signature(generator, eval_str=True) | ||
return_marshallers = unwrap_return_annotation(sig.return_annotation) | ||
|
||
context = BuildContext.current() | ||
action = TurbineBuilderAction( | ||
generator, | ||
args, | ||
kwargs, | ||
desc=f"Export turbine model {name}", | ||
executor=context.executor, | ||
concurrency=( | ||
ActionConcurrency.PROCESS if out_of_process else ActionConcurrency.THREAD | ||
), | ||
) | ||
for rm in return_marshallers: | ||
rm.prepare_action(context, name, action, len(return_marshallers)) | ||
return [r[1] for r in action.returns] | ||
|
||
|
||
class ReturnMarshaller(ABC): | ||
@abstractmethod | ||
def prepare_action( | ||
self, | ||
context: BuildContext, | ||
name: str, | ||
action: "TurbineBuilderAction", | ||
return_arity: int, | ||
): | ||
... | ||
|
||
@abstractmethod | ||
def save_remote_result(self, result, path: Path): | ||
... | ||
|
||
|
||
class ExportOutputReturnMarshaller(ReturnMarshaller): | ||
def prepare_action( | ||
self, | ||
context: BuildContext, | ||
name: str, | ||
action: "TurbineBuilderAction", | ||
return_arity: int, | ||
): | ||
# Need to allocate one file for output. | ||
file_name = ( | ||
f"{name}_{len(action.returns)}.mlir" if return_arity > 1 else f"{name}.mlir" | ||
) | ||
output_file = context.allocate_file(file_name) | ||
action.returns.append((self, output_file)) | ||
output_file.deps.add(action) | ||
|
||
def save_remote_result(self, result, path: Path): | ||
if not isinstance(result, ExportOutput): | ||
raise RuntimeError( | ||
"Turbine generator was declared to return an ExportOutput instance, " | ||
f"but it returned {type(result)}" | ||
) | ||
result.save_mlir(path) | ||
|
||
|
||
RETURN_MARSHALLERS_BY_TYPE: dict[type, ReturnMarshaller] = { | ||
ExportOutput: ExportOutputReturnMarshaller(), | ||
} | ||
EXPLICIT_MARSHALLER_TYPES = list(RETURN_MARSHALLERS_BY_TYPE.keys()) | ||
|
||
|
||
def get_return_marshaller(t: type) -> ReturnMarshaller: | ||
m = RETURN_MARSHALLERS_BY_TYPE.get(t) | ||
if m is not None: | ||
return m | ||
|
||
# Do an exhaustive subclass check. | ||
for k, m in RETURN_MARSHALLERS_BY_TYPE.items(): | ||
if issubclass(t, k): | ||
# Cache it. | ||
RETURN_MARSHALLERS_BY_TYPE[t] = m | ||
return m | ||
raise ValueError( | ||
f"In order to use a callable as a generator in turbine_generate, it must be " | ||
f" annotated with specific return types. Found '{t}' but only " | ||
f"{EXPLICIT_MARSHALLER_TYPES} are supported" | ||
) | ||
|
||
|
||
def unwrap_return_annotation(annot) -> list[ReturnMarshaller]: | ||
# typing._GenericAlias is used to unwrap old-style (i.e. `List`) collection | ||
# aliases. We special case this and it can be removed eventually. | ||
_GenericAlias = getattr(typing, "_GenericAlias", None) | ||
is_generic_alias = isinstance(annot, (types.GenericAlias)) or ( | ||
_GenericAlias and isinstance(annot, _GenericAlias) | ||
) | ||
if is_generic_alias and annot.__origin__ is tuple: | ||
unpacked = annot.__args__ | ||
else: | ||
unpacked = [annot] | ||
return [get_return_marshaller(it) for it in unpacked] | ||
|
||
|
||
class RemoteGenerator: | ||
def __init__( | ||
self, | ||
generation_thunk, | ||
thunk_args, | ||
thunk_kwargs, | ||
return_info: list[tuple[ReturnMarshaller, Path]], | ||
): | ||
self.generation_thunk = generation_thunk | ||
self.thunk_args = thunk_args | ||
self.thunk_kwargs = thunk_kwargs | ||
self.return_info = return_info | ||
|
||
def __call__(self): | ||
results = self.generation_thunk(*self.thunk_args, **self.thunk_kwargs) | ||
if not isinstance(results, (tuple, list)): | ||
results = [results] | ||
if len(results) != len(self.return_info): | ||
raise RuntimeError( | ||
f"Turbine generator {self.generation_thunk} returned {len(results)} values, " | ||
f"but it was declared to return {len(self.return_info)}" | ||
) | ||
for result, (marshaller, output_path) in zip(results, self.return_info): | ||
marshaller.save_remote_result(result, output_path) | ||
|
||
|
||
class TurbineBuilderAction(BuildAction): | ||
def __init__( | ||
self, | ||
thunk, | ||
thunk_args, | ||
thunk_kwargs, | ||
concurrency, | ||
**kwargs, | ||
): | ||
super().__init__(concurrency=concurrency, **kwargs) | ||
self.thunk = thunk | ||
self.thunk_args = thunk_args | ||
self.thunk_kwargs = thunk_kwargs | ||
self.returns: list[tuple[ReturnMarshaller, BuildFile]] = [] | ||
|
||
def _remotable_thunk(self): | ||
remotable_return_info = [ | ||
(marshaller, bf.get_fs_path()) for marshaller, bf in self.returns | ||
] | ||
return RemoteGenerator( | ||
self.thunk, self.thunk_args, self.thunk_kwargs, remotable_return_info | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
"""Test iree.build file for exercising turbine_generate. | ||
Since we can only do in-process testing of real modules (not dynamically loaded) | ||
across process boundaries, this builder must live in the source tree vs the | ||
tests tree. | ||
""" | ||
|
||
import os | ||
|
||
from iree.build import compile, entrypoint, iree_build_main | ||
from iree.turbine.aot.build_actions import * | ||
from iree.turbine.aot import ( | ||
ExportOutput, | ||
FxProgramsBuilder, | ||
export, | ||
externalize_module_parameters, | ||
) | ||
|
||
|
||
def export_simple_model(batch_size: int | None = None) -> ExportOutput: | ||
import torch | ||
|
||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU()) | ||
self.branch2 = torch.nn.Sequential( | ||
torch.nn.Linear(128, 64), torch.nn.ReLU() | ||
) | ||
self.buffer = torch.ones(32) | ||
|
||
def forward(self, x1, x2): | ||
out1 = self.branch1(x1) | ||
out2 = self.branch2(x2) | ||
return (out1 + self.buffer, out2) | ||
|
||
example_bs = 2 if batch_size is None else batch_size | ||
example_args = (torch.randn(example_bs, 64), torch.randn(example_bs, 128)) | ||
|
||
# Create a dynamic batch size | ||
if batch_size is None: | ||
batch = torch.export.Dim("batch") | ||
# Specify that the first dimension of each input is that batch size | ||
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} | ||
else: | ||
dynamic_shapes = {} | ||
|
||
module = M() | ||
externalize_module_parameters(module) | ||
fxb = FxProgramsBuilder(module) | ||
print(f" [{os.getpid()}] Compiling with dynamic shapes: {dynamic_shapes}") | ||
|
||
@fxb.export_program(args=example_args, dynamic_shapes=dynamic_shapes) | ||
def dynamic_batch(module: M, x1, x2): | ||
return module.forward(x1, x2) | ||
|
||
return export(fxb) | ||
|
||
|
||
@entrypoint(description="Builds an awesome pipeline") | ||
def pipe(): | ||
print(f"Main pid: {os.getpid()}") | ||
results = [] | ||
for i in range(3): | ||
turbine_generate( | ||
export_simple_model, | ||
batch_size=None if i == 0 else i * 10, | ||
name=f"import_stage{i}", | ||
out_of_process=i > 0, | ||
) | ||
results.extend( | ||
compile( | ||
name=f"stage{i}", | ||
source=f"import_stage{i}.mlir", | ||
) | ||
) | ||
return results | ||
|
||
|
||
if __name__ == "__main__": | ||
iree_build_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import io | ||
|
||
from iree.build import * | ||
|
||
|
||
def test_example_builder(tmp_path): | ||
from iree.turbine.aot.testing import example_builder | ||
|
||
iree_build_main( | ||
example_builder, | ||
args=( | ||
f"--output-dir={tmp_path}", | ||
"--iree-hal-target-device=cpu", | ||
"--iree-llvmcpu-target-cpu=host", | ||
), | ||
) | ||
|
||
# Should have compiled three outputs. | ||
for output_name in [ | ||
"bin/pipe/stage0_cpu-host.vmfb", | ||
"bin/pipe/stage1_cpu-host.vmfb", | ||
"bin/pipe/stage2_cpu-host.vmfb", | ||
]: | ||
output_path = tmp_path / output_name | ||
assert output_path.exists() | ||
|
||
# Should have generated with a dynamic batch and two fixed batch sizes. | ||
for gen_name, contains_str in [ | ||
("genfiles/pipe/import_stage0.mlir", "!torch.vtensor<[?,64],f32>"), | ||
("genfiles/pipe/import_stage1.mlir", "!torch.vtensor<[10,64],f32>"), | ||
("genfiles/pipe/import_stage2.mlir", "!torch.vtensor<[20,64],f32>"), | ||
]: | ||
gen_path = tmp_path / gen_name | ||
contents = gen_path.read_text() | ||
print(contents) | ||
assert contains_str in contents |