Skip to content

Commit

Permalink
Merge branch 'main' into kmitrovic/temp
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT authored Jan 15, 2025
2 parents 9ee8657 + d129a9e commit c7ed7b2
Show file tree
Hide file tree
Showing 25 changed files with 635 additions and 247 deletions.
25 changes: 24 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIRBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define TTMLIR_TTMLIR_DIALECT_TTIR_TTIRDIALECT_TD

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -41,6 +42,28 @@ def TTIR_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTIR_Op<string mnemonic, list<Trait> traits = []> :
Op<TTIR_Dialect, mnemonic, !listconcat(traits, [Pure])>;
Op<TTIR_Dialect, mnemonic, !listconcat([Pure], traits)>;

//===----------------------------------------------------------------------===//
// TTIR traits definition.
//===----------------------------------------------------------------------===//

def TwoOperands : ParamNativeOpTrait<"NOperands", "2">;
def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">;
def FourOperands : ParamNativeOpTrait<"NOperands", "4">;

class TTIR_Trait<string name, list<Trait> traits = []> : NativeOpTrait<name, traits> {
let cppNamespace = "::mlir::tt::ttir::OpTrait";
}

// Involution is property of an operation where applying the operation twice results in the original value.
// Example: not(not(x)) == x
def TTIR_Involution : TTIR_Trait<"TTIRInvolution", [DestinationStyleOpInterface, TwoOperands, NoMemoryEffect]>;
// Idempotence is property of an operation where applying the operation twice results in the same value as applying it once.
// Example: abs(abs(x)) == abs(x)
def TTIR_Idempotence : TTIR_Trait<"TTIRIdempotence", [DestinationStyleOpInterface, TwoOperands, NoMemoryEffect]>;
// BinaryIdempotence is property of a binary operation where applying the operation on the same value results in the same value.
// Example: and(x, x) == x
def TTIR_BinaryIdempotence : TTIR_Trait<"TTIRBinaryIdempotence", [DestinationStyleOpInterface, ThreeOperands, NoMemoryEffect]>;

#endif
8 changes: 5 additions & 3 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
#ifndef TTMLIR_DIALECT_TTIR_IR_TTIROPS_H
#define TTMLIR_DIALECT_TTIR_IR_TTIROPS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h"
#include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h"

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

#include "TTIROpsInterfaces.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.h.inc"

Expand Down
70 changes: 44 additions & 26 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/OpBase.td"

class TTIR_DPSOp<string mnemonic, list<Trait> traits = []> :
TTIR_Op<mnemonic, !listconcat(traits, [TTIROpInterface, DestinationStyleOpInterface])> {
TTIR_Op<mnemonic, !listconcat([TTIROpInterface, DestinationStyleOpInterface], traits)> {
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
}];
Expand Down Expand Up @@ -98,8 +98,9 @@ def TTIR_GetDimensionSizeOp : TTIR_Op<"get_dimension_size"> {

let results = (outs AnyRankedTensor:$result);

let hasFolder = 1;
let hasVerifier = 1;

let hasFolder = 1;
}

