Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable choosing of input type for lowering tests #1776

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions python/test_infra/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
import inspect
import torch
from typing import Callable, List, Optional

from ttmlir.dialects import func
Expand All @@ -15,7 +16,7 @@
ttmetal_to_flatbuffer_file,
)

from .ttir_builder import Golden, Operand, Shape, TTIRBuilder
from .ttir_builder import Golden, Operand, Shape, TTIRBuilder, DataType

TT_MLIR_HOME = os.environ.get("TT_MLIR_HOME", "")

Expand All @@ -34,6 +35,7 @@ def _dump_module(module: Module) -> None:
def compile_as_mlir_module(
test_fn: Callable,
inputs_shapes: List[Shape],
inputs_types: Optional[List[torch.dtype]] = None,
module_dump: bool = False,
):
"""
Expand Down Expand Up @@ -106,9 +108,16 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder):
# `test_fn` so the user can use it to build ops.
builder = TTIRBuilder(ctx, loc)

# Default to all f32s
if inputs_types is None:
inputs_types = [torch.float32] * len(inputs_shapes)

assert inputs_types is not None and len(inputs_shapes) == len(inputs_types)

with ctx, loc:
test_fn_input_types = [
builder.ranked_tensor_type(input_shape) for input_shape in inputs_shapes
builder.ranked_tensor_type(shape, builder.get_type_from_torch_dtype(dtype))
for (shape, dtype) in zip(inputs_shapes, inputs_types)
]

# Wrap everything in a mlir module.
Expand All @@ -119,8 +128,8 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder):
@func.func(*test_fn_input_types, name=test_fn.__name__)
def decorated_func(*inputs):
# Randomly generate golden tensors for function inputs.
for index, i in enumerate(inputs):
builder.generate_input_golden(i, index)
for index, (operand, dtype) in enumerate(zip(inputs, inputs_types)):
builder.generate_input_golden(operand, dtype, index)

return test_fn(*inputs, builder=builder)

Expand Down Expand Up @@ -256,6 +265,7 @@ def ttmetal_to_flatbuffer(

def compile_to_flatbuffer(
inputs_shapes: List[Shape],
inputs_types: Optional[List[torch.dtype]] = None,
test_name: Optional[str] = None,
targets: List[str] = ["ttmetal", "ttnn"],
module_dump: bool = False,
Expand Down Expand Up @@ -320,12 +330,16 @@ def wrapper():
# both targets are chosen

if "ttmetal" in targets:
module, builder = compile_as_mlir_module(test_fn, inputs_shapes)
module, builder = compile_as_mlir_module(
test_fn, inputs_shapes, inputs_types
)
module = ttir_to_ttmetal(module, builder, test_base + ".mlir")
ttmetal_to_flatbuffer(module, builder, test_base + ".ttm")

if "ttnn" in targets:
module, builder = compile_as_mlir_module(test_fn, inputs_shapes)
module, builder = compile_as_mlir_module(
test_fn, inputs_shapes, inputs_types
)
module = ttir_to_ttnn(module, builder, test_base + ".mlir")
ttnn_to_flatbuffer(module, builder, test_base + ".ttnn")

Expand Down
70 changes: 61 additions & 9 deletions python/test_infra/ttir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,32 @@ def get_shape(self, input: Operand) -> Shape:
"""Retrieves shape of operand which is expected to be a shaped type."""
return self._get_type(input).shape

def generate_and_store_random_golden(self, operand: Operand) -> Golden:
def generate_and_store_random_golden(
self, operand: Operand, dtype: torch.dtype = torch.float32
) -> Golden:
"""
Generates random tensor of `operand`s shape, assigns it to a golden,
Generates random tensor of `dtype`s of `operand`s shape, assigns it to a golden,
and maps `operand` to that golden.

