Skip to content

Commit

Permalink
Introduce TTIR traits to track partial or full implicit folding of bi…
Browse files Browse the repository at this point in the history
…nary eltwise operations. Add pass to apply folding when these traits are present.
  • Loading branch information
uazizTT committed Jan 15, 2025
1 parent d129a9e commit 32cd902
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 8 deletions.
16 changes: 8 additions & 8 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> {
def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where", [TTIR_PartiallyBroadcastable]> {
let summary = "Eltwise where op.";
let description = [{
Eltwise where operation.
Expand Down Expand Up @@ -597,7 +597,7 @@ def TTIR_BitwiseXorOp : TTIR_ElementwiseBinaryOp<"bitwise_xor"> {
let hasCanonicalizer = 1;
}

def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotence]> {
def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotence, TTIR_PartiallyBroadcastable]> {
let summary = "Eltwise minimum OP.";
let description = [{
Calculates minimum of input tensors' values element-wise and stores result
Expand All @@ -610,14 +610,14 @@ def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotenc
}];
}

def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract", [TTIR_PartiallyBroadcastable]> {
let summary = "Eltwise subtract.";
let description = [{
Eltwise subtract operation.
}];
}

def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder"> {
def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder", [TTIR_PartiallyBroadcastable]> {
let summary = "Eltwise remainder.";
let description = [{
Performs element-wise remainder of dividend lhs and divisor rhs tensors and produces a
Expand Down Expand Up @@ -1439,28 +1439,28 @@ class TTIR_GenericElementwiseBinaryOp<string mnemonic, list<Trait> traits = []>
}];
}

def TTIR_AddOp : TTIR_GenericElementwiseBinaryOp<"add"> {
def TTIR_AddOp : TTIR_GenericElementwiseBinaryOp<"add", [TTIR_FullyBroadcastable]> {
let summary = "Eltwise add.";
let description = [{
Eltwise add operation.
}];
}

def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply"> {
def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply", [TTIR_FullyBroadcastable]> {
let summary = "Eltwise multiply.";
let description = [{
Eltwise multiply operation.
}];
}

def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div"> {
def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div", [TTIR_PartiallyBroadcastable]> {
let summary = "Eltwise divide.";
let description = [{
Eltwise divide operation.
}];
}

def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum"> {
def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum", [TTIR_PartiallyBroadcastable]> {
let summary = "Eltwise maximum.";
let description = [{
Calculates maximum of input tensors' values element-wise and stores result in output tensor.
Expand Down
10 changes: 10 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,14 @@ def TTIR_GenericRegionOpInterface : OpInterface<"GenericRegionOp"> {
];
}

def TTIR_PartiallyBroadcastable : OpInterface<"partiallyBroadcastable"> {
// Supports implicit broadcast for first operand only.
let cppNamespace = "::mlir::tt::ttir";
}

def TTIR_FullyBroadcastable : OpInterface<"fullyBroadcastable"> {
// Supports implicit broadcast for all the operands.
let cppNamespace = "::mlir::tt::ttir";
}

#endif
17 changes: 17 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
];
}

def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> {
let summary = "Broadcast operation is folded to all the consumers.";
let description = [{
This pass walks through the graph and folds broadcasts operations when it is implicitly supported by the operation.

Example:
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array<i32: 1, 16, 1>}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%2 = tensor.empty() : tensor<1x16x32xf32>
%3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>

Since multiplyOp supports implicit broadcasting, above broadcast is folded as:
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32
}];
}

def TTIRHoistTransform: Pass<"ttir-cpu-hoist-transform", "::mlir::ModuleOp">
{
let summary = "Transform to perform hoist mechanics on any ops marked to be hoisted for CPU lowering";
Expand Down
73 changes: 73 additions & 0 deletions lib/Dialect/TTIR/Transforms/Broadcast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRBROADCASTFOLD
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

class TTIRBroadcastFoldRewriter : public RewritePattern {
public:
TTIRBroadcastFoldRewriter(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {

// First check if the op itself has any broadcastable traits
if (op->hasTrait<partiallyBroadcastable::Trait>()) {

// This operation can only fold broadcast operation for Operand 0.
ttir::BroadcastOp broadcastOp =
op->getOperand(0).getDefiningOp<ttir::BroadcastOp>();
if (broadcastOp) {
rewriter.replaceOp(broadcastOp, broadcastOp.getInput());
return success();
}
} else if (op->hasTrait<fullyBroadcastable::Trait>()) {
bool changed = false;
// Check all operands for this op
ttir::BroadcastOp broadcastOp0 =
op->getOperand(0).getDefiningOp<ttir::BroadcastOp>();
ttir::BroadcastOp broadcastOp1 =
op->getOperand(1).getDefiningOp<ttir::BroadcastOp>();
if (broadcastOp0) {
rewriter.replaceOp(broadcastOp0, broadcastOp0.getInput());
changed = true;
} else if (broadcastOp1) {
rewriter.replaceOp(broadcastOp1, broadcastOp1.getInput());
changed = true;
}
return changed ? success() : failure();
}

return failure();
}
};

class TTIRBroadcastFold
: public impl::TTIRBroadcastFoldBase<TTIRBroadcastFold> {
public:
using impl::TTIRBroadcastFoldBase<TTIRBroadcastFold>::TTIRBroadcastFoldBase;
void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<TTIRBroadcastFoldRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));

if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
return;
}
}

void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::tt::ttir::TTIRDialect>();
registry.insert<mlir::tt::TTDialect>();
}
};

} // namespace mlir::tt::ttir
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRTTIRTransforms
Allocate.cpp
Broadcast.cpp
Constant.cpp
Generic.cpp
HoistCPUOps.cpp
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,22 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTNNPipelineTTIRBroadcastFoldPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(mlir::tt::ttir::createTTIRBroadcastFold());
}

void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
TTIRToTTNNBackendPipelineOptions::createFromString(options);
createTTNNPipelineTTIRBroadcastFoldPass(pm, *optionsStruct);
}

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineTTIRBroadcastFoldPass(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineWorkaroundPass(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
Expand Down
12 changes: 12 additions & 0 deletions test/ttmlir/Dialect/TTIR/broadcast/ttir_implicit_broadcast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s

module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> {
// CHECK-NOT: ttir.broadcast
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array<i32: 1, 16, 1>}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%2 = tensor.empty() : tensor<1x16x32xf32>
%3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
return %3 : tensor<1x16x32xf32>
}
}

0 comments on commit 32cd902

Please sign in to comment.