-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[iree.build] Add turbine_generate
iree.build rule.
#249
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Callable | ||
|
||
from abc import abstractmethod, ABC | ||
from pathlib import Path | ||
import typing | ||
import types | ||
|
||
import inspect | ||
|
||
from iree.build.executor import ActionConcurrency, BuildAction, BuildContext, BuildFile | ||
from iree.turbine.aot.exporter import ExportOutput | ||
|
||
__all__ = [ | ||
"turbine_generate", | ||
] | ||
|
||
|
||
def turbine_generate( | ||
generator: Callable, | ||
*args, | ||
name: str, | ||
out_of_process: bool = True, | ||
**kwargs, | ||
): | ||
"""Invokes a user-defined generator callable as an action, performing turbine | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider renaming "generator callable" to "exporter callable" and/or explain that it is a callable that returns an exported torch module. For me, "generator callable" resolves to the python language feature, "generator functions", like def fib():
prev = 0
cur = 1
while True:
yield cur
prev, cur = cur, cur+prev There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm running out of daylight before vacation to land changes, but sounds like a reasonable idea. Maybe you could send a patch next week and work with Ean/Scott on a better name? I'll land it like this but it has no use yet -- easy to rename. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good & will do! |
||
import and storing the resulting artifacts as outputs. | ||
|
||
Because torch-based generation is usually quite slow and a bottleneck, this | ||
action takes pains to use the out of process action pool, allowing multiple | ||
generation activities to take place concurrently. Since this requires interacting | ||
with the pickle infrastructure, it puts some constraints on usage: | ||
|
||
* generator must be a pickleable callable. In practice, this means that it must | ||
be a named function at module scope (without decorator) or a named class at | ||
module scope with a `__call__` method. | ||
* args and kwargs must be pickleable. In practice, this means primitive values. | ||
|
||
Arguments to the generator are taken from the positional and unmatched keyword | ||
arguments passed to `turbine_generate`. | ||
|
||
The generator makes artifacts available as outputs by returning corresponding | ||
Python instances (which must be declared as typing parameters for the remoting | ||
to work): | ||
|
||
* `ExportOutput`: The result of calling `aot.export(...)` will result in | ||
`save_mlir()` being called on it while still in the subprocess to write to | ||
a file names `{name}.mlir` if there is one return or `{name}_{n}.mlir` if | ||
multiple. | ||
Comment on lines
+51
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this output There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It needs to grow some options for mlirbc. Have some thoughts on how to plug parameter generation through in an ergonomic way. One of:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (but wanted to land this before adding more) |
||
|
||
By default, import is run in a subprocess pool. It can be run in the main | ||
process by passing `out_of_process=False`. | ||
|
||
See testing/example_builder.py for an example. | ||
""" | ||
sig = inspect.signature(generator, eval_str=True) | ||
return_marshallers = unwrap_return_annotation(sig.return_annotation) | ||
|
||
context = BuildContext.current() | ||
action = TurbineBuilderAction( | ||
generator, | ||
args, | ||
kwargs, | ||
desc=f"Export turbine model {name}", | ||
executor=context.executor, | ||
concurrency=( | ||
ActionConcurrency.PROCESS if out_of_process else ActionConcurrency.THREAD | ||
), | ||
) | ||
for rm in return_marshallers: | ||
rm.prepare_action(context, name, action, len(return_marshallers)) | ||
return [r[1] for r in action.returns] | ||
|
||
|
||
class ReturnMarshaller(ABC): | ||
@abstractmethod | ||
def prepare_action( | ||
self, | ||
context: BuildContext, | ||
name: str, | ||
action: "TurbineBuilderAction", | ||
return_arity: int, | ||
): | ||
... | ||
|
||
@abstractmethod | ||
def save_remote_result(self, result, path: Path): | ||
... | ||
|
||
|
||
class ExportOutputReturnMarshaller(ReturnMarshaller): | ||
def prepare_action( | ||
self, | ||
context: BuildContext, | ||
name: str, | ||
action: "TurbineBuilderAction", | ||
return_arity: int, | ||
): | ||
# Need to allocate one file for output. | ||
file_name = ( | ||
f"{name}_{len(action.returns)}.mlir" if return_arity > 1 else f"{name}.mlir" | ||
) | ||
output_file = context.allocate_file(file_name) | ||
action.returns.append((self, output_file)) | ||
output_file.deps.add(action) | ||
|
||
def save_remote_result(self, result, path: Path): | ||
if not isinstance(result, ExportOutput): | ||
raise RuntimeError( | ||
"Turbine generator was declared to return an ExportOutput instance, " | ||
f"but it returned {type(result)}" | ||
) | ||
result.save_mlir(path) | ||
|
||
|
||
RETURN_MARSHALLERS_BY_TYPE: dict[type, ReturnMarshaller] = { | ||
ExportOutput: ExportOutputReturnMarshaller(), | ||
} | ||
EXPLICIT_MARSHALLER_TYPES = list(RETURN_MARSHALLERS_BY_TYPE.keys()) | ||
|
||
|
||
def get_return_marshaller(t: type) -> ReturnMarshaller: | ||
m = RETURN_MARSHALLERS_BY_TYPE.get(t) | ||
if m is not None: | ||
return m | ||
|
||
# Do an exhaustive subclass check. | ||
for k, m in RETURN_MARSHALLERS_BY_TYPE.items(): | ||
if issubclass(t, k): | ||
# Cache it. | ||
RETURN_MARSHALLERS_BY_TYPE[t] = m | ||
return m | ||
raise ValueError( | ||
f"In order to use a callable as a generator in turbine_generate, it must be " | ||
f" annotated with specific return types. Found '{t}' but only " | ||
f"{EXPLICIT_MARSHALLER_TYPES} are supported" | ||
) | ||
Comment on lines
+132
to
+142
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👀 Python magic |
||
|
||
|
||
def unwrap_return_annotation(annot) -> list[ReturnMarshaller]: | ||
# typing._GenericAlias is used to unwrap old-style (i.e. `List`) collection | ||
# aliases. We special case this and it can be removed eventually. | ||
_GenericAlias = getattr(typing, "_GenericAlias", None) | ||
is_generic_alias = isinstance(annot, (types.GenericAlias)) or ( | ||
_GenericAlias and isinstance(annot, _GenericAlias) | ||
) | ||
if is_generic_alias and annot.__origin__ is tuple: | ||
unpacked = annot.__args__ | ||
else: | ||
unpacked = [annot] | ||
return [get_return_marshaller(it) for it in unpacked] | ||
|
||
|
||
class RemoteGenerator: | ||
def __init__( | ||
self, | ||
generation_thunk, | ||
thunk_args, | ||
thunk_kwargs, | ||
return_info: list[tuple[ReturnMarshaller, Path]], | ||
): | ||
self.generation_thunk = generation_thunk | ||
self.thunk_args = thunk_args | ||
self.thunk_kwargs = thunk_kwargs | ||
self.return_info = return_info | ||
|
||
def __call__(self): | ||
results = self.generation_thunk(*self.thunk_args, **self.thunk_kwargs) | ||
if not isinstance(results, (tuple, list)): | ||
results = [results] | ||
if len(results) != len(self.return_info): | ||
raise RuntimeError( | ||
f"Turbine generator {self.generation_thunk} returned {len(results)} values, " | ||
f"but it was declared to return {len(self.return_info)}" | ||
) | ||
for result, (marshaller, output_path) in zip(results, self.return_info): | ||
marshaller.save_remote_result(result, output_path) | ||
|
||
|
||
class TurbineBuilderAction(BuildAction): | ||
def __init__( | ||
self, | ||
thunk, | ||
thunk_args, | ||
thunk_kwargs, | ||
concurrency, | ||
**kwargs, | ||
): | ||
super().__init__(concurrency=concurrency, **kwargs) | ||
self.thunk = thunk | ||
self.thunk_args = thunk_args | ||
self.thunk_kwargs = thunk_kwargs | ||
self.returns: list[tuple[ReturnMarshaller, BuildFile]] = [] | ||
|
||
def _remotable_thunk(self): | ||
remotable_return_info = [ | ||
(marshaller, bf.get_fs_path()) for marshaller, bf in self.returns | ||
] | ||
return RemoteGenerator( | ||
self.thunk, self.thunk_args, self.thunk_kwargs, remotable_return_info | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
"""Test iree.build file for exercising turbine_generate. | ||
|
||
Since we can only do in-process testing of real modules (not dynamically loaded) | ||
across process boundaries, this builder must live in the source tree vs the | ||
tests tree. | ||
""" | ||
|
||
import os | ||
|
||
from iree.build import compile, entrypoint, iree_build_main | ||
from iree.turbine.aot.build_actions import * | ||
from iree.turbine.aot import ( | ||
ExportOutput, | ||
FxProgramsBuilder, | ||
export, | ||
externalize_module_parameters, | ||
) | ||
|
||
|
||
def export_simple_model(batch_size: int | None = None) -> ExportOutput: | ||
import torch | ||
|
||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU()) | ||
self.branch2 = torch.nn.Sequential( | ||
torch.nn.Linear(128, 64), torch.nn.ReLU() | ||
) | ||
self.buffer = torch.ones(32) | ||
|
||
def forward(self, x1, x2): | ||
out1 = self.branch1(x1) | ||
out2 = self.branch2(x2) | ||
return (out1 + self.buffer, out2) | ||
|
||
example_bs = 2 if batch_size is None else batch_size | ||
example_args = (torch.randn(example_bs, 64), torch.randn(example_bs, 128)) | ||
|
||
# Create a dynamic batch size | ||
if batch_size is None: | ||
batch = torch.export.Dim("batch") | ||
# Specify that the first dimension of each input is that batch size | ||
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} | ||
else: | ||
dynamic_shapes = {} | ||
|
||
module = M() | ||
externalize_module_parameters(module) | ||
fxb = FxProgramsBuilder(module) | ||
print(f" [{os.getpid()}] Compiling with dynamic shapes: {dynamic_shapes}") | ||
|
||
@fxb.export_program(args=example_args, dynamic_shapes=dynamic_shapes) | ||
def dynamic_batch(module: M, x1, x2): | ||
return module.forward(x1, x2) | ||
|
||
return export(fxb) | ||
|
||
|
||
@entrypoint(description="Builds an awesome pipeline") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚀 |
||
def pipe(): | ||
print(f"Main pid: {os.getpid()}") | ||
results = [] | ||
for i in range(3): | ||
turbine_generate( | ||
export_simple_model, | ||
batch_size=None if i == 0 else i * 10, | ||
name=f"import_stage{i}", | ||
out_of_process=i > 0, | ||
) | ||
results.extend( | ||
compile( | ||
name=f"stage{i}", | ||
source=f"import_stage{i}.mlir", | ||
) | ||
) | ||
return results | ||
|
||
|
||
if __name__ == "__main__": | ||
iree_build_main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import io | ||
|
||
from iree.build import * | ||
|
||
|
||
def test_example_builder(tmp_path): | ||
from iree.turbine.aot.testing import example_builder | ||
|
||
iree_build_main( | ||
example_builder, | ||
args=( | ||
f"--output-dir={tmp_path}", | ||
"--iree-hal-target-device=cpu", | ||
"--iree-llvmcpu-target-cpu=host", | ||
), | ||
) | ||
|
||
# Should have compiled three outputs. | ||
for output_name in [ | ||
"bin/pipe/stage0_cpu-host.vmfb", | ||
"bin/pipe/stage1_cpu-host.vmfb", | ||
"bin/pipe/stage2_cpu-host.vmfb", | ||
]: | ||
output_path = tmp_path / output_name | ||
assert output_path.exists() | ||
|
||
# Should have generated with a dynamic batch and two fixed batch sizes. | ||
for gen_name, contains_str in [ | ||
("genfiles/pipe/import_stage0.mlir", "!torch.vtensor<[?,64],f32>"), | ||
("genfiles/pipe/import_stage1.mlir", "!torch.vtensor<[10,64],f32>"), | ||
("genfiles/pipe/import_stage2.mlir", "!torch.vtensor<[20,64],f32>"), | ||
]: | ||
gen_path = tmp_path / gen_name | ||
contents = gen_path.read_text() | ||
print(contents) | ||
assert contains_str in contents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marbre any ideas for how we could include iree-turbine docs as part of https://iree-python-api.readthedocs.io/ ?
I'd like to include this page under https://iree-python-api.readthedocs.io/en/latest/compiler/build.html as part of iree-org/iree#19019 , and we can also put the rest of the iree-turbine docs up too. We should limit the number of doc sites we have to maintain, even if this crosses repository boundaries a bit :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tracking issue: #77