Skip to content

Commit

Permalink
Canonicalization of TTIR ops (#1670)
Browse files Browse the repository at this point in the history
Canonicalization can be roughly divided into three forms:
1. Traits (properties) of some ops that allows them to be folded
(Involution, Idempotence)
2. Per op foldings
4. Per op canonicalization (pattern rewriters)

The first one allows us to decleratively add folding for a large class
of ops. While traits like `Involution` already exists in `MLIR`
infrastructure it doesn't account for DPS, so I added a new one that
takes care of it.
The second one allows us in practice to define some simple graph
rewritings where we replace the producing value of op with some existing
value (or constant).
The third one gives us the most freedom, where we can do arbitrary graph
rewritings, we usually use it when we have to create a new op during
rewriting.

I added a `canonicalize` pass both before and after
`ttir-to-ttir-decomposition-pass` in `ttir-to-ttnn-backend-pipeline`, as
there are many graph rewritings during that pass, so we might benefit
form canonicalization both before and after.

I plan to cover one big part of canonicalization in the future, and
that's constant folding. I will also write a short document on adding a
canonicalization for new and existing ops. While this PR covers a lot of
patterns, it's definitely not an exhaustive list. This MLIR
[doc](https://mlir.llvm.org/docs/Canonicalization/) is already a great
source of information, I just believe we might also benefit from the
additional context of TTIR dialect.

This PR should cover a big part of
#1264, as I said the
missing part is constant folding.
  • Loading branch information
azecevicTT authored Jan 15, 2025
1 parent 8f194e6 commit d129a9e
Show file tree
Hide file tree
Showing 23 changed files with 634 additions and 244 deletions.
25 changes: 24 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIRBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define TTMLIR_TTMLIR_DIALECT_TTIR_TTIRDIALECT_TD

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -41,6 +42,28 @@ def TTIR_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTIR_Op<string mnemonic, list<Trait> traits = []> :
Op<TTIR_Dialect, mnemonic, !listconcat(traits, [Pure])>;
Op<TTIR_Dialect, mnemonic, !listconcat([Pure], traits)>;

//===----------------------------------------------------------------------===//
// TTIR traits definition.
//===----------------------------------------------------------------------===//

def TwoOperands : ParamNativeOpTrait<"NOperands", "2">;
def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">;
def FourOperands : ParamNativeOpTrait<"NOperands", "4">;

class TTIR_Trait<string name, list<Trait> traits = []> : NativeOpTrait<name, traits> {
let cppNamespace = "::mlir::tt::ttir::OpTrait";
}

// Involution is property of an operation where applying the operation twice results in the original value.
// Example: not(not(x)) == x
def TTIR_Involution : TTIR_Trait<"TTIRInvolution", [DestinationStyleOpInterface, TwoOperands, NoMemoryEffect]>;
// Idempotence is property of an operation where applying the operation twice results in the same value as applying it once.
// Example: abs(abs(x)) == abs(x)
def TTIR_Idempotence : TTIR_Trait<"TTIRIdempotence", [DestinationStyleOpInterface, TwoOperands, NoMemoryEffect]>;
// BinaryIdempotence is property of a binary operation where applying the operation on the same value results in the same value.
// Example: and(x, x) == x
def TTIR_BinaryIdempotence : TTIR_Trait<"TTIRBinaryIdempotence", [DestinationStyleOpInterface, ThreeOperands, NoMemoryEffect]>;

#endif
8 changes: 5 additions & 3 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
#ifndef TTMLIR_DIALECT_TTIR_IR_TTIROPS_H
#define TTMLIR_DIALECT_TTIR_IR_TTIROPS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h"
#include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h"

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

#include "TTIROpsInterfaces.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.h.inc"

Expand Down
70 changes: 44 additions & 26 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/OpBase.td"

class TTIR_DPSOp<string mnemonic, list<Trait> traits = []> :
TTIR_Op<mnemonic, !listconcat(traits, [TTIROpInterface, DestinationStyleOpInterface])> {
TTIR_Op<mnemonic, !listconcat([TTIROpInterface, DestinationStyleOpInterface], traits)> {
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
}];
Expand Down Expand Up @@ -98,8 +98,9 @@ def TTIR_GetDimensionSizeOp : TTIR_Op<"get_dimension_size"> {

let results = (outs AnyRankedTensor:$result);

let hasFolder = 1;
let hasVerifier = 1;

let hasFolder = 1;
}

def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpInterface]> {
Expand Down Expand Up @@ -186,12 +187,8 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> {
// TTIR top level named ops
//===----------------------------------------------------------------------===//

def TwoOperands : ParamNativeOpTrait<"NOperands", "2">;
def ThreeOperands : ParamNativeOpTrait<"NOperands", "3">;
def FourOperands : ParamNativeOpTrait<"NOperands", "4">;

class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [AttrSizedOperandSegments, TTIR_Broadcastable])> {
TTIR_DPSOp<mnemonic, !listconcat([AttrSizedOperandSegments, TTIR_Broadcastable], traits)> {

let description = [{
Base class for elementwise operations. Elementwise operations can take inputs with different shape,
Expand All @@ -204,7 +201,7 @@ class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
}

class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [FourOperands])> {
TTIR_ElementwiseOp<mnemonic, !listconcat([FourOperands], traits)> {
let summary = "Eltwise ternary op.";
let description = [{
Eltwise ternary op.
Expand All @@ -227,7 +224,7 @@ def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> {
}

class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [TwoOperands])> {
TTIR_ElementwiseOp<mnemonic, !listconcat([TwoOperands], traits)> {
let summary = "Eltwise unary op.";
let description = [{
Eltwise unary op.
Expand All @@ -242,7 +239,7 @@ class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> {
def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs", [TTIR_Idempotence]> {
let summary = "Eltwise absolute op.";
let description = [{
Eltwise absolute operation.
Expand All @@ -256,7 +253,7 @@ def TTIR_CbrtOp: TTIR_ElementwiseUnaryOp<"cbrt"> {
}];
}

def TTIR_CeilOp: TTIR_ElementwiseUnaryOp<"ceil"> {
def TTIR_CeilOp: TTIR_ElementwiseUnaryOp<"ceil", [TTIR_Idempotence]> {
let summary = "Eltwise ceil op.";
let description = [{
Eltwise ceil operation.
Expand All @@ -270,7 +267,7 @@ def TTIR_CosOp: TTIR_ElementwiseUnaryOp<"cos"> {
}];
}

def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor"> {
def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor", [TTIR_Idempotence]> {
let summary = "Eltwise floor op.";
let description = [{
Eltwise floor operation.
Expand Down Expand Up @@ -298,7 +295,7 @@ def TTIR_LogicalNotOp: TTIR_ElementwiseUnaryOp<"logical_not"> {
}];
}

def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> {
def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not", [TTIR_Involution]> {
let summary = "Eltwise bitwise NOT.";
let description = [{
Performs element-wise NOT of tensor `operand` and produces a `result` tensor.
Expand All @@ -311,7 +308,7 @@ def TTIR_BitwiseNotOp : TTIR_ElementwiseUnaryOp<"bitwise_not"> {
}];
}

def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg"> {
def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg", [TTIR_Involution]> {
let summary = "Eltwise negate op.";
let description = [{
Eltwise negate operation.
Expand All @@ -332,14 +329,17 @@ def TTIR_TanhOp: TTIR_ElementwiseUnaryOp<"tanh"> {
}];
}

def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal"> {
// TODO (azecevic): What should we do with 0.0 case?
// 1/0.0 = inf, 1/inf = 0.0, but TTNN isn't IEEE754 compliant.
// https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Handling_Special_Value/special_values.md
def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal", [TTIR_Involution]> {
let summary = "Eltwise reciprocal.";
let description = [{
Eltwise reciprocal operation.
}];
}

def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> {
def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu", [TTIR_Idempotence]> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
Expand All @@ -360,7 +360,7 @@ def TTIR_SigmoidOp: TTIR_ElementwiseUnaryOp<"sigmoid"> {
}];
}

def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign"> {
def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign", [TTIR_Idempotence]> {
let summary = "Eltwise sign operation.";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
Expand Down Expand Up @@ -469,7 +469,7 @@ def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
}

class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, !listconcat(traits, [ThreeOperands])> {
TTIR_ElementwiseOp<mnemonic, !listconcat([ThreeOperands], traits)> {
let summary = "Eltwise binary op.";
let description = [{
Eltwise binary op.
Expand All @@ -484,56 +484,62 @@ class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

// TODO (azecevic): NaN != NaN, otherwise eq(x, x) == 1.
def TTIR_EqualOp : TTIR_ElementwiseBinaryOp<"eq"> {
let summary = "Eltwise equal to.";
let description = [{
Eltwise equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise ne(x, x) == 0.
def TTIR_NotEqualOp : TTIR_ElementwiseBinaryOp<"ne"> {
let summary = "Eltwise not equal to.";
let description = [{
Eltwise not equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise ge(x, x) == 1.
def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> {
let summary = "Eltwise greater than or equal to.";
let description = [{
Eltwise greater than or equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise gt(x, x) == 0.
def TTIR_GreaterThanOp : TTIR_ElementwiseBinaryOp<"gt"> {
let summary = "Eltwise greater than.";
let description = [{
Eltwise greater than operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise le(x, x) == 1.
def TTIR_LessEqualOp : TTIR_ElementwiseBinaryOp<"le"> {
let summary = "Eltwise less than or equal to.";
let description = [{
Eltwise less than or equal to operation.
}];
}

// TODO (azecevic): NaN != NaN, otherwise lt(x, x) == 0.
def TTIR_LessThanOp : TTIR_ElementwiseBinaryOp<"lt"> {
let summary = "Eltwise less than.";
let description = [{
Eltwise less than operation.
}];
}

def TTIR_LogicalAndOp : TTIR_ElementwiseBinaryOp<"logical_and"> {
def TTIR_LogicalAndOp : TTIR_ElementwiseBinaryOp<"logical_and", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise logical and.";
let description = [{
Eltwise logical and operation.
}];
}

def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or"> {
def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise logical or.";
let description = [{
Eltwise logical or operation.
Expand All @@ -547,7 +553,7 @@ def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> {
}];
}

def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> {
def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise bitwise AND.";
let description = [{
Performs element-wise bitwise AND of two tensors `lhs` and `rhs`
Expand All @@ -561,7 +567,7 @@ def TTIR_BitwiseAndOp : TTIR_ElementwiseBinaryOp<"bitwise_and"> {
}];
}

def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or"> {
def TTIR_BitwiseOrOp : TTIR_ElementwiseBinaryOp<"bitwise_or", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise bitwise OR.";
let description = [{
Performs element-wise bitwise OR of two tensors `lhs` and `rhs`
Expand All @@ -587,9 +593,11 @@ def TTIR_BitwiseXorOp : TTIR_ElementwiseBinaryOp<"bitwise_xor"> {
%result = "ttir.bitwise_xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
}];

let hasCanonicalizer = 1;
}

def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum"> {
def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotence]> {
let summary = "Eltwise minimum OP.";
let description = [{
Calculates minimum of input tensors' values element-wise and stores result
Expand Down Expand Up @@ -625,7 +633,7 @@ def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder"> {
}

class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {
TTIR_DPSOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let summary = "Reduction op.";
let description = [{
Expand Down Expand Up @@ -772,6 +780,8 @@ def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> {
}];

let hasVerifier = 1;

let hasCanonicalizer = 1;
}

def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
Expand Down Expand Up @@ -854,6 +864,8 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
}];

let hasVerifier = 1;

let hasFolder = 1;
}

def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
Expand Down Expand Up @@ -1245,6 +1257,8 @@ def TTIR_ReverseOp : TTIR_DPSOp<"reverse", [AllShapesMatch<["input", "result"]>]
}];

let hasVerifier = 1;

let hasCanonicalizer = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
Expand Down Expand Up @@ -1314,6 +1328,8 @@ def TTIR_LinearOp : TTIR_DPSOp<"linear"> {
}];

let hasVerifier = 1;

let hasCanonicalizeMethod = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
Expand Down Expand Up @@ -1362,14 +1378,16 @@ def TTIR_PermuteOp : TTIR_DPSOp<"permute"> {
}];

let hasVerifier = 1;

let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// TTIR top level generic ops
//===----------------------------------------------------------------------===//

class TTIR_GenericElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseUnaryOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {
TTIR_ElementwiseUnaryOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
Expand Down Expand Up @@ -1400,7 +1418,7 @@ def TTIR_ExpOp: TTIR_GenericElementwiseUnaryOp<"exp"> {
}

class TTIR_GenericElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseBinaryOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {
TTIR_ElementwiseBinaryOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
Expand Down
Loading

0 comments on commit d129a9e

Please sign in to comment.