Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ static llvm::cl::opt<bool> clWarnOnUninitializedValues(
"iree-global-opt-enable-warn-on-uninitialized-values",
llvm::cl::desc("Warn on some classes of uses of uninitialized values."),
llvm::cl::init(true));

static llvm::cl::opt<bool> clEnableEdgeReshapePropagation(
"iree-global-opt-experimental-enable-edge-reshape-propagation",
llvm::cl::desc(
"Enables propagation of reshapes on the edges of the program "
"in transpose propagation. This workaround for better performance and "
"will be removed soon."),
llvm::cl::init(false));

void buildGlobalOptExprHoistingPassPipeline(
OpPassManager &passManager, const TransformOptions &transformOptions) {
IREE::Util::ExprHoistingOptions options;
Expand Down Expand Up @@ -170,6 +179,8 @@ void buildGlobalOptimizationPassPipeline(
transformOptions.aggressiveTransposePropagation;
options.enableAttentionVTranspose =
clEnableAttentionVTranspose;
options.enableEdgeReshapePropagation =
clEnableEdgeReshapePropagation;
return createPropagateLinalgTransposePass(options);
})
.addPass(IREE::Flow::createCanonicalizePass)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def PropagateLinalgTransposePass :
/*default=*/"true", "Enable transposition of attention v operand">,
Option<"enableConvolutionPropagation", "enable-aggressive-propagation-through-conv", "bool",
/*default=*/"false", "enable propagation through convolutions">,
Option<"enableEdgeReshapePropagation", "enable-edge-reshape-propagation", "bool",
/*default=*/"false", "Enable propagation of reshapes on the edges of the program">,
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@ class BubbleTransposeThroughCollapseShape
: public OpRewritePattern<linalg::TransposeOp> {
public:
using Base::Base;
BubbleTransposeThroughCollapseShape(MLIRContext *ctx,
bool enableEdgeReshapeProp,
PatternBenefit b = 1)
: OpRewritePattern<linalg::TransposeOp>(ctx, b),
enableEdgeReshapePropagation(enableEdgeReshapeProp) {}

LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
Expand All @@ -336,7 +341,8 @@ class BubbleTransposeThroughCollapseShape
transposeOp, "transpose input is not a single-use collapse shape");
}

if (!isReshapeBlockingFusion(transposeOp,
if (!enableEdgeReshapePropagation &&
!isReshapeBlockingFusion(transposeOp,
collapseOp.getSrc().getDefiningOp())) {
return rewriter.notifyMatchFailure(transposeOp,
"transpose not blocking fusion");
Expand Down Expand Up @@ -379,6 +385,9 @@ class BubbleTransposeThroughCollapseShape
rewriter.replaceOp(transposeOp, newReshape);
return success();
}

private:
bool enableEdgeReshapePropagation = true;
};

} // namespace
Expand Down Expand Up @@ -523,6 +532,10 @@ class SinkTransposeThroughExpandShape
: public OpRewritePattern<tensor::ExpandShapeOp> {
public:
using Base::Base;
SinkTransposeThroughExpandShape(MLIRContext *ctx, bool enableEdgeReshapeProp,
PatternBenefit b = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(ctx, b),
enableEdgeReshapePropagation(enableEdgeReshapeProp) {}

LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
Expand All @@ -539,7 +552,8 @@ class SinkTransposeThroughExpandShape
expandOp, "expand shape input is not a single-use transpose");
}

