Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[iree.build] Add turbine_generate iree.build rule. #249

Merged
merged 5 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions iree-requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
# Uncomment to skip versions from PyPI (so _only_ nightly versions).
# --no-index

iree-base-compiler==3.1.0rc20241122
iree-base-runtime==3.1.0rc20241122
iree-base-compiler==3.1.0rc20241127
iree-base-runtime==3.1.0rc20241127
206 changes: 206 additions & 0 deletions iree/turbine/aot/build_actions.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marbre any ideas for how we could include iree-turbine docs as part of https://iree-python-api.readthedocs.io/ ?

I'd like to include this page under https://iree-python-api.readthedocs.io/en/latest/compiler/build.html as part of iree-org/iree#19019 , and we can also put the rest of the iree-turbine docs up too. We should limit the number of doc sites we have to maintain, even if this crosses repository boundaries a bit :P

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracking issue: #77

Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Callable

from abc import abstractmethod, ABC
from pathlib import Path
import typing
import types

import inspect

from iree.build.executor import ActionConcurrency, BuildAction, BuildContext, BuildFile
from iree.turbine.aot.exporter import ExportOutput

__all__ = [
"turbine_generate",
]


def turbine_generate(
generator: Callable,
*args,
name: str,
out_of_process: bool = True,
**kwargs,
):
"""Invokes a user-defined generator callable as an action, performing turbine
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider renaming "generator callable" to "exporter callable" and/or explain that it is a callable that returns an exported torch module.

For me, "generator callable" resolves to the python language feature, "generator functions", like

def fib():
  prev = 0
  cur = 1
  while True:
    yield cur
    prev, cur = cur, cur+prev

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm running out of daylight before vacation to land changes, but sounds like a reasonable idea. Maybe you could send a patch next week and work with Ean/Scott on a better name? I'll land it like this but it has no use yet -- easy to rename.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good & will do!

import and storing the resulting artifacts as outputs.

Because torch-based generation is usually quite slow and a bottleneck, this
action takes pains to use the out of process action pool, allowing multiple
generation activities to take place concurrently. Since this requires interacting
with the pickle infrastructure, it puts some constraints on usage:

* generator must be a pickleable callable. In practice, this means that it must
be a named function at module scope (without decorator) or a named class at
module scope with a `__call__` method.
* args and kwargs must be pickleable. In practice, this means primitive values.

Arguments to the generator are taken from the positional and unmatched keyword
arguments passed to `turbine_generate`.

The generator makes artifacts available as outputs by returning corresponding
Python instances (which must be declared as typing parameters for the remoting
to work):

* `ExportOutput`: The result of calling `aot.export(...)` will result in
`save_mlir()` being called on it while still in the subprocess to write to
a file names `{name}.mlir` if there is one return or `{name}_{n}.mlir` if
multiple.
Comment on lines +51 to +54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this output .mlirbc? What about .irpa files?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to grow some options for mlirbc. Have some thoughts on how to plug parameter generation through in an ergonomic way. One of:

  • Support another return type and then the generator just returns something that causes the export.
  • Have an extra_outputs= dict that lets you set up additional output files and have that passed into the generator function to do with as you please.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(but wanted to land this before adding more)


By default, import is run in a subprocess pool. It can be run in the main
process by passing `out_of_process=False`.

See testing/example_builder.py for an example.
"""
sig = inspect.signature(generator, eval_str=True)
return_marshallers = unwrap_return_annotation(sig.return_annotation)

context = BuildContext.current()
action = TurbineBuilderAction(
generator,
args,
kwargs,
desc=f"Export turbine model {name}",
executor=context.executor,
concurrency=(
ActionConcurrency.PROCESS if out_of_process else ActionConcurrency.THREAD
),
)
for rm in return_marshallers:
rm.prepare_action(context, name, action, len(return_marshallers))
return [r[1] for r in action.returns]


class ReturnMarshaller(ABC):
@abstractmethod
def prepare_action(
self,
context: BuildContext,
name: str,
action: "TurbineBuilderAction",
return_arity: int,
):
...

@abstractmethod
def save_remote_result(self, result, path: Path):
...


class ExportOutputReturnMarshaller(ReturnMarshaller):
def prepare_action(
self,
context: BuildContext,
name: str,
action: "TurbineBuilderAction",
return_arity: int,
):
# Need to allocate one file for output.
file_name = (
f"{name}_{len(action.returns)}.mlir" if return_arity > 1 else f"{name}.mlir"
)
output_file = context.allocate_file(file_name)
action.returns.append((self, output_file))
output_file.deps.add(action)

def save_remote_result(self, result, path: Path):
if not isinstance(result, ExportOutput):
raise RuntimeError(
"Turbine generator was declared to return an ExportOutput instance, "
f"but it returned {type(result)}"
)
result.save_mlir(path)


RETURN_MARSHALLERS_BY_TYPE: dict[type, ReturnMarshaller] = {
ExportOutput: ExportOutputReturnMarshaller(),
}
EXPLICIT_MARSHALLER_TYPES = list(RETURN_MARSHALLERS_BY_TYPE.keys())


def get_return_marshaller(t: type) -> ReturnMarshaller:
m = RETURN_MARSHALLERS_BY_TYPE.get(t)
if m is not None:
return m

# Do an exhaustive subclass check.
for k, m in RETURN_MARSHALLERS_BY_TYPE.items():
if issubclass(t, k):
# Cache it.
RETURN_MARSHALLERS_BY_TYPE[t] = m
return m
raise ValueError(
f"In order to use a callable as a generator in turbine_generate, it must be "
f" annotated with specific return types. Found '{t}' but only "
f"{EXPLICIT_MARSHALLER_TYPES} are supported"
)
Comment on lines +132 to +142
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 Python magic



