Skip to content

Commit

Permalink
Test Most Elementwise Ops for TTNN
Browse files Browse the repository at this point in the history
This change adds tests for lowerings of most of the remaining
elementwise ops in TTNN. The resulting `.ttnn` files can be used to
golden test these ops as well. This change does not test all elementwise
ops, as there are a few that require a slightly more robust testing
system in place to use.

Closes #1748
  • Loading branch information
ctodTT authored Jan 14, 2025
1 parent 731de13 commit 043d0a5
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 2 deletions.
95 changes: 94 additions & 1 deletion python/test_infra/ttir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import inspect
from dataclasses import dataclass
from typing import List, Optional, Union, Tuple, Callable, Dict
from typing import List, Optional, Union, Tuple, Callable, Dict, Any
from ttmlir.ir import *
from ttmlir.dialects import ttir, tt, tensor
from ttmlir.passes import create_golden_tensor, DataType
Expand Down Expand Up @@ -389,6 +389,8 @@ def eltwise_proxy(
) -> OpView:
return self.op_proxy(op_golden_function, op_ttir_function, inputs)

# TODO: implement `scatter` & `typecast`

def exp(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.exp, ttir.ExpOp, [in0])

Expand All @@ -398,15 +400,60 @@ def abs(self, in0: Operand) -> OpView:
def logical_not(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.logical_not, ttir.LogicalNotOp, [in0])

def bitwise_not(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.bitwise_not, ttir.BitwiseNotOp, [in0])

def ceil(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.log, ttir.CeilOp, [in0])

def sin(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.sin, ttir.SinOp, [in0])

def cos(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.cos, ttir.CosOp, [in0])

def tan(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.tan, ttir.TanOp, [in0])

def tanh(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.tanh, ttir.TanhOp, [in0])

def log(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.log, ttir.LogOp, [in0])

def log1p(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.log1p, ttir.Log1pOp, [in0])

def expm1(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.expm1, ttir.Expm1Op, [in0])

def sign(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.sign, ttir.SignOp, [in0])

def is_finite(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.isfinite, ttir.IsFiniteOp, [in0])

def floor(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.floor, ttir.FloorOp, [in0])

def where(self, in0: Operand, in1: Operand, in2: Operand) -> OpView:
return self.eltwise_proxy(torch.where, ttir.WhereOp, [in0, in1, in2])

def neg(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.neg, ttir.NegOp, [in0])

def relu(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.relu, ttir.ReluOp, [in0])

def gelu(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.nn.functional.gelu, ttir.GeluOp, [in0])

def sqrt(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.sqrt, ttir.SqrtOp, [in0])

def cbrt(self, in0: Operand) -> OpView:
return self.eltwise_proxy(lambda x: torch.pow(x, 1 / 3), ttir.CbrtOp, [in0])

def rsqrt(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.rsqrt, ttir.RsqrtOp, [in0])

Expand All @@ -428,6 +475,18 @@ def logical_and(self, in0: Operand, in1: Operand) -> OpView:
def logical_or(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.logical_or, ttir.LogicalOrOp, [in0, in1])

def logical_xor(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.logical_xor, ttir.LogicalXorOp, [in0, in1])

def bitwise_and(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.bitwise_and, ttir.BitwiseAndOp, [in0, in1])

def bitwise_or(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.bitwise_or, ttir.BitwiseOrOp, [in0, in1])

def bitwise_xor(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.bitwise_xor, ttir.BitwiseXorOp, [in0, in1])

def subtract(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.subtract, ttir.SubtractOp, [in0, in1])

Expand All @@ -452,9 +511,43 @@ def lt(self, in0: Operand, in1: Operand) -> OpView:
def div(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.div, ttir.DivOp, [in0, in1])

def remainder(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.remainder, ttir.RemainderOp, [in0, in1])

def maximum(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.maximum, ttir.MaximumOp, [in0, in1])

def minimum(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.minimum, ttir.MinimumOp, [in0, in1])

def leaky_relu(self, in0: Operand, parameter: float = 0.01) -> OpView:
# TODO: reconcile this naming mismatch
ttir_kwargs = {"parameter": parameter}
golden_kwargs = {"negative_slope": parameter}
return self.op_proxy(
torch.nn.functional.leaky_relu,
ttir.LeakyReluOp,
[in0],
golden_kwargs=golden_kwargs,
ttir_kwargs=ttir_kwargs,
)

def clamp(
self,
in0: Operand,
min_arg: Optional[float] = None,
max_arg: Optional[float] = None,
) -> OpView:
kwargs = {"min": min_arg, "max": max_arg}
return self.op_proxy(
torch.clamp,
ttir.ClampOp,
[in0],
ttir_kwargs=kwargs,
golden_kwargs=kwargs,
organize_ttir_args=lambda i, o, _: (self._get_type(o), i[0], o),
)

