Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ static RankedTensorType getPermutedTensorType(RankedTensorType type,
return RankedTensorType::get(permutedShape, type.getElementType());
}

static bool isReshapeBlockingFusion(Operation *producer, Operation *consumer) {
auto isFusableOp = [](Operation *op) {
if (!op) {
return false;
}
return isa_and_nonnull<linalg::LinalgDialect,
IREE::LinalgExt::IREELinalgExtDialect,
tensor::TensorDialect>(op->getDialect());
};
return isFusableOp(producer) && isFusableOp(consumer);
}

//===----------------------------------------------------------------------===//
// Transpose specialization
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -324,6 +336,12 @@ class BubbleTransposeThroughCollapseShape
transposeOp, "transpose input is not a single-use collapse shape");
}

if (!isReshapeBlockingFusion(transposeOp,
collapseOp.getSrc().getDefiningOp())) {
return rewriter.notifyMatchFailure(transposeOp,
"transpose not blocking fusion");
}

SmallVector<ReassociationIndices> reassociations =
collapseOp.getReassociationIndices();

Expand Down Expand Up @@ -521,6 +539,13 @@ class SinkTransposeThroughExpandShape
expandOp, "expand shape input is not a single-use transpose");
}

if (llvm::none_of(expandOp->getUsers(), [&](Operation *consumer) {
return isReshapeBlockingFusion(transposeOp, consumer);
})) {
return rewriter.notifyMatchFailure(transposeOp,
"transpose not blocking fusion");
}

auto invPerm = invertPermutationVector(transposeOp.getPermutation());
SmallVector<ReassociationIndices> reassociations =
expandOp.getReassociationIndices();
Expand Down Expand Up @@ -1084,6 +1109,13 @@ void PropagateLinalgTransposePass::runOnOperation() {
if (!isa<tensor::ExpandShapeOp>(consumer)) {
return false;
}

if (llvm::none_of(
consumer->getUsers(), [&](Operation *expandConsumer) {
return isReshapeBlockingFusion(producer, expandConsumer);
})) {
return false;
}
// Only propagate if the immediate consumer of the reshape is a
// transpose.
return consumer->hasOneUse() &&
Expand Down Expand Up @@ -1156,6 +1188,12 @@ void PropagateLinalgTransposePass::runOnOperation() {
if (!isa<tensor::CollapseShapeOp>(producer)) {
return false;
}

if (!isReshapeBlockingFusion(producer->getOperand(0).getDefiningOp(),
consumer)) {
return false;
}

// Require that the immediate producer of the reshape is a transpose.
return isa_and_nonnull<linalg::TransposeOp>(
producer->getOperand(0).getDefiningOp());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,40 +358,6 @@ util.func public @sink_through_expand_shape(%arg0: tensor<?x?x?xf32>) -> tensor<

// -----

util.func public @sink_non_involution_through_expand_shape(%arg0 : tensor<2x3x4xf32>) -> tensor<1x3x4x2xf32> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to keep this behavior... Cause this might allow the transpose to fuse with its consumers...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is that the reshapes are on the edge so that there will be no producer/consumers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I follow. The test is changing the behavior that is different than what I assume the end state of the code should be. Could you provide more details here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that BubbleUpExpandShapes is not perfect and doesn't currently sink reshapes through reduction operations. The idea is to not touch reshapes that are on the edges of the program since they might get stuck on a reduction operation. This shouldn't harm transpose propagation because these reshapes are on the edges and won't block any transposes from getting propagated throughout the program.

Also, I added some more context on the issue with the mlir before and after this pass #22312

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the better solution is make a change to BubbleUpExpandShapes, I'm testing if there are any regressions #22341

%empty = tensor.empty(): tensor<3x4x2xf32>
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x4xf32>)
outs(%empty : tensor<3x4x2xf32>) permutation = [1, 2, 0]
%expanded = tensor.expand_shape %transposed [[0, 1], [2], [3]] output_shape [1, 3, 4, 2] : tensor<3x4x2xf32> into tensor<1x3x4x2xf32>
util.return %expanded : tensor<1x3x4x2xf32>
}
// SINK-LABEL: util.func public @sink_non_involution_through_expand_shape
// SINK: %[[EXP:.+]] = tensor.expand_shape {{.*}} {{\[\[}}0], [1, 2], [3]]
// SINK-SAME: tensor<2x3x4xf32> into tensor<2x1x3x4xf32>
// SINK: %[[RES:.+]] = linalg.transpose ins(%[[EXP]] : tensor<2x1x3x4xf32>
// SINK-SAME: outs({{.*}} : tensor<1x3x4x2xf32>)
// SINK-SAME: permutation = [1, 2, 3, 0]
// SINK: util.return %[[RES]] : tensor<1x3x4x2xf32>