def unwrap_return_annotation(annot) -> list[ReturnMarshaller]:
# typing._GenericAlias is used to unwrap old-style (i.e. `List`) collection
# aliases. We special case this and it can be removed eventually.
_GenericAlias = getattr(typing, "_GenericAlias", None)
is_generic_alias = isinstance(annot, (types.GenericAlias)) or (
_GenericAlias and isinstance(annot, _GenericAlias)
)
if is_generic_alias and annot.__origin__ is tuple:
unpacked = annot.__args__
else:
unpacked = [annot]
return [get_return_marshaller(it) for it in unpacked]


class RemoteGenerator:
def __init__(
self,
generation_thunk,
thunk_args,
thunk_kwargs,
return_info: list[tuple[ReturnMarshaller, Path]],
):
self.generation_thunk = generation_thunk
self.thunk_args = thunk_args
self.thunk_kwargs = thunk_kwargs
self.return_info = return_info

def __call__(self):
results = self.generation_thunk(*self.thunk_args, **self.thunk_kwargs)
if not isinstance(results, (tuple, list)):
results = [results]
if len(results) != len(self.return_info):
raise RuntimeError(
f"Turbine generator {self.generation_thunk} returned {len(results)} values, "
f"but it was declared to return {len(self.return_info)}"
)
for result, (marshaller, output_path) in zip(results, self.return_info):
marshaller.save_remote_result(result, output_path)


class TurbineBuilderAction(BuildAction):
def __init__(
self,
thunk,
thunk_args,
thunk_kwargs,
concurrency,
**kwargs,
):
super().__init__(concurrency=concurrency, **kwargs)
self.thunk = thunk
self.thunk_args = thunk_args
self.thunk_kwargs = thunk_kwargs
self.returns: list[tuple[ReturnMarshaller, BuildFile]] = []

def _remotable_thunk(self):
remotable_return_info = [
(marshaller, bf.get_fs_path()) for marshaller, bf in self.returns
]
return RemoteGenerator(
self.thunk, self.thunk_args, self.thunk_kwargs, remotable_return_info
)
88 changes: 88 additions & 0 deletions iree/turbine/aot/testing/example_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Test iree.build file for exercising turbine_generate.

Since we can only do in-process testing of real modules (not dynamically loaded)
across process boundaries, this builder must live in the source tree vs the
tests tree.
"""

import os

from iree.build import compile, entrypoint, iree_build_main
from iree.turbine.aot.build_actions import *
from iree.turbine.aot import (
ExportOutput,
FxProgramsBuilder,
export,
externalize_module_parameters,
)


def export_simple_model(batch_size: int | None = None) -> ExportOutput:
import torch

class M(torch.nn.Module):
def __init__(self):
super().__init__()

self.branch1 = torch.nn.Sequential(torch.nn.Linear(64, 32), torch.nn.ReLU())
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)

def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)

example_bs = 2 if batch_size is None else batch_size
example_args = (torch.randn(example_bs, 64), torch.randn(example_bs, 128))

# Create a dynamic batch size
if batch_size is None:
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
else:
dynamic_shapes = {}

module = M()
externalize_module_parameters(module)
fxb = FxProgramsBuilder(module)
print(f" [{os.getpid()}] Compiling with dynamic shapes: {dynamic_shapes}")

@fxb.export_program(args=example_args, dynamic_shapes=dynamic_shapes)
def dynamic_batch(module: M, x1, x2):
return module.forward(x1, x2)

return export(fxb)


@entrypoint(description="Builds an awesome pipeline")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

def pipe():
print(f"Main pid: {os.getpid()}")
results = []
for i in range(3):
turbine_generate(
export_simple_model,
batch_size=None if i == 0 else i * 10,
name=f"import_stage{i}",
out_of_process=i > 0,
)
results.extend(
compile(
name=f"stage{i}",
source=f"import_stage{i}.mlir",
)
)
return results


if __name__ == "__main__":
iree_build_main()
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ explicit_package_bases = True
mypy_path = $MYPY_CONFIG_FILE_DIR
packages = iree.turbine

# Missing typing stubs for iree.build.
[mypy-iree.build.*]
ignore_missing_imports = True

# Missing typing stubs for iree.compiler.
[mypy-iree.compiler.*]
ignore_missing_imports = True
Expand Down
42 changes: 42 additions & 0 deletions tests/aot/turbine_generate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import io

from iree.build import *


def test_example_builder(tmp_path):
from iree.turbine.aot.testing import example_builder

iree_build_main(
example_builder,
args=(
f"--output-dir={tmp_path}",
"--iree-hal-target-device=cpu",
"--iree-llvmcpu-target-cpu=host",
),
)

# Should have compiled three outputs.
for output_name in [
"bin/pipe/stage0_cpu-host.vmfb",
"bin/pipe/stage1_cpu-host.vmfb",
"bin/pipe/stage2_cpu-host.vmfb",
]:
output_path = tmp_path / output_name
assert output_path.exists()

# Should have generated with a dynamic batch and two fixed batch sizes.
for gen_name, contains_str in [
("genfiles/pipe/import_stage0.mlir", "!torch.vtensor<[?,64],f32>"),
("genfiles/pipe/import_stage1.mlir", "!torch.vtensor<[10,64],f32>"),
("genfiles/pipe/import_stage2.mlir", "!torch.vtensor<[20,64],f32>"),
]:
gen_path = tmp_path / gen_name
contents = gen_path.read_text()
print(contents)
assert contains_str in contents