From 32cd9025e91bcfba02384492c825ce77fccdc1f4 Mon Sep 17 00:00:00 2001 From: uazizTT Date: Wed, 15 Jan 2025 14:57:24 -0500 Subject: [PATCH] Introduce TTIR traits to track partial or full implicit folding of binary eltwise operations. Add pass to apply folding when these traits are present. --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 16 ++-- .../Dialect/TTIR/IR/TTIROpsInterfaces.td | 10 +++ .../ttmlir/Dialect/TTIR/Transforms/Passes.td | 17 +++++ lib/Dialect/TTIR/Transforms/Broadcast.cpp | 73 +++++++++++++++++++ lib/Dialect/TTIR/Transforms/CMakeLists.txt | 1 + lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 13 ++++ .../broadcast/ttir_implicit_broadcast.mlir | 12 +++ 7 files changed, 134 insertions(+), 8 deletions(-) create mode 100644 lib/Dialect/TTIR/Transforms/Broadcast.cpp create mode 100644 test/ttmlir/Dialect/TTIR/broadcast/ttir_implicit_broadcast.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 710f88cfe..cf8b386d5 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -216,7 +216,7 @@ class TTIR_ElementwiseTernaryOp traits = []> : ]; } -def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> { +def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where", [TTIR_PartiallyBroadcastable]> { let summary = "Eltwise where op."; let description = [{ Eltwise where operation. @@ -597,7 +597,7 @@ def TTIR_BitwiseXorOp : TTIR_ElementwiseBinaryOp<"bitwise_xor"> { let hasCanonicalizer = 1; } -def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotence]> { +def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotence, TTIR_PartiallyBroadcastable]> { let summary = "Eltwise minimum OP."; let description = [{ Calculates minimum of input tensors' values element-wise and stores result @@ -610,14 +610,14 @@ def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum", [TTIR_BinaryIdempotenc }]; } -def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> { +def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract", [TTIR_PartiallyBroadcastable]> { let summary = "Eltwise subtract."; let description = [{ Eltwise subtract operation. }]; } -def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder"> { +def TTIR_RemainderOp : TTIR_ElementwiseBinaryOp<"remainder", [TTIR_PartiallyBroadcastable]> { let summary = "Eltwise remainder."; let description = [{ Performs element-wise remainder of dividend lhs and divisor rhs tensors and produces a @@ -1439,28 +1439,28 @@ class TTIR_GenericElementwiseBinaryOp traits = []> }]; } -def TTIR_AddOp : TTIR_GenericElementwiseBinaryOp<"add"> { +def TTIR_AddOp : TTIR_GenericElementwiseBinaryOp<"add", [TTIR_FullyBroadcastable]> { let summary = "Eltwise add."; let description = [{ Eltwise add operation. }]; } -def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply"> { +def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply", [TTIR_FullyBroadcastable]> { let summary = "Eltwise multiply."; let description = [{ Eltwise multiply operation. }]; } -def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div"> { +def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div", [TTIR_PartiallyBroadcastable]> { let summary = "Eltwise divide."; let description = [{ Eltwise divide operation. }]; } -def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum"> { +def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum", [TTIR_PartiallyBroadcastable]> { let summary = "Eltwise maximum."; let description = [{ Calculates maximum of input tensors' values element-wise and stores result in output tensor. diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td index 64c314279..30668813f 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td @@ -114,4 +114,14 @@ def TTIR_GenericRegionOpInterface : OpInterface<"GenericRegionOp"> { ]; } +def TTIR_PartiallyBroadcastable : OpInterface<"partiallyBroadcastable"> { + // Supports implicit broadcast for first operand only. + let cppNamespace = "::mlir::tt::ttir"; +} + +def TTIR_FullyBroadcastable : OpInterface<"fullyBroadcastable"> { + // Supports implicit broadcast for all the operands. + let cppNamespace = "::mlir::tt::ttir"; +} + #endif diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 3922be687..ec585f788 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -112,6 +112,23 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { ]; } +def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> { + let summary = "Broadcast operation is folded to all the consumers."; + let description = [{ + This pass walks through the graph and folds broadcasts operations when it is implicitly supported by the operation. + + Example: + %0 = tensor.empty() : tensor<1x16x32xf32> + %1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> + %2 = tensor.empty() : tensor<1x16x32xf32> + %3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> + + Since multiplyOp supports implicit broadcasting, above broadcast is folded as: + %0 = tensor.empty() : tensor<1x16x32xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32 + }]; +} + def TTIRHoistTransform: Pass<"ttir-cpu-hoist-transform", "::mlir::ModuleOp"> { let summary = "Transform to perform hoist mechanics on any ops marked to be hoisted for CPU lowering"; diff --git a/lib/Dialect/TTIR/Transforms/Broadcast.cpp b/lib/Dialect/TTIR/Transforms/Broadcast.cpp new file mode 100644 index 000000000..4c5aef7d3 --- /dev/null +++ b/lib/Dialect/TTIR/Transforms/Broadcast.cpp @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" +#include +namespace mlir::tt::ttir { +#define GEN_PASS_DEF_TTIRBROADCASTFOLD +#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" + +class TTIRBroadcastFoldRewriter : public RewritePattern { +public: + TTIRBroadcastFoldRewriter(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + + // First check if the op itself has any broadcastable traits + if (op->hasTrait()) { + + // This operation can only fold broadcast operation for Operand 0. + ttir::BroadcastOp broadcastOp = + op->getOperand(0).getDefiningOp(); + if (broadcastOp) { + rewriter.replaceOp(broadcastOp, broadcastOp.getInput()); + return success(); + } + } else if (op->hasTrait()) { + bool changed = false; + // Check all operands for this op + ttir::BroadcastOp broadcastOp0 = + op->getOperand(0).getDefiningOp(); + ttir::BroadcastOp broadcastOp1 = + op->getOperand(1).getDefiningOp(); + if (broadcastOp0) { + rewriter.replaceOp(broadcastOp0, broadcastOp0.getInput()); + changed = true; + } else if (broadcastOp1) { + rewriter.replaceOp(broadcastOp1, broadcastOp1.getInput()); + changed = true; + } + return changed ? success() : failure(); + } + + return failure(); + } +}; + +class TTIRBroadcastFold + : public impl::TTIRBroadcastFoldBase { +public: + using impl::TTIRBroadcastFoldBase::TTIRBroadcastFoldBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { + signalPassFailure(); + return; + } + } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } +}; + +} // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTIR/Transforms/CMakeLists.txt b/lib/Dialect/TTIR/Transforms/CMakeLists.txt index 1b0164ac4..262b4455b 100644 --- a/lib/Dialect/TTIR/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTIR/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTTIRTransforms Allocate.cpp + Broadcast.cpp Constant.cpp Generic.cpp HoistCPUOps.cpp diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index c7d91bc01..da1874a3d 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -120,9 +120,22 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } +void createTTNNPipelineTTIRBroadcastFoldPass( + OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { + pm.addPass(mlir::tt::ttir::createTTIRBroadcastFold()); +} + +void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm, + std::string options) { + auto optionsStruct = + TTIRToTTNNBackendPipelineOptions::createFromString(options); + createTTNNPipelineTTIRBroadcastFoldPass(pm, *optionsStruct); +} + void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { createTTNNPipelineTTIRPasses(pm, options); + createTTNNPipelineTTIRBroadcastFoldPass(pm, options); createTTNNPipelineLoweringPasses(pm, options); createTTNNPipelineWorkaroundPass(pm, options); createTTNNPipelineAnalysisPasses(pm, options); diff --git a/test/ttmlir/Dialect/TTIR/broadcast/ttir_implicit_broadcast.mlir b/test/ttmlir/Dialect/TTIR/broadcast/ttir_implicit_broadcast.mlir new file mode 100644 index 000000000..7cd2387dc --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/broadcast/ttir_implicit_broadcast.mlir @@ -0,0 +1,12 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s + +module attributes {tt.device = #device, tt.system_desc = #system_desc} { + func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> { + // CHECK-NOT: ttir.broadcast + %0 = tensor.empty() : tensor<1x16x32xf32> + %1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> + %2 = tensor.empty() : tensor<1x16x32xf32> + %3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> + return %3 : tensor<1x16x32xf32> + } +}