diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp index c90e1086a7d8..9411a0d36e27 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp @@ -87,6 +87,18 @@ static RankedTensorType getPermutedTensorType(RankedTensorType type, return RankedTensorType::get(permutedShape, type.getElementType()); } +static bool isReshapeBlockingFusion(Operation *producer, Operation *consumer) { + auto isFusableOp = [](Operation *op) { + if (!op) { + return false; + } + return isa_and_nonnull(op->getDialect()); + }; + return isFusableOp(producer) && isFusableOp(consumer); +} + //===----------------------------------------------------------------------===// // Transpose specialization //===----------------------------------------------------------------------===// @@ -324,6 +336,12 @@ class BubbleTransposeThroughCollapseShape transposeOp, "transpose input is not a single-use collapse shape"); } + if (!isReshapeBlockingFusion(transposeOp, + collapseOp.getSrc().getDefiningOp())) { + return rewriter.notifyMatchFailure(transposeOp, + "transpose not blocking fusion"); + } + SmallVector reassociations = collapseOp.getReassociationIndices(); @@ -521,6 +539,13 @@ class SinkTransposeThroughExpandShape expandOp, "expand shape input is not a single-use transpose"); } + if (llvm::none_of(expandOp->getUsers(), [&](Operation *consumer) { + return isReshapeBlockingFusion(transposeOp, consumer); + })) { + return rewriter.notifyMatchFailure(transposeOp, + "transpose not blocking fusion"); + } + auto invPerm = invertPermutationVector(transposeOp.getPermutation()); SmallVector reassociations = expandOp.getReassociationIndices(); @@ -1084,6 +1109,13 @@ void PropagateLinalgTransposePass::runOnOperation() { if (!isa(consumer)) { return false; } + + if (llvm::none_of( + consumer->getUsers(), [&](Operation *expandConsumer) { + return isReshapeBlockingFusion(producer, expandConsumer); + })) { + return false; + } // Only propagate if the immediate consumer of the reshape is a // transpose. return consumer->hasOneUse() && @@ -1156,6 +1188,12 @@ void PropagateLinalgTransposePass::runOnOperation() { if (!isa(producer)) { return false; } + + if (!isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(), + consumer)) { + return false; + } + // Require that the immediate producer of the reshape is a transpose. return isa_and_nonnull( producer->getOperand(0).getDefiningOp()); diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir index ae22d08b22e5..0405d1c59ae2 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir @@ -358,40 +358,6 @@ util.func public @sink_through_expand_shape(%arg0: tensor) -> tensor< // ----- -util.func public @sink_non_involution_through_expand_shape(%arg0 : tensor<2x3x4xf32>) -> tensor<1x3x4x2xf32> { - %empty = tensor.empty(): tensor<3x4x2xf32> - %transposed = linalg.transpose ins(%arg0 : tensor<2x3x4xf32>) - outs(%empty : tensor<3x4x2xf32>) permutation = [1, 2, 0] - %expanded = tensor.expand_shape %transposed [[0, 1], [2], [3]] output_shape [1, 3, 4, 2] : tensor<3x4x2xf32> into tensor<1x3x4x2xf32> - util.return %expanded : tensor<1x3x4x2xf32> -} -// SINK-LABEL: util.func public @sink_non_involution_through_expand_shape -// SINK: %[[EXP:.+]] = tensor.expand_shape {{.*}} {{\[\[}}0], [1, 2], [3]] -// SINK-SAME: tensor<2x3x4xf32> into tensor<2x1x3x4xf32> -// SINK: %[[RES:.+]] = linalg.transpose ins(%[[EXP]] : tensor<2x1x3x4xf32> -// SINK-SAME: outs({{.*}} : tensor<1x3x4x2xf32>) -// SINK-SAME: permutation = [1, 2, 3, 0] -// SINK: util.return %[[RES]] : tensor<1x3x4x2xf32> - -// ----- - -util.func public @bubble_non_involution_through_collapse_shape(%arg0 : tensor<1x2x3x5x7x11xf32>) -> tensor<35x11x6xf32> { - %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x35x11xf32> - %empty = tensor.empty(): tensor<35x11x6xf32> - %transposed = linalg.transpose ins(%collapsed : tensor<6x35x11xf32>) - outs(%empty : tensor<35x11x6xf32>) permutation = [1, 2, 0] - util.return %transposed : tensor<35x11x6xf32> -} -// BUBBLE-LABEL: util.func public @bubble_non_involution_through_collapse_shape -// BUBBLE: %[[T:.+]] = linalg.transpose ins(%{{.*}} : tensor<1x2x3x5x7x11xf32> -// BUBBLE-SAME: outs({{.*}} : tensor<5x7x11x1x2x3xf32>) -// BUBBLE-SAME: permutation = [3, 4, 5, 0, 1, 2] -// BUBBLE: %[[COL:.+]] = tensor.collapse_shape %[[T]] {{\[\[}}0, 1], [2], [3, 4, 5]] -// BUBBLE-SAME: tensor<5x7x11x1x2x3xf32> into tensor<35x11x6xf32> -// BUBBLE: util.return %[[COL]] : tensor<35x11x6xf32> - -// ----- - util.func public @propagate_transpose_through_unary_elementwise(%arg0 : tensor<2x3x4xf32>) -> tensor<3x4x2xf32> { %empty = tensor.empty(): tensor<3x4x2xf32> %transposed = linalg.transpose ins(%arg0 : tensor<2x3x4xf32>) @@ -799,3 +765,31 @@ util.func public @dont_reshape_reduction(%arg0: tensor<16x4x4xf32>, %arg1: tenso // APROP: %[[V1:.+]] = tensor.collapse_shape %[[V0]] // APROP: %[[V2:.+]] = linalg.matmul ins(%[[V1]] // APROP: util.return %[[V2]] + +// ----- + +util.func @dont_propagate_edge_reshapes(%arg0: tensor<10x10x10xi32>) -> tensor<10x100xi32> { + %collapsed = tensor.collapse_shape %arg0[[0, 1], [2]] : tensor<10x10x10xi32> into tensor<100x10xi32> + %empty = tensor.empty() : tensor<10x100xi32> + %transpose = linalg.transpose ins(%collapsed : tensor<100x10xi32>) outs(%empty : tensor<10x100xi32>) permutation = [1, 0] + util.return %transpose : tensor<10x100xi32> +} +// CHECK-LABEL: util.func public @dont_propagate_edge_reshapes +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]] +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[VAL:.+]] = linalg.transpose ins(%[[COLLAPSED]] +// CHECK: util.return %[[VAL]] + +// ----- + +util.func public @dont_sink_through_edge_expand_shape(%arg0 : tensor<2x3x4xf32>) -> tensor<1x3x4x2xf32> { + %empty = tensor.empty(): tensor<3x4x2xf32> + %transposed = linalg.transpose ins(%arg0 : tensor<2x3x4xf32>) + outs(%empty : tensor<3x4x2xf32>) permutation = [1, 2, 0] + %expanded = tensor.expand_shape %transposed [[0, 1], [2], [3]] output_shape [1, 3, 4, 2] : tensor<3x4x2xf32> into tensor<1x3x4x2xf32> + util.return %expanded : tensor<1x3x4x2xf32> +} +// SINK-LABEL: util.func public @dont_sink_through_edge_expand_shape +// SINK: %[[TRANSPOSE:.+]] = linalg.transpose +// SINK: %[[RES:.+]] = tensor.expand_shape %[[TRANSPOSE]] +// SINK: util.return %[[RES]] : tensor<1x3x4x2xf32>