def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> {
Expand Down Expand Up @@ -186,12 +187,8 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> {
// TTIR top level named ops
//===----------------------------------------------------------------------===//

def TwoOperands : ParamNativeOpTrait<"NOperands", "2">;
def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">;
def FourOperands : ParamNativeOpTrait<"NOperands", "4">;

class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [AttrSizedOperandSegments, TTIR_Broadcastable])> {
TTIR_DPSOp<mnemonic, !listconcat([AttrSizedOperandSegments, TTIR_Broadcastable], traits)> {

let description = [{
Base class for elementwise operations. Elementwise operations can take inputs with different shape,
Expand All @@ -204,7 +201,7 @@ class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
}

class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [FourOperands])> {
TTIR_ElementwiseOp<mnemonic, !listconcat([FourOperands], traits)> {
let summary = "Eltwise ternary op.";
let description = [{
Eltwise ternary op.
Expand All @@ -227,7 +224,7 @@ def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> {
}

class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [TwoOperands])> {
TTIR_ElementwiseOp<mnemonic, !listconcat([TwoOperands], traits)> {
let summary = "Eltwise unary op.";
let description = [{
Eltwise unary op.
Expand All @@ -242,7 +239,7 @@ class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> {
def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs", [TTIR_Idempotence]> {
let summary = "Eltwise absolute op.";
let description = [{
Eltwise absolute operation.
Expand All @@ -256,7 +253,7 @@ def TTIR_CbrtOp: TTIR_ElementwiseUnaryOp<"cbrt"> {
}];
}

def TTIR_CeilOp: TTIR_ElementwiseUnaryOp<"ceil"> {
def TTIR_CeilOp: TTIR_ElementwiseUnaryOp<"ceil", [TTIR_Idempotence]> {
let summary = "Eltwise ceil op.";
let description = [{
Eltwise ceil operation.
Expand All @@ -270,7 +267,7 @@ def TTIR_CosOp: TTIR_ElementwiseUnaryOp<"cos"> {
}];
}

def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor"> {
def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor", [TTIR_Idempotence]> {
let summary = "Eltwise floor op.";
let description = [{
Eltwise floor operation.
Expand Down Expand Up @@ -298,7 +295,7 @@ def TTIR_LogicalNotOp: TTIR_ElementwiseUnaryOp<"logical_not"> {
}];
}

def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> {
def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not", [TTIR_Involution]> {
let summary = "Eltwise bitwise NOT.";
let description = [{
Performs element-wise NOT of tensor `operand` and produces a `result` tensor.
Expand All @@ -311,7 +308,7 @@ def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> {
}];
}

def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg"> {
def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg", [TTIR_Involution]> {
let summary = "Eltwise negate op.";
let description = [{
Eltwise negate operation.
Expand All @@ -332,14 +329,17 @@ def TTIR_TanhOp: TTIR_ElementwiseUnaryOp<"tanh"> {
}];
}

def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal"> {
// TODO (azecevic): What should we do with 0.0 case?
// 1/0.0 = inf, 1/inf = 0.0, but TTNN isn't IEEE754 compliant.
// https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Handling_Special_Value/special_values.md
def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal", [TTIR_Involution]> {
let summary = "Eltwise reciprocal.";
let description = [{
Eltwise reciprocal operation.
}];
}

def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> {
def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu", [TTIR_Idempotence]> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
Expand All @@ -360,7 +360,7 @@ def TTIR_SigmoidOp: TTIR_ElementwiseUnaryOp<"sigmoid"> {
}];
}

def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign"> {
def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign", [TTIR_Idempotence]> {
let summary = "Eltwise sign operation.";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
Expand Down Expand Up @@ -469,7 +469,7 @@ def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
}

class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [ThreeOperands])> {
TTIR_ElementwiseOp<mnemonic, !listconcat([ThreeOperands], traits)> {
let summary = "Eltwise binary op.";
let description = [{
Eltwise binary op.
Expand All @@ -484,56 +484,62 @@ class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

// TODO (azecevic): NaN != NaN, otherwise eq(x, x) == 1.
def TTIR_EqualOp : TTIR_ElementwiseBinaryOp<"eq"> {
let summary = "Eltwise equal to.";
let description = [{
Eltwise equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise ne(x, x) == 0.
def TTIR_NotEqualOp : TTIR_ElementwiseBinaryOp<"ne"> {
let summary = "Eltwise not equal to.";
let description = [{
Eltwise not equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise ge(x, x) == 1.
def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> {
let summary = "Eltwise greater than or equal to.";
let description = [{
Eltwise greater than or equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise gt(x, x) == 0.
def TTIR_GreaterThanOp : TTIR_ElementwiseBinaryOp<"gt"> {
let summary = "Eltwise greater than.";
let description = [{
Eltwise greater than operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise le(x, x) == 1.
def TTIR_LessEqualOp : TTIR_ElementwiseBinaryOp<"le"> {
let summary = "Eltwise less than or equal to.";
let description = [{
Eltwise less than or equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise lt(x, x) == 0.
def TTIR_LessThanOp : TTIR_ElementwiseBinaryOp<"lt"> {
let summary = "Eltwise less than.";
let description = [{
Eltwise less than operation.
}];
}

def TTIR_LogicalAndOp : TTIR_ElementwiseBinaryOp<"logical_and"> {
def TTIR_LogicalAndOp : TTIR_ElementwiseBinaryOp<"logical_and", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise logical and.";
let description = [{
Eltwise logical and operation.
}];
}

def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or"> {
def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise logical or.";
let description = [{
Eltwise logical or operation.
Expand All @@ -547,7 +553,7 @@ def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> {
}];
}

def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> {
def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise bitwise AND.";
let description = [{
Performs element-wise bitwise AND of two tensors `lhs` and `rhs`
Expand All @@ -561,7 +567,7 @@ def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> {
}];
}

def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or"> {
def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise bitwise OR.";
let description = [{
Performs element-wise bitwise OR of two tensors `lhs` and `rhs`
Expand All @@ -587,9 +593,11 @@ def TTIR_BitwiseXorOp : TTIR_ElementwiseBinaryOp<"bitwise_xor"> {
%result = "ttir.bitwise_xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
}];

let hasCanonicalizer = 1;
}

def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum"> {
def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise minimum OP.";
let description = [{
Calculates minimum of input tensors' values element-wise and stores result
Expand Down Expand Up @@ -625,7 +633,7 @@ def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder"> {
}

class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {
TTIR_DPSOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let summary = "Reduction op.";
let description = [{
Expand Down Expand Up @@ -772,6 +780,8 @@ def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> {
}];

let hasVerifier = 1;

let hasCanonicalizer = 1;
}

def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
Expand Down Expand Up @@ -854,6 +864,8 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
}];

let hasVerifier = 1;

let hasFolder = 1;
}

def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
Expand Down Expand Up @@ -1245,6 +1257,8 @@ def TTIR_ReverseOp : TTIR_DPSOp<"reverse", [AllShapesMatch<["input", "result"]>]
}];

let hasVerifier = 1;

let hasCanonicalizer = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
Expand Down Expand Up @@ -1314,6 +1328,8 @@ def TTIR_LinearOp : TTIR_DPSOp<"linear"> {
}];

let hasVerifier = 1;

let hasCanonicalizeMethod = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
Expand Down Expand Up @@ -1362,14 +1378,16 @@ def TTIR_PermuteOp : TTIR_DPSOp<"permute"> {
}];

let hasVerifier = 1;

let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// TTIR top level generic ops
//===----------------------------------------------------------------------===//

class TTIR_GenericElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseUnaryOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {
TTIR_ElementwiseUnaryOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
Expand Down Expand Up @@ -1400,7 +1418,7 @@ def TTIR_ExpOp: TTIR_GenericElementwiseUnaryOp<"exp"> {
}

class TTIR_GenericElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseBinaryOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {
TTIR_ElementwiseBinaryOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
Expand Down
Loading

0 comments on commit c7ed7b2

Please sign in to comment.