diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp index f32fd518823e..4732bc18aa94 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp @@ -68,6 +68,15 @@ static llvm::cl::opt clWarnOnUninitializedValues( "iree-global-opt-enable-warn-on-uninitialized-values", llvm::cl::desc("Warn on some classes of uses of uninitialized values."), llvm::cl::init(true)); + +static llvm::cl::opt clEnableEdgeReshapePropagation( + "iree-global-opt-experimental-enable-edge-reshape-propagation", + llvm::cl::desc( + "Enables propagation of reshapes on the edges of the program " + "in transpose propagation. This workaround for better performance and " + "will be removed soon."), + llvm::cl::init(false)); + void buildGlobalOptExprHoistingPassPipeline( OpPassManager &passManager, const TransformOptions &transformOptions) { IREE::Util::ExprHoistingOptions options; @@ -170,6 +179,8 @@ void buildGlobalOptimizationPassPipeline( transformOptions.aggressiveTransposePropagation; options.enableAttentionVTranspose = clEnableAttentionVTranspose; + options.enableEdgeReshapePropagation = + clEnableEdgeReshapePropagation; return createPropagateLinalgTransposePass(options); }) .addPass(IREE::Flow::createCanonicalizePass) diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td index 12ffd9457e0c..4d89a5ed2273 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td +++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td @@ -119,6 +119,8 @@ def PropagateLinalgTransposePass : /*default=*/"true", "Enable transposition of attention v operand">, Option<"enableConvolutionPropagation", "enable-aggressive-propagation-through-conv", "bool", /*default=*/"false", "enable propagation through convolutions">, + Option<"enableEdgeReshapePropagation", "enable-edge-reshape-propagation", "bool", + /*default=*/"false", "Enable propagation of reshapes on the edges of the program">, ]; } diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index 4ad0fa4b2326..a8770481fe1a 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -320,6 +320,11 @@ class BubbleTransposeThroughCollapseShape : public OpRewritePattern { public: using Base::Base; + BubbleTransposeThroughCollapseShape(MLIRContext *ctx, + bool enableEdgeReshapeProp, + PatternBenefit b = 1) + : OpRewritePattern(ctx, b), + enableEdgeReshapePropagation(enableEdgeReshapeProp) {} LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -336,7 +341,8 @@ class BubbleTransposeThroughCollapseShape transposeOp, "transpose input is not a single-use collapse shape"); } - if (!isReshapeBlockingFusion(transposeOp, + if (!enableEdgeReshapePropagation && + !isReshapeBlockingFusion(transposeOp, collapseOp.getSrc().getDefiningOp())) { return rewriter.notifyMatchFailure(transposeOp, "transpose not blocking fusion"); @@ -379,6 +385,9 @@ class BubbleTransposeThroughCollapseShape rewriter.replaceOp(transposeOp, newReshape); return success(); } + +private: + bool enableEdgeReshapePropagation = true; }; } // namespace @@ -523,6 +532,10 @@ class SinkTransposeThroughExpandShape : public OpRewritePattern { public: using Base::Base; + SinkTransposeThroughExpandShape(MLIRContext *ctx, bool enableEdgeReshapeProp, + PatternBenefit b = 1) + : OpRewritePattern(ctx, b), + enableEdgeReshapePropagation(enableEdgeReshapeProp) {} LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, PatternRewriter &rewriter) const override { @@ -539,7 +552,8 @@ class SinkTransposeThroughExpandShape expandOp, "expand shape input is not a single-use transpose"); } - if (llvm::none_of(expandOp->getUsers(), [&](Operation *consumer) { + if (!enableEdgeReshapePropagation && + llvm::none_of(expandOp->getUsers(), [&](Operation *consumer) { return isReshapeBlockingFusion(transposeOp, consumer); })) { return rewriter.notifyMatchFailure(transposeOp, @@ -588,6 +602,9 @@ class SinkTransposeThroughExpandShape rewriter.replaceOp(expandOp, originalReshape); return success(); } + +private: + bool enableEdgeReshapePropagation = true; }; // Fuses a transpose with the input of a linalg.generic op or contraction op. @@ -1072,7 +1089,8 @@ void PropagateLinalgTransposePass::runOnOperation() { if (!testBubblingOnly) { RewritePatternSet sinkingPatterns(context); sinkingPatterns.insert(context); - sinkingPatterns.insert(context); + sinkingPatterns.insert( + context, enableEdgeReshapePropagation); populateNamedOpSinkingPatterns(context, sinkingPatterns); populateCommonCanonicalizationPatterns(context, sinkingPatterns); sinkingPatterns.add( @@ -1118,7 +1136,8 @@ void PropagateLinalgTransposePass::runOnOperation() { return false; } - if (llvm::none_of( + if (!enableEdgeReshapePropagation && + llvm::none_of( consumer->getUsers(), [&](Operation *expandConsumer) { return isReshapeBlockingFusion(producer, expandConsumer); })) { @@ -1148,7 +1167,8 @@ void PropagateLinalgTransposePass::runOnOperation() { } bubblingPatterns.insert( context, enableAggressivePropagation, enableConvolutionPropagation); - bubblingPatterns.insert(context); + bubblingPatterns.insert( + context, enableEdgeReshapePropagation); bubblingPatterns.add( context, /*benefit=*/2); bubblingPatterns.insert(context); @@ -1197,7 +1217,8 @@ void PropagateLinalgTransposePass::runOnOperation() { return false; } - if (!isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(), + if (!enableEdgeReshapePropagation && + !isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(), consumer)) { return false; } @@ -1209,7 +1230,8 @@ void PropagateLinalgTransposePass::runOnOperation() { linalg::populateFoldReshapeOpsByExpansionPatterns(sinkingPatterns, reshapePropagationFn); sinkingPatterns.insert(context); - sinkingPatterns.insert(context); + sinkingPatterns.insert( + context, enableEdgeReshapePropagation); sinkingPatterns.insert( context, enableAggressivePropagation, enableConvolutionPropagation); sinkingPatterns.insert(context); diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir index 4e3bec010d0a..9d3c1fcfe167 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir @@ -3,6 +3,7 @@ // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{test-sinking-only=true}))" --split-input-file %s | FileCheck %s --check-prefix=SINK // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{test-bubbling-only=true}))" --split-input-file %s | FileCheck %s --check-prefix=BUBBLE // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-aggressive-propagation-through-conv=true}))" --split-input-file %s | FileCheck %s --check-prefix=CONV +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-edge-reshape-propagation=true}))" %s -o - | FileCheck %s --check-prefix=ENABLE-EDGE-PROP util.func public @specialize_transpose_op(%arg0 : tensor<1x2x3xf32>, %empty : tensor<3x2x1xf32>) -> tensor<3x2x1xf32> { @@ -819,6 +820,11 @@ util.func @dont_propagate_edge_reshapes(%arg0: tensor<10x10x10xi32>) -> tensor<1 // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] // CHECK: %[[VAL:.+]] = linalg.transpose ins(%[[COLLAPSED]] // CHECK: util.return %[[VAL]] +// ENABLE-EDGE-PROP-LABEL: util.func public @dont_propagate_edge_reshapes +// ENABLE-EDGE-PROP-SAME: %[[ARG0:[0-9a-zA-Z]+]] +// ENABLE-EDGE-PROP: %[[TRANSPOSED:.+]] = linalg.transpose ins(%[[ARG0]] +// ENABLE-EDGE-PROP: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[TRANSPOSED]] +// ENABLE-EDGE-PROP: util.return %[[COLLAPSED]] // ----- @@ -833,3 +839,7 @@ util.func public @dont_sink_through_edge_expand_shape(%arg0 : tensor<2x3x4xf32>) // SINK: %[[TRANSPOSE:.+]] = linalg.transpose // SINK: %[[RES:.+]] = tensor.expand_shape %[[TRANSPOSE]] // SINK: util.return %[[RES]] : tensor<1x3x4x2xf32> +// ENABLE-EDGE-PROP-LABEL: util.func public @dont_sink_through_edge_expand_shape +// ENABLE-EDGE-PROP: %[[EXP:.+]] = tensor.expand_shape +// ENABLE-EDGE-PROP: %[[RES:.+]] = linalg.transpose +// ENABLE-EDGE-PROP: util.return %[[RES]]