Skip to content

Commit

Permalink
Add test.
Browse files Browse the repository at this point in the history
Signed-off-by: Stella Laurenzo <[email protected]>
  • Loading branch information
stellaraccident committed Nov 26, 2024
1 parent 8205588 commit 8ca6c64
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
31 changes: 24 additions & 7 deletions iree/turbine/aot/build_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ def unwrap_return_annotation(annot) -> list[ReturnMarshaller]:
return [get_return_marshaller(it) for it in unpacked]


def turbine_generate(*, name: str, generator: callable):
def turbine_generate(generator: callable, *args, name: str, **kwargs):
sig = inspect.signature(generator, eval_str=True)
return_marshallers = unwrap_return_annotation(sig.return_annotation)

context = BuildContext.current()
action = TurbineBuilderAction(
generator,
args,
kwargs,
desc=f"Export turbine model {name}",
executor=context.executor,
)
Expand All @@ -117,14 +119,19 @@ def turbine_generate(*, name: str, generator: callable):

class RemoteGenerator:
def __init__(
self, generation_thunk, return_info: list[tuple[ReturnMarshaller, Path]]
self,
generation_thunk,
thunk_args,
thunk_kwargs,
return_info: list[tuple[ReturnMarshaller, Path]],
):
self.generation_thunk = generation_thunk
self.thunk_args = thunk_args
self.thunk_kwargs = thunk_kwargs
self.return_info = return_info

def __call__(self):
print("JOB PID:", os.getpid())
results = self.generation_thunk()
results = self.generation_thunk(*self.thunk_args, **self.thunk_kwargs)
if not isinstance(results, (tuple, list)):
results = [results]
if len(results) != len(self.return_info):
Expand All @@ -137,14 +144,24 @@ def __call__(self):


class TurbineBuilderAction(BuildAction):
def __init__(self, thunk, concurrency=ActionConcurrency.PROCESS, **kwargs):
def __init__(
self,
thunk,
thunk_args,
thunk_kwargs,
concurrency=ActionConcurrency.PROCESS,
**kwargs,
):
super().__init__(concurrency=concurrency, **kwargs)
self.thunk = thunk
self.thunk_args = thunk_args
self.thunk_kwargs = thunk_kwargs
self.returns: list[tuple[ReturnMarshaller, BuildFile]] = []

def _remotable_thunk(self):
print("SCHEDULER PID:", os.getpid())
remotable_return_info = [
(marshaller, bf.get_fs_path()) for marshaller, bf in self.returns
]
return RemoteGenerator(self.thunk, remotable_return_info)
return RemoteGenerator(
self.thunk, self.thunk_args, self.thunk_kwargs, remotable_return_info
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Tuple
"""Test iree.build file for exercising turbine_generate.
Since we can only do in-process testing of real modules (not dynamically loaded)
across process boundaries, this builder must live in the source tree vs the
tests tree.
"""

from iree.build import *
from iree.turbine.aot.build_actions import *
from iree.turbine.aot import FxProgramsBuilder


def export_simple_model() -> FxProgramsBuilder:
def export_simple_model(batch_size: int | None = None) -> FxProgramsBuilder:
import torch

class M(torch.nn.Module):
Expand All @@ -29,14 +34,19 @@ def forward(self, x1, x2):
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))
example_bs = 2 if batch_size is None else batch_size
example_args = (torch.randn(example_bs, 64), torch.randn(example_bs, 128))

# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
if batch_size is None:
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
else:
dynamic_shapes = {}

fxb = FxProgramsBuilder(M())
print("Compiling with dynamic shapes:", dynamic_shapes)

@fxb.export_program(args=example_args, dynamic_shapes=dynamic_shapes)
def dynamic_batch(module: M, x1, x2):
Expand All @@ -50,8 +60,9 @@ def pipe():
results = []
for i in range(3):
turbine_generate(
export_simple_model,
batch_size=None if i == 0 else i * 10,
name=f"import_stage{i}",
generator=export_simple_model,
)
results.extend(
compile(
Expand Down
41 changes: 41 additions & 0 deletions tests/aot/turbine_generate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import io

from iree.build import *


def test_example_builder(tmp_path):
from iree.turbine.aot.testing import example_builder

iree_build_main(
example_builder,
args=(
f"--output-dir={tmp_path}",
"--iree-hal-target-device=cpu",
"--iree-llvmcpu-target-cpu=host",
),
)

# Should have compiled three outputs.
for output_name in [
"bin/pipe/stage0_cpu-host.vmfb",
"bin/pipe/stage1_cpu-host.vmfb",
"bin/pipe/stage2_cpu-host.vmfb",
]:
output_path = tmp_path / output_name
assert output_path.exists()

# Should have generated with a dynamic batch and two fixed batch sizes.
for gen_name, contains_str in [
("genfiles/pipe/import_stage0.mlir", "!torch.vtensor<[?,64],f32>"),
("genfiles/pipe/import_stage1.mlir", "!torch.vtensor<[10,64],f32>"),
("genfiles/pipe/import_stage2.mlir", "!torch.vtensor<[20,64],f32>"),
]:
gen_path = tmp_path / gen_name
contents = gen_path.read_text()
assert contains_str in contents

0 comments on commit 8ca6c64

Please sign in to comment.