Skip to content

Commit

Permalink
[iree.build] Add turbine_generate iree.build rule. (#249)
Browse files Browse the repository at this point in the history
This supports performing parallelizable aot model export as part of the
iree.build tool. By default, it runs actions out of process, meaning
that multiple torch.export and torch-mlir imports can be running with
true concurrency.

---------

Signed-off-by: Stella Laurenzo <[email protected]>
  • Loading branch information
stellaraccident authored Nov 27, 2024
1 parent 9e79f4e commit d1dfb3c
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 2 deletions.
4 changes: 2 additions & 2 deletions iree-requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
# Uncomment to skip versions from PyPI (so _only_ nightly versions).
# --no-index

iree-base-compiler==3.1.0rc20241122
iree-base-runtime==3.1.0rc20241122
iree-base-compiler==3.1.0rc20241127
iree-base-runtime==3.1.0rc20241127
206 changes: 206 additions & 0 deletions iree/turbine/aot/build_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable

from abc import abstractmethod, ABC
from pathlib import Path
import typing
import types

import inspect

from iree.build.executor import ActionConcurrency, BuildAction, BuildContext, BuildFile
from iree.turbine.aot.exporter import ExportOutput

__all__ = [
"turbine_generate",
]


def turbine_generate(
generator: Callable,
*args,
name: str,
out_of_process: bool = True,
**kwargs,
):
"""Invokes a user-defined generator callable as an action, performing turbine
import and storing the resulting artifacts as outputs.
Because torch-based generation is usually quite slow and a bottleneck, this
action takes pains to use the out of process action pool, allowing multiple
generation activities to take place concurrently. Since this requires interacting
with the pickle infrastructure, it puts some constraints on usage:
* generator must be a pickleable callable. In practice, this means that it must
be a named function at module scope (without decorator) or a named class at
module scope with a `__call__` method.
* args and kwargs must be pickleable. In practice, this means primitive values.
Arguments to the generator are taken from the positional and unmatched keyword
arguments passed to `turbine_generate`.
The generator makes artifacts available as outputs by returning corresponding
Python instances (which must be declared as typing parameters for the remoting
to work):
* `ExportOutput`: The result of calling `aot.export(...)` will result in
`save_mlir()` being called on it while still in the subprocess to write to
a file names `{name}.mlir` if there is one return or `{name}_{n}.mlir` if
multiple.
By default, import is run in a subprocess pool. It can be run in the main
process by passing `out_of_process=False`.
See testing/example_builder.py for an example.
"""
sig = inspect.signature(generator, eval_str=True)
return_marshallers = unwrap_return_annotation(sig.return_annotation)

context = BuildContext.current()
action = TurbineBuilderAction(
generator,
args,
kwargs,
desc=f"Export turbine model {name}",
executor=context.executor,
concurrency=(
ActionConcurrency.PROCESS if out_of_process else ActionConcurrency.THREAD
),
)
for rm in return_marshallers:
rm.prepare_action(context, name, action, len(return_marshallers))
return [r[1] for r in action.returns]


class ReturnMarshaller(ABC):
@abstractmethod
def prepare_action(
self,
context: BuildContext,
name: str,
action: "TurbineBuilderAction",
return_arity: int,
):
...

@abstractmethod
def save_remote_result(self, result, path: Path):
...


class ExportOutputReturnMarshaller(ReturnMarshaller):
def prepare_action(
self,
context: BuildContext,
name: str,
action: "TurbineBuilderAction",
return_arity: int,
):
# Need to allocate one file for output.
file_name = (
f"{name}_{len(action.returns)}.mlir" if return_arity > 1 else f"{name}.mlir"
)
output_file = context.allocate_file(file_name)
action.returns.append((self, output_file))
output_file.deps.add(action)

def save_remote_result(self, result, path: Path):
if not isinstance(result, ExportOutput):
raise RuntimeError(
"Turbine generator was declared to return an ExportOutput instance, "
f"but it returned {type(result)}"
)
result.save_mlir(path)


RETURN_MARSHALLERS_BY_TYPE: dict[type, ReturnMarshaller] = {
ExportOutput: ExportOutputReturnMarshaller(),
}
EXPLICIT_MARSHALLER_TYPES = list(RETURN_MARSHALLERS_BY_TYPE.keys())


def get_return_marshaller(t: type) -> ReturnMarshaller:
m = RETURN_MARSHALLERS_BY_TYPE.get(t)
if m is not None:
return m

# Do an exhaustive subclass check.
for k, m in RETURN_MARSHALLERS_BY_TYPE.items():
if issubclass(t, k):
# Cache it.
RETURN_MARSHALLERS_BY_TYPE[t] = m
return m
raise ValueError(
f"In order to use a callable as a generator in turbine_generate, it must be "
f" annotated with specific return types. Found '{t}' but only "
f"{EXPLICIT_MARSHALLER_TYPES} are supported"
)


def unwrap_return_annotation(annot) -> list[ReturnMarshaller]:
# typing._GenericAlias is used to unwrap old-style (i.e. `List`) collection
# aliases. We special case this and it can be removed eventually.
_GenericAlias = getattr(typing, "_GenericAlias", None)
is_generic_alias = isinstance(annot, (types.GenericAlias)) or (
_GenericAlias and isinstance(annot, _GenericAlias)
)
if is_generic_alias and annot.__origin__ is tuple:
unpacked = annot.__args__
else:
unpacked = [annot]
return [get_return_marshaller(it) for it in unpacked]


class RemoteGenerator:
def __init__(
self,
generation_thunk,
thunk_args,
thunk_kwargs,
return_info: list[tuple[ReturnMarshaller, Path]],
):
self.generation_thunk = generation_thunk
self.thunk_args = thunk_args
self.thunk_kwargs = thunk_kwargs
self.return_info = return_info

def __call__(self):
results = self.generation_thunk(*self.thunk_args, **self.thunk_kwargs)
if not isinstance(results, (tuple, list)):
results = [results]
if len(results) != len(self.return_info):
raise RuntimeError(
f"Turbine generator {self.generation_thunk} returned {len(results)} values, "
f"but it was declared to return {len(self.return_info)}"
)
for result, (marshaller, output_path) in zip(results, self.return_info):
marshaller.save_remote_result(result, output_path)


class TurbineBuilderAction(BuildAction):
def __init__(
self,
thunk,
thunk_args,
thunk_kwargs,
concurrency,
**kwargs,
):
super().__init__(concurrency=concurrency, **kwargs)
self.thunk = thunk
self.thunk_args = thunk_args
self.thunk_kwargs = thunk_kwargs
self.returns: list[tuple[ReturnMarshaller, BuildFile]] = []

def _remotable_thunk(self):
remotable_return_info = [
(marshaller, bf.get_fs_path()) for marshaller, bf in self.returns
]
return RemoteGenerator(
self.thunk, self.thunk_args, self.thunk_kwargs, remotable_return_info
)
88 changes: 88 additions & 0 deletions iree/turbine/aot/testing/example_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Test iree.build file for exercising turbine_generate.
Since we can only do in-process testing of real modules (not dynamically loaded)
across process boundaries, this builder must live in the source tree vs the
tests tree.
"""

import os

from iree.build import compile, entrypoint, iree_build_main
from iree.turbine.aot.build_actions import *
from iree.turbine.aot import (
ExportOutput,
FxProgramsBuilder,
export,
externalize_module_parameters,
)


def export_simple_model(batch_size: int | None = None) -> ExportOutput:
import torch

class M(torch.nn.Module):
def __init__(self):
super().__init__()

self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU())
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)

def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)

example_bs = 2 if batch_size is None else batch_size
example_args = (torch.randn(example_bs, 64), torch.randn(example_bs, 128))

# Create a dynamic batch size
if batch_size is None:
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
else:
dynamic_shapes = {}

module = M()
externalize_module_parameters(module)
fxb = FxProgramsBuilder(module)
print(f" [{os.getpid()}] Compiling with dynamic shapes: {dynamic_shapes}")

@fxb.export_program(args=example_args, dynamic_shapes=dynamic_shapes)
def dynamic_batch(module: M, x1, x2):
return module.forward(x1, x2)

return export(fxb)


@entrypoint(description="Builds an awesome pipeline")
def pipe():
print(f"Main pid: {os.getpid()}")
results = []
for i in range(3):
turbine_generate(
export_simple_model,
batch_size=None if i == 0 else i * 10,
name=f"import_stage{i}",
out_of_process=i > 0,
)
results.extend(
compile(
name=f"stage{i}",
source=f"import_stage{i}.mlir",
)
)
return results


if __name__ == "__main__":
iree_build_main()
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ explicit_package_bases = True
mypy_path = $MYPY_CONFIG_FILE_DIR
packages = iree.turbine

# Missing typing stubs for iree.build.
[mypy-iree.build.*]
ignore_missing_imports = True

# Missing typing stubs for iree.compiler.
[mypy-iree.compiler.*]
ignore_missing_imports = True
Expand Down
42 changes: 42 additions & 0 deletions tests/aot/turbine_generate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import io

from iree.build import *


def test_example_builder(tmp_path):
from iree.turbine.aot.testing import example_builder

iree_build_main(
example_builder,
args=(
f"--output-dir={tmp_path}",
"--iree-hal-target-device=cpu",
"--iree-llvmcpu-target-cpu=host",
),
)

# Should have compiled three outputs.
for output_name in [
"bin/pipe/stage0_cpu-host.vmfb",
"bin/pipe/stage1_cpu-host.vmfb",
"bin/pipe/stage2_cpu-host.vmfb",
]:
output_path = tmp_path / output_name
assert output_path.exists()

# Should have generated with a dynamic batch and two fixed batch sizes.
for gen_name, contains_str in [
("genfiles/pipe/import_stage0.mlir", "!torch.vtensor<[?,64],f32>"),
("genfiles/pipe/import_stage1.mlir", "!torch.vtensor<[10,64],f32>"),
("genfiles/pipe/import_stage2.mlir", "!torch.vtensor<[20,64],f32>"),
]:
gen_path = tmp_path / gen_name
contents = gen_path.read_text()
print(contents)
assert contains_str in contents

0 comments on commit d1dfb3c

Please sign in to comment.