Returns generated golden.
"""
seed = self._get_seed()
random_tensor = self._generate_random_tensor(self.get_shape(operand), seed)
random_tensor = self._generate_random_tensor(
self.get_shape(operand), dtype, seed
)
golden = Golden(random_tensor, seed)
self._store_golden(operand, golden)
return golden

def generate_input_golden(self, operand: Operand, index: int) -> None:
def generate_input_golden(
self, operand: Operand, dtype: torch.dtype, index: int
) -> None:
"""
Generates random tensor of `input`s shape, assigns it to a golden,
Generates random tensor of `dtype`s of `input`s shape, assigns it to a golden,
and maps `input` to that golden.
"""
self.id_golden_map[f"input_{index}"] = self.generate_and_store_random_golden(
operand
operand, dtype
)

def get_golden_map(self) -> Dict:
Expand Down Expand Up @@ -200,12 +206,26 @@ def _get_seed(self) -> int:
return seed

@staticmethod
def _generate_random_tensor(shape: Shape, seed: int) -> torch.Tensor:
def _generate_random_tensor(
shape: Shape, dtype: torch.dtype, seed: int
) -> torch.Tensor:
"""
Generates random tensor of shape `shape`, using `seed` to seed torch
Generates random tensor of shape `shape`, with type `dtype`, using `seed` to seed torch
random generator.
"""
return torch.randn(shape, generator=torch.manual_seed(seed))

if dtype.is_floating_point:
return torch.randn(shape, generator=torch.manual_seed(seed), dtype=dtype)
else:
min_int = torch.iinfo(dtype).min
max_int = torch.iinfo(dtype).max
return torch.randint(
low=min_int,
high=max_int,
size=shape,
generator=torch.manual_seed(seed),
dtype=dtype,
)

def _get_golden(self, operand: Operand) -> Golden:
"""Retrieves stored golden for `operand`."""
Expand Down Expand Up @@ -259,6 +279,38 @@ def _get_type(self, input: Operand):

return typ

# ----- Utility Conversion ----

def get_type_from_torch_dtype(self, dtype: torch.dtype) -> Type:
"""
Returns a MLIR `Type` obj corresponding to `dtype`
"""
match dtype:
case torch.float16:
return F16Type.get(self._ctx)
case torch.float32:
return F32Type.get(self._ctx)
case torch.float64:
return F64Type.get(self._ctx)
case torch.int8:
return IntegerType.get_signless(8, self._ctx)
case torch.int16:
return IntegerType.get_signless(16, self._ctx)
case torch.int32:
return IntegerType.get_signless(32, self._ctx)
case torch.int64:
return IntegerType.get_signless(64, self._ctx)
case torch.uint8:
return IntegerType.get_unsigned(8, self._ctx)
case torch.uint16:
return IntegerType.get_unsigned(16, self._ctx)
case torch.uint32:
return IntegerType.get_unsigned(32, self._ctx)
case torch.uint64:
return IntegerType.get_unsigned(64, self._ctx)
case _:
raise TypeError(f"Invalid Type {type}")

# ----- Utility factories -----

def ranked_tensor_type(
Expand Down
101 changes: 55 additions & 46 deletions test/python/golden/test_ttir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# RUN: SYSTEM_DESC_PATH=%system_desc_path% %python %s

import inspect
import torch

from ttmlir.test_utils import compile_to_flatbuffer
from ttmlir.ttir_builder import Operand, TTIRBuilder, Attribute
Expand Down Expand Up @@ -40,11 +41,11 @@ def test_logical_not(in0: Operand, builder: TTIRBuilder):
return builder.logical_not(in0)


# TODO: uncomment once we have control over generated input types (bitwise ops
# don't support floats) (see issue #1765)
# @compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
# def test_bitwise_not(in0: Operand, builder: TTIRBuilder):
# return builder.bitwise_not(in0)
# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking
@compile_to_flatbuffer([(128, 128)], inputs_types=[torch.int8], targets=["ttnn"])
def test_bitwise_not(in0: Operand, builder: TTIRBuilder):
return builder.bitwise_not(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
Expand Down Expand Up @@ -190,39 +191,46 @@ def test_logical_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_xor(in0, in1)


# TODO: uncomment once we have control over generated input types (bitwise ops
# don't support floats) (see issue #1765)
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
# return builder.bitwise_and(in0, in1)
#
#
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
# return builder.bitwise_or(in0, in1)
#
#
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_bitwise_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
# return builder.bitwise_xor(in0, in1)
# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking
@compile_to_flatbuffer(
[
(64, 64),
(64, 64),
],
inputs_types=[torch.int8, torch.int8],
targets=["ttnn"],
)
def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.bitwise_and(in0, in1)


# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking
@compile_to_flatbuffer(
[
(64, 64),
(64, 64),
],
inputs_types=[torch.int8, torch.int8],
targets=["ttnn"],
)
def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.bitwise_or(in0, in1)


# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking
@compile_to_flatbuffer(
[
(64, 64),
(64, 64),
],
inputs_types=[torch.int8, torch.int8],
targets=["ttnn"],
)
def test_bitwise_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.bitwise_xor(in0, in1)


@compile_to_flatbuffer(
Expand Down Expand Up @@ -346,17 +354,18 @@ def test_minimum(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.minimum(in0, in1)


# TODO: uncomment when we have control over the input types
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# [
# (64, 64),
# (64, 64),
# (64, 64),
# ],
# inputs_types = [torch.int8, torch.float32, torch.float32],
# targets=["ttnn"],
# )
# def test_where(in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder):
# return builder.where(in0, in1, in2)
# return builder.where(in0, in1, in2)
#


@compile_to_flatbuffer(
Expand Down
Loading