if (llvm::none_of(expandOp->getUsers(), [&](Operation *consumer) {
if (!enableEdgeReshapePropagation &&
llvm::none_of(expandOp->getUsers(), [&](Operation *consumer) {
return isReshapeBlockingFusion(transposeOp, consumer);
})) {
return rewriter.notifyMatchFailure(transposeOp,
Expand Down Expand Up @@ -588,6 +602,9 @@ class SinkTransposeThroughExpandShape
rewriter.replaceOp(expandOp, originalReshape);
return success();
}

private:
bool enableEdgeReshapePropagation = true;
};

// Fuses a transpose with the input of a linalg.generic op or contraction op.
Expand Down Expand Up @@ -1072,7 +1089,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
if (!testBubblingOnly) {
RewritePatternSet sinkingPatterns(context);
sinkingPatterns.insert<SinkTransposeThroughExtractSlice>(context);
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(context);
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(
context, enableEdgeReshapePropagation);
populateNamedOpSinkingPatterns(context, sinkingPatterns);
populateCommonCanonicalizationPatterns(context, sinkingPatterns);
sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
Expand Down Expand Up @@ -1118,7 +1136,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
return false;
}

if (llvm::none_of(
if (!enableEdgeReshapePropagation &&
llvm::none_of(
consumer->getUsers(), [&](Operation *expandConsumer) {
return isReshapeBlockingFusion(producer, expandConsumer);
})) {
Expand Down Expand Up @@ -1148,7 +1167,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
}
bubblingPatterns.insert<FuseTransposeWithProducerLinalgOp>(
context, enableAggressivePropagation, enableConvolutionPropagation);
bubblingPatterns.insert<BubbleTransposeThroughCollapseShape>(context);
bubblingPatterns.insert<BubbleTransposeThroughCollapseShape>(
context, enableEdgeReshapePropagation);
bubblingPatterns.add<BubbleTransposeThroughUnaryElementwiseDpsInit>(
context, /*benefit=*/2);
bubblingPatterns.insert<ComposeTransposes>(context);
Expand Down Expand Up @@ -1197,7 +1217,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
return false;
}

if (!isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(),
if (!enableEdgeReshapePropagation &&
!isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(),
consumer)) {
return false;
}
Expand All @@ -1209,7 +1230,8 @@ void PropagateLinalgTransposePass::runOnOperation() {
linalg::populateFoldReshapeOpsByExpansionPatterns(sinkingPatterns,
reshapePropagationFn);
sinkingPatterns.insert<SinkTransposeThroughExtractSlice>(context);
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(context);
sinkingPatterns.insert<SinkTransposeThroughExpandShape>(
context, enableEdgeReshapePropagation);
sinkingPatterns.insert<FuseTransposeWithLinalgOpConsumer>(
context, enableAggressivePropagation, enableConvolutionPropagation);
sinkingPatterns.insert<ComposeTransposes>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{test-sinking-only=true}))" --split-input-file %s | FileCheck %s --check-prefix=SINK
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{test-bubbling-only=true}))" --split-input-file %s | FileCheck %s --check-prefix=BUBBLE
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-aggressive-propagation-through-conv=true}))" --split-input-file %s | FileCheck %s --check-prefix=CONV
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-global-opt-propagate-linalg-transpose{enable-edge-reshape-propagation=true}))" %s -o - | FileCheck %s --check-prefix=ENABLE-EDGE-PROP

util.func public @specialize_transpose_op(%arg0 : tensor<1x2x3xf32>,
%empty : tensor<3x2x1xf32>) -> tensor<3x2x1xf32> {
Expand Down Expand Up @@ -819,6 +820,11 @@ util.func @dont_propagate_edge_reshapes(%arg0: tensor<10x10x10xi32>) -> tensor<1
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK: %[[VAL:.+]] = linalg.transpose ins(%[[COLLAPSED]]
// CHECK: util.return %[[VAL]]
// ENABLE-EDGE-PROP-LABEL: util.func public @dont_propagate_edge_reshapes
// ENABLE-EDGE-PROP-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// ENABLE-EDGE-PROP: %[[TRANSPOSED:.+]] = linalg.transpose ins(%[[ARG0]]
// ENABLE-EDGE-PROP: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[TRANSPOSED]]
// ENABLE-EDGE-PROP: util.return %[[COLLAPSED]]

// -----

Expand All @@ -833,3 +839,7 @@ util.func public @dont_sink_through_edge_expand_shape(%arg0 : tensor<2x3x4xf32>)
// SINK: %[[TRANSPOSE:.+]] = linalg.transpose
// SINK: %[[RES:.+]] = tensor.expand_shape %[[TRANSPOSE]]
// SINK: util.return %[[RES]] : tensor<1x3x4x2xf32>
// ENABLE-EDGE-PROP-LABEL: util.func public @dont_sink_through_edge_expand_shape
// ENABLE-EDGE-PROP: %[[EXP:.+]] = tensor.expand_shape
// ENABLE-EDGE-PROP: %[[RES:.+]] = linalg.transpose
// ENABLE-EDGE-PROP: util.return %[[RES]]
Loading