// -----

util.func public @bubble_non_involution_through_collapse_shape(%arg0 : tensor<1x2x3x5x7x11xf32>) -> tensor<35x11x6xf32> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]] : tensor<1x2x3x5x7x11xf32> into tensor<6x35x11xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. The transpose could fuse with its producers.

%empty = tensor.empty(): tensor<35x11x6xf32>
%transposed = linalg.transpose ins(%collapsed : tensor<6x35x11xf32>)
outs(%empty : tensor<35x11x6xf32>) permutation = [1, 2, 0]
util.return %transposed : tensor<35x11x6xf32>
}
// BUBBLE-LABEL: util.func public @bubble_non_involution_through_collapse_shape
// BUBBLE: %[[T:.+]] = linalg.transpose ins(%{{.*}} : tensor<1x2x3x5x7x11xf32>
// BUBBLE-SAME: outs({{.*}} : tensor<5x7x11x1x2x3xf32>)
// BUBBLE-SAME: permutation = [3, 4, 5, 0, 1, 2]
// BUBBLE: %[[COL:.+]] = tensor.collapse_shape %[[T]] {{\[\[}}0, 1], [2], [3, 4, 5]]
// BUBBLE-SAME: tensor<5x7x11x1x2x3xf32> into tensor<35x11x6xf32>
// BUBBLE: util.return %[[COL]] : tensor<35x11x6xf32>

// -----

util.func public @propagate_transpose_through_unary_elementwise(%arg0 : tensor<2x3x4xf32>) -> tensor<3x4x2xf32> {
%empty = tensor.empty(): tensor<3x4x2xf32>
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x4xf32>)
Expand Down Expand Up @@ -799,3 +765,31 @@ util.func public @dont_reshape_reduction(%arg0: tensor<16x4x4xf32>, %arg1: tenso
// APROP: %[[V1:.+]] = tensor.collapse_shape %[[V0]]
// APROP: %[[V2:.+]] = linalg.matmul ins(%[[V1]]
// APROP: util.return %[[V2]]

// -----

util.func @dont_propagate_edge_reshapes(%arg0: tensor<10x10x10xi32>) -> tensor<10x100xi32> {
%collapsed = tensor.collapse_shape %arg0[[0, 1], [2]] : tensor<10x10x10xi32> into tensor<100x10xi32>
%empty = tensor.empty() : tensor<10x100xi32>
%transpose = linalg.transpose ins(%collapsed : tensor<100x10xi32>) outs(%empty : tensor<10x100xi32>) permutation = [1, 0]
util.return %transpose : tensor<10x100xi32>
}
// CHECK-LABEL: util.func public @dont_propagate_edge_reshapes
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK: %[[VAL:.+]] = linalg.transpose ins(%[[COLLAPSED]]
// CHECK: util.return %[[VAL]]

// -----

util.func public @dont_sink_through_edge_expand_shape(%arg0 : tensor<2x3x4xf32>) -> tensor<1x3x4x2xf32> {
%empty = tensor.empty(): tensor<3x4x2xf32>
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x4xf32>)
outs(%empty : tensor<3x4x2xf32>) permutation = [1, 2, 0]
%expanded = tensor.expand_shape %transposed [[0, 1], [2], [3]] output_shape [1, 3, 4, 2] : tensor<3x4x2xf32> into tensor<1x3x4x2xf32>
util.return %expanded : tensor<1x3x4x2xf32>
}
// SINK-LABEL: util.func public @dont_sink_through_edge_expand_shape
// SINK: %[[TRANSPOSE:.+]] = linalg.transpose
// SINK: %[[RES:.+]] = tensor.expand_shape %[[TRANSPOSE]]
// SINK: util.return %[[RES]] : tensor<1x3x4x2xf32>
Loading