From 1e18da0ec457a1da2b3a1d00249374aa66b9f060 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Tue, 14 Jan 2025 16:53:25 +0000 Subject: [PATCH] Enable choosing of input type for lowering tests Closes #1765 --- python/test_infra/test_utils.py | 26 +++++-- python/test_infra/ttir_builder.py | 70 ++++++++++++++++--- test/python/golden/test_ttir_ops.py | 101 +++++++++++++++------------- 3 files changed, 136 insertions(+), 61 deletions(-) diff --git a/python/test_infra/test_utils.py b/python/test_infra/test_utils.py index da1957b7f..e28214c81 100644 --- a/python/test_infra/test_utils.py +++ b/python/test_infra/test_utils.py @@ -4,6 +4,7 @@ import os import inspect +import torch from typing import Callable, List, Optional from ttmlir.dialects import func @@ -15,7 +16,7 @@ ttmetal_to_flatbuffer_file, ) -from .ttir_builder import Golden, Operand, Shape, TTIRBuilder +from .ttir_builder import Golden, Operand, Shape, TTIRBuilder, DataType TT_MLIR_HOME = os.environ.get("TT_MLIR_HOME", "") @@ -34,6 +35,7 @@ def _dump_module(module: Module) -> None: def compile_as_mlir_module( test_fn: Callable, inputs_shapes: List[Shape], + inputs_types: Optional[List[torch.dtype]] = None, module_dump: bool = False, ): """ @@ -106,9 +108,16 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): # `test_fn` so the user can use it to build ops. builder = TTIRBuilder(ctx, loc) + # Default to all f32s + if inputs_types is None: + inputs_types = [torch.float32] * len(inputs_shapes) + + assert inputs_types is not None and len(inputs_shapes) == len(inputs_types) + with ctx, loc: test_fn_input_types = [ - builder.ranked_tensor_type(input_shape) for input_shape in inputs_shapes + builder.ranked_tensor_type(shape, builder.get_type_from_torch_dtype(dtype)) + for (shape, dtype) in zip(inputs_shapes, inputs_types) ] # Wrap everything in a mlir module. @@ -119,8 +128,8 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): @func.func(*test_fn_input_types, name=test_fn.__name__) def decorated_func(*inputs): # Randomly generate golden tensors for function inputs. - for index, i in enumerate(inputs): - builder.generate_input_golden(i, index) + for index, (operand, dtype) in enumerate(zip(inputs, inputs_types)): + builder.generate_input_golden(operand, dtype, index) return test_fn(*inputs, builder=builder) @@ -256,6 +265,7 @@ def ttmetal_to_flatbuffer( def compile_to_flatbuffer( inputs_shapes: List[Shape], + inputs_types: Optional[List[torch.dtype]] = None, test_name: Optional[str] = None, targets: List[str] = ["ttmetal", "ttnn"], module_dump: bool = False, @@ -320,12 +330,16 @@ def wrapper(): # both targets are chosen if "ttmetal" in targets: - module, builder = compile_as_mlir_module(test_fn, inputs_shapes) + module, builder = compile_as_mlir_module( + test_fn, inputs_shapes, inputs_types + ) module = ttir_to_ttmetal(module, builder, test_base + ".mlir") ttmetal_to_flatbuffer(module, builder, test_base + ".ttm") if "ttnn" in targets: - module, builder = compile_as_mlir_module(test_fn, inputs_shapes) + module, builder = compile_as_mlir_module( + test_fn, inputs_shapes, inputs_types + ) module = ttir_to_ttnn(module, builder, test_base + ".mlir") ttnn_to_flatbuffer(module, builder, test_base + ".ttnn") diff --git a/python/test_infra/ttir_builder.py b/python/test_infra/ttir_builder.py index daae39ac5..40478b230 100644 --- a/python/test_infra/ttir_builder.py +++ b/python/test_infra/ttir_builder.py @@ -137,26 +137,32 @@ def get_shape(self, input: Operand) -> Shape: """Retrieves shape of operand which is expected to be a shaped type.""" return self._get_type(input).shape - def generate_and_store_random_golden(self, operand: Operand) -> Golden: + def generate_and_store_random_golden( + self, operand: Operand, dtype: torch.dtype = torch.float32 + ) -> Golden: """ - Generates random tensor of `operand`s shape, assigns it to a golden, + Generates random tensor of `dtype`s of `operand`s shape, assigns it to a golden, and maps `operand` to that golden. Returns generated golden. """ seed = self._get_seed() - random_tensor = self._generate_random_tensor(self.get_shape(operand), seed) + random_tensor = self._generate_random_tensor( + self.get_shape(operand), dtype, seed + ) golden = Golden(random_tensor, seed) self._store_golden(operand, golden) return golden - def generate_input_golden(self, operand: Operand, index: int) -> None: + def generate_input_golden( + self, operand: Operand, dtype: torch.dtype, index: int + ) -> None: """ - Generates random tensor of `input`s shape, assigns it to a golden, + Generates random tensor of `dtype`s of `input`s shape, assigns it to a golden, and maps `input` to that golden. """ self.id_golden_map[f"input_{index}"] = self.generate_and_store_random_golden( - operand + operand, dtype ) def get_golden_map(self) -> Dict: @@ -200,12 +206,26 @@ def _get_seed(self) -> int: return seed @staticmethod - def _generate_random_tensor(shape: Shape, seed: int) -> torch.Tensor: + def _generate_random_tensor( + shape: Shape, dtype: torch.dtype, seed: int + ) -> torch.Tensor: """ - Generates random tensor of shape `shape`, using `seed` to seed torch + Generates random tensor of shape `shape`, with type `dtype`, using `seed` to seed torch random generator. """ - return torch.randn(shape, generator=torch.manual_seed(seed)) + + if dtype.is_floating_point: + return torch.randn(shape, generator=torch.manual_seed(seed), dtype=dtype) + else: + min_int = torch.iinfo(dtype).min + max_int = torch.iinfo(dtype).max + return torch.randint( + low=min_int, + high=max_int, + size=shape, + generator=torch.manual_seed(seed), + dtype=dtype, + ) def _get_golden(self, operand: Operand) -> Golden: """Retrieves stored golden for `operand`.""" @@ -259,6 +279,38 @@ def _get_type(self, input: Operand): return typ + # ----- Utility Conversion ---- + + def get_type_from_torch_dtype(self, dtype: torch.dtype) -> Type: + """ + Returns a MLIR `Type` obj corresponding to `dtype` + """ + match dtype: + case torch.float16: + return F16Type.get(self._ctx) + case torch.float32: + return F32Type.get(self._ctx) + case torch.float64: + return F64Type.get(self._ctx) + case torch.int8: + return IntegerType.get_signless(8, self._ctx) + case torch.int16: + return IntegerType.get_signless(16, self._ctx) + case torch.int32: + return IntegerType.get_signless(32, self._ctx) + case torch.int64: + return IntegerType.get_signless(64, self._ctx) + case torch.uint8: + return IntegerType.get_unsigned(8, self._ctx) + case torch.uint16: + return IntegerType.get_unsigned(16, self._ctx) + case torch.uint32: + return IntegerType.get_unsigned(32, self._ctx) + case torch.uint64: + return IntegerType.get_unsigned(64, self._ctx) + case _: + raise TypeError(f"Invalid Type {type}") + # ----- Utility factories ----- def ranked_tensor_type( diff --git a/test/python/golden/test_ttir_ops.py b/test/python/golden/test_ttir_ops.py index e2d2bde4a..14fbd10b7 100644 --- a/test/python/golden/test_ttir_ops.py +++ b/test/python/golden/test_ttir_ops.py @@ -5,6 +5,7 @@ # RUN: SYSTEM_DESC_PATH=%system_desc_path% %python %s import inspect +import torch from ttmlir.test_utils import compile_to_flatbuffer from ttmlir.ttir_builder import Operand, TTIRBuilder, Attribute @@ -40,11 +41,11 @@ def test_logical_not(in0: Operand, builder: TTIRBuilder): return builder.logical_not(in0) -# TODO: uncomment once we have control over generated input types (bitwise ops -# don't support floats) (see issue #1765) -# @compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) -# def test_bitwise_not(in0: Operand, builder: TTIRBuilder): -# return builder.bitwise_not(in0) +# NOTE: The generated flatbuffer will currently fail to run due to only floats +# being supported by the runtime. See issue #1775 for tracking +@compile_to_flatbuffer([(128, 128)], inputs_types=[torch.int8], targets=["ttnn"]) +def test_bitwise_not(in0: Operand, builder: TTIRBuilder): + return builder.bitwise_not(in0) @compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) @@ -190,39 +191,46 @@ def test_logical_xor(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.logical_xor(in0, in1) -# TODO: uncomment once we have control over generated input types (bitwise ops -# don't support floats) (see issue #1765) -# @compile_to_flatbuffer( -# [ -# (64, 64), -# (64, 64), -# ], -# targets=["ttnn"], -# ) -# def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder): -# return builder.bitwise_and(in0, in1) -# -# -# @compile_to_flatbuffer( -# [ -# (64, 64), -# (64, 64), -# ], -# targets=["ttnn"], -# ) -# def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder): -# return builder.bitwise_or(in0, in1) -# -# -# @compile_to_flatbuffer( -# [ -# (64, 64), -# (64, 64), -# ], -# targets=["ttnn"], -# ) -# def test_bitwise_xor(in0: Operand, in1: Operand, builder: TTIRBuilder): -# return builder.bitwise_xor(in0, in1) +# NOTE: The generated flatbuffer will currently fail to run due to only floats +# being supported by the runtime. See issue #1775 for tracking +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + inputs_types=[torch.int8, torch.int8], + targets=["ttnn"], +) +def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.bitwise_and(in0, in1) + + +# NOTE: The generated flatbuffer will currently fail to run due to only floats +# being supported by the runtime. See issue #1775 for tracking +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + inputs_types=[torch.int8, torch.int8], + targets=["ttnn"], +) +def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.bitwise_or(in0, in1) + + +# NOTE: The generated flatbuffer will currently fail to run due to only floats +# being supported by the runtime. See issue #1775 for tracking +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + inputs_types=[torch.int8, torch.int8], + targets=["ttnn"], +) +def test_bitwise_xor(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.bitwise_xor(in0, in1) @compile_to_flatbuffer( @@ -346,17 +354,18 @@ def test_minimum(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.minimum(in0, in1) -# TODO: uncomment when we have control over the input types # @compile_to_flatbuffer( -# [ -# (64, 64), -# (64, 64), -# (64, 64), -# ], -# targets=["ttnn"], +# [ +# (64, 64), +# (64, 64), +# (64, 64), +# ], +# inputs_types = [torch.int8, torch.float32, torch.float32], +# targets=["ttnn"], # ) # def test_where(in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder): -# return builder.where(in0, in1, in2) +# return builder.where(in0, in1, in2) +# @compile_to_flatbuffer(