Skip to content

Commit

Permalink
Rename the pass to ImplicitBroadcastFolding. Make the pass optional t…
Browse files Browse the repository at this point in the history
…o help with testing.
  • Loading branch information
uazizTT committed Jan 16, 2025
1 parent 32cd902 commit 4ca895a
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 16 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
];
}

def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> {
def TTIRImplicitBroadcastFold: Pass<"ttir-implicit-broadcast-fold", "::mlir::ModuleOp"> {
let summary = "Broadcast operation is folded to all the consumers.";
let description = [{
This pass walks through the graph and folds broadcasts operations when it is implicitly supported by the operation.
Expand Down
5 changes: 5 additions & 0 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ struct TTIRToTTNNBackendPipelineOptions
*this, "enable-decomposition-workaround-pass",
llvm::cl::desc("Enable decomposition workaround pass."),
llvm::cl::init(true)};

Option<bool> implicitBroadcastFoldingEnabled{
*this, "enable-implicit-broadcast-folding-pass",
llvm::cl::desc("Enable implicit broadcast folding pass."),
llvm::cl::init(true)};
};

// TTIR to EmitC pipeline options.
Expand Down
15 changes: 8 additions & 7 deletions lib/Dialect/TTIR/Transforms/Broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRBROADCASTFOLD
#define GEN_PASS_DEF_TTIRIMPLICITBROADCASTFOLD
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

class TTIRBroadcastFoldRewriter : public RewritePattern {
class TTIRImplicitBroadcastFoldRewriter : public RewritePattern {
public:
TTIRBroadcastFoldRewriter(MLIRContext *ctx)
TTIRImplicitBroadcastFoldRewriter(MLIRContext *ctx)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}

LogicalResult matchAndRewrite(Operation *op,
Expand Down Expand Up @@ -49,13 +49,14 @@ class TTIRBroadcastFoldRewriter : public RewritePattern {
}
};

class TTIRBroadcastFold
: public impl::TTIRBroadcastFoldBase<TTIRBroadcastFold> {
class TTIRImplicitBroadcastFold
: public impl::TTIRImplicitBroadcastFoldBase<TTIRImplicitBroadcastFold> {
public:
using impl::TTIRBroadcastFoldBase<TTIRBroadcastFold>::TTIRBroadcastFoldBase;
using impl::TTIRImplicitBroadcastFoldBase<
TTIRImplicitBroadcastFold>::TTIRImplicitBroadcastFoldBase;
void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<TTIRBroadcastFoldRewriter>(&getContext());
patterns.add<TTIRImplicitBroadcastFoldRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));

if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
Expand Down
14 changes: 8 additions & 6 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,24 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTNNPipelineTTIRBroadcastFoldPass(
void createTTNNPipelineTTIRImplicitBroadcastFoldPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(mlir::tt::ttir::createTTIRBroadcastFold());
if (options.implicitBroadcastFoldingEnabled) {
pm.addPass(mlir::tt::ttir::createTTIRImplicitBroadcastFold());
}
}

void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm,
std::string options) {
void createTTNNPipelineTTIRImplicitBroadcastFoldPassFromString(
OpPassManager &pm, std::string options) {
auto optionsStruct =
TTIRToTTNNBackendPipelineOptions::createFromString(options);
createTTNNPipelineTTIRBroadcastFoldPass(pm, *optionsStruct);
createTTNNPipelineTTIRImplicitBroadcastFoldPass(pm, *optionsStruct);
}

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineTTIRBroadcastFoldPass(pm, options);
createTTNNPipelineTTIRImplicitBroadcastFoldPass(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineWorkaroundPass(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_repeat.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-implicit-broadcast-folding-pass=false" %s | FileCheck %s
module {
func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Silicon/TTNN/simple_repeat.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-implicit-broadcast-folding-pass=false system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module {
Expand Down

0 comments on commit 4ca895a

Please sign in to comment.