def matmul(
self, in0: Operand, in1: Operand, bias: Optional[Operand] = None
) -> OpView:
Expand Down
165 changes: 164 additions & 1 deletion test/python/golden/test_ttir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,29 @@
import inspect

from ttmlir.test_utils import compile_to_flatbuffer
from ttmlir.ttir_builder import Operand, TTIRBuilder
from ttmlir.ttir_builder import Operand, TTIRBuilder, Attribute


@compile_to_flatbuffer([(128, 128)])
def test_exp(in0: Operand, builder: TTIRBuilder):
return builder.exp(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_expm1(in0: Operand, builder: TTIRBuilder):
return builder.expm1(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_ceil(in0: Operand, builder: TTIRBuilder):
return builder.ceil(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_floor(in0: Operand, builder: TTIRBuilder):
return builder.floor(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_abs(in0: Operand, builder: TTIRBuilder):
return builder.abs(in0)
Expand All @@ -25,21 +40,83 @@ def test_logical_not(in0: Operand, builder: TTIRBuilder):
return builder.logical_not(in0)


# TODO: uncomment once we have control over generated input types (bitwise ops
# don't support floats) (see issue #1765)
# @compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
# def test_bitwise_not(in0: Operand, builder: TTIRBuilder):
# return builder.bitwise_not(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_neg(in0: Operand, builder: TTIRBuilder):
return builder.neg(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_sign(in0: Operand, builder: TTIRBuilder):
return builder.sign(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_sin(in0: Operand, builder: TTIRBuilder):
return builder.sin(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_cos(in0: Operand, builder: TTIRBuilder):
return builder.cos(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_tan(in0: Operand, builder: TTIRBuilder):
return builder.tan(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_tanh(in0: Operand, builder: TTIRBuilder):
return builder.tanh(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_log(in0: Operand, builder: TTIRBuilder):
return builder.log(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_log1p(in0: Operand, builder: TTIRBuilder):
return builder.log1p(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_relu(in0: Operand, builder: TTIRBuilder):
return builder.relu(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_gelu(in0: Operand, builder: TTIRBuilder):
return builder.gelu(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_clamp(in0: Operand, builder: TTIRBuilder):
return builder.clamp(in0, max_arg=1.0, min_arg=0.0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_leaky_relu(in0: Operand, builder: TTIRBuilder):
return builder.leaky_relu(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_sqrt(in0: Operand, builder: TTIRBuilder):
return builder.sqrt(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_cbrt(in0: Operand, builder: TTIRBuilder):
return builder.cbrt(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_rsqrt(in0: Operand, builder: TTIRBuilder):
return builder.rsqrt(in0)
Expand All @@ -55,6 +132,11 @@ def test_reciprocal(in0: Operand, builder: TTIRBuilder):
return builder.reciprocal(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
def test_is_finite(in0: Operand, builder: TTIRBuilder):
return builder.is_finite(in0)


@compile_to_flatbuffer(
[
(64, 128),
Expand Down Expand Up @@ -97,6 +179,52 @@ def test_logical_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_or(in0, in1)


@compile_to_flatbuffer(
[
(64, 64),
(64, 64),
],
targets=["ttnn"],
)
def test_logical_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_xor(in0, in1)


# TODO: uncomment once we have control over generated input types (bitwise ops
# don't support floats) (see issue #1765)
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
# return builder.bitwise_and(in0, in1)
#
#
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
# return builder.bitwise_or(in0, in1)
#
#
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_bitwise_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
# return builder.bitwise_xor(in0, in1)


@compile_to_flatbuffer(
[
(64, 64),
Expand Down Expand Up @@ -185,6 +313,17 @@ def test_div(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.div(in0, in1)


@compile_to_flatbuffer(
[
(64, 64),
(64, 64),
],
targets=["ttnn"],
)
def test_remainder(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.remainder(in0, in1)


@compile_to_flatbuffer(
[
(64, 64),
Expand All @@ -196,6 +335,30 @@ def test_maximum(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.maximum(in0, in1)


@compile_to_flatbuffer(
[
(64, 64),
(64, 64),
],
targets=["ttnn"],
)
def test_minimum(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.minimum(in0, in1)


# TODO: uncomment when we have control over the input types
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_where(in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder):
# return builder.where(in0, in1, in2)


@compile_to_flatbuffer(
[
(32, 32),
Expand Down

0 comments on commit 043d0a5

Please sign in to comment.