From 4ca895abdcd0aeba1f5ca7469d807147e3b22508 Mon Sep 17 00:00:00 2001 From: uazizTT Date: Wed, 15 Jan 2025 19:03:56 -0500 Subject: [PATCH] Rename the pass to ImplicitBroadcastFolding. Make the pass optional to help with testing. --- include/ttmlir/Dialect/TTIR/Transforms/Passes.td | 2 +- .../ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h | 5 +++++ lib/Dialect/TTIR/Transforms/Broadcast.cpp | 15 ++++++++------- lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 14 ++++++++------ test/ttmlir/Dialect/TTNN/simple_repeat.mlir | 2 +- test/ttmlir/Silicon/TTNN/simple_repeat.mlir | 2 +- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index ec585f788..06a3ddac8 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -112,7 +112,7 @@ def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> { ]; } -def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> { +def TTIRImplicitBroadcastFold: Pass<"ttir-implicit-broadcast-fold", "::mlir::ModuleOp"> { let summary = "Broadcast operation is folded to all the consumers."; let description = [{ This pass walks through the graph and folds broadcasts operations when it is implicitly supported by the operation. diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 3e8e71de8..a0a9da13c 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -136,6 +136,11 @@ struct TTIRToTTNNBackendPipelineOptions *this, "enable-decomposition-workaround-pass", llvm::cl::desc("Enable decomposition workaround pass."), llvm::cl::init(true)}; + + Option implicitBroadcastFoldingEnabled{ + *this, "enable-implicit-broadcast-folding-pass", + llvm::cl::desc("Enable implicit broadcast folding pass."), + llvm::cl::init(true)}; }; // TTIR to EmitC pipeline options. diff --git a/lib/Dialect/TTIR/Transforms/Broadcast.cpp b/lib/Dialect/TTIR/Transforms/Broadcast.cpp index 4c5aef7d3..e1d5102f5 100644 --- a/lib/Dialect/TTIR/Transforms/Broadcast.cpp +++ b/lib/Dialect/TTIR/Transforms/Broadcast.cpp @@ -7,12 +7,12 @@ #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" #include namespace mlir::tt::ttir { -#define GEN_PASS_DEF_TTIRBROADCASTFOLD +#define GEN_PASS_DEF_TTIRIMPLICITBROADCASTFOLD #include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" -class TTIRBroadcastFoldRewriter : public RewritePattern { +class TTIRImplicitBroadcastFoldRewriter : public RewritePattern { public: - TTIRBroadcastFoldRewriter(MLIRContext *ctx) + TTIRImplicitBroadcastFoldRewriter(MLIRContext *ctx) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, @@ -49,13 +49,14 @@ class TTIRBroadcastFoldRewriter : public RewritePattern { } }; -class TTIRBroadcastFold - : public impl::TTIRBroadcastFoldBase { +class TTIRImplicitBroadcastFold + : public impl::TTIRImplicitBroadcastFoldBase { public: - using impl::TTIRBroadcastFoldBase::TTIRBroadcastFoldBase; + using impl::TTIRImplicitBroadcastFoldBase< + TTIRImplicitBroadcastFold>::TTIRImplicitBroadcastFoldBase; void runOnOperation() final { RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index da1874a3d..efe4cc9ec 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -120,22 +120,24 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm, createTTNNPipelineDeallocPass(pm, *optionsStruct); } -void createTTNNPipelineTTIRBroadcastFoldPass( +void createTTNNPipelineTTIRImplicitBroadcastFoldPass( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { - pm.addPass(mlir::tt::ttir::createTTIRBroadcastFold()); + if (options.implicitBroadcastFoldingEnabled) { + pm.addPass(mlir::tt::ttir::createTTIRImplicitBroadcastFold()); + } } -void createTTNNPipelineTTIRBroadcastFoldPassFromString(OpPassManager &pm, - std::string options) { +void createTTNNPipelineTTIRImplicitBroadcastFoldPassFromString( + OpPassManager &pm, std::string options) { auto optionsStruct = TTIRToTTNNBackendPipelineOptions::createFromString(options); - createTTNNPipelineTTIRBroadcastFoldPass(pm, *optionsStruct); + createTTNNPipelineTTIRImplicitBroadcastFoldPass(pm, *optionsStruct); } void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { createTTNNPipelineTTIRPasses(pm, options); - createTTNNPipelineTTIRBroadcastFoldPass(pm, options); + createTTNNPipelineTTIRImplicitBroadcastFoldPass(pm, options); createTTNNPipelineLoweringPasses(pm, options); createTTNNPipelineWorkaroundPass(pm, options); createTTNNPipelineAnalysisPasses(pm, options); diff --git a/test/ttmlir/Dialect/TTNN/simple_repeat.mlir b/test/ttmlir/Dialect/TTNN/simple_repeat.mlir index 42261936e..00fddfb78 100644 --- a/test/ttmlir/Dialect/TTNN/simple_repeat.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_repeat.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-implicit-broadcast-folding-pass=false" %s | FileCheck %s module { func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> { // CHECK: %{{[0-9]+}} = "ttnn.repeat" diff --git a/test/ttmlir/Silicon/TTNN/simple_repeat.mlir b/test/ttmlir/Silicon/TTNN/simple_repeat.mlir index ab91af2ee..2000530a0 100644 --- a/test/ttmlir/Silicon/TTNN/simple_repeat.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_repeat.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-implicit-broadcast-folding-pass=false system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module {