Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -543,20 +543,14 @@ moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target,
rewriter.setInsertionPoint(body.getTerminator());
Operation *clonedTarget = rewriter.clone(*target);

// Replace any operands returned by the `regionOp` with the results yielded
// inside of the `regionOp`.
for (OpOperand &operand : clonedTarget->getOpOperands()) {
if (operand.get().getDefiningOp() != regionOp) {
continue;
}
auto returnOp =
cast<IREE::Flow::ReturnOp>(regionOp.getBody().front().getTerminator());
auto opResult = cast<OpResult>(operand.get());
Value yieldedValue = returnOp->getOperand(opResult.getResultNumber());
rewriter.modifyOpInPlace(clonedTarget, [&]() {
clonedTarget->setOperand(operand.getOperandNumber(), yieldedValue);
});
}
// Replace all of `clonedTarget` uses of `regionOp` with the values yielded
// from inside the region.
auto returnOp =
cast<IREE::Flow::ReturnOp>(regionOp.getBody().front().getTerminator());
rewriter.replaceOpUsesWithIf(
regionOp, returnOp.getOperands(), [&](OpOperand &operand) {
return clonedTarget->isAncestor(operand.getOwner());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just to confirm, isAncestor includes the operation as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

});

// Gather all uses of `target`.
for (auto [index, result] : llvm::enumerate(target->getResults())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@ class FusionGroup {
FailureOr<AffineMap> getRootParallelLoopToOpMap(Operation *op) const;

bool isFusable(Operation *op) const {
// We only handle fusion across operation's operands. Don't fuse if the
// operation is using values in the fusion group in it's body.
bool hasUseFromAbove = false;
mlir::visitUsedValuesDefinedAbove(
op->getRegions(), [&](OpOperand *operand) {
if (loopMaps.contains(operand->get().getDefiningOp())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is loopMaps ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a map that contains all of the operations currently in the "fusion group".

hasUseFromAbove = true;
}
});
if (hasUseFromAbove) {
return false;
}

FailureOr<AffineMap> maybeMap = getRootParallelLoopToOpMap(op);
if (failed(maybeMap)) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1837,3 +1837,41 @@ util.func public @no_fusion_across_blocks(%arg0: tensor<3x2xf32>) -> tensor<f32>
// CHECK-SAME: ins(%[[DISPATCH0]], %[[FILL]]
// CHECK: flow.return %[[DIV]]
// CHECK: util.return %[[DISPATCH1]]

// -----

util.func public @no_fusion_use_from_above(%arg0 : tensor<?x?xf32>,
%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%empty = tensor.empty(%d0, %d1) : tensor<?x?xf32>
%matmul = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32>
%empty2 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
%consumer = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]}
ins(%matmul : tensor<?x?xf32>)
outs(%empty2 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%c0_idx = arith.constant 0 : index
%c1_idx = arith.constant 1 : index
%extracted = tensor.extract %matmul[%c0_idx, %c1_idx] : tensor<?x?xf32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting example.

First, I dont know why we are lowering it this way. This seems awfully strange lowering. Are we sure we know whats happening with the input lowering of torch-mlir. I dont think we want to end up with this lowering.

It does so happen that the SSA violation is being used to prevent this fusion, but we should prevent this fusion "structurally" as well, i.e. when we decide if this is a fusable consumer, we should mark it as not fusable. due to the tensor.extract usage in the consumer. That would also avoid the issue and seems like the more correct handling of this in my view.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: I see that you have checked that exactly below. Nice!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test might be a bit over-simplified, making it look stranger than it is. I've lost the exact IR from CI that was causing this issue but I should be able to re-create it on #22341

%sum = arith.addf %in, %extracted : f32
linalg.yield %sum : f32
} -> tensor<?x?xf32>
util.return %consumer : tensor<?x?xf32>
}
// CHECK-LABEL: util.func public @no_fusion_use_from_above(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK: flow.return %[[MATMUL]]
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region
// CHECK: %[[CONSUMER:.+]] = linalg.generic
// CHECK-SAME: ins(%[[DISPATCH0]]
// CHECK: tensor.extract %[[DISPATCH0]]
// CHECK: flow.return %[[CONSUMER]]
// CHECK: util.return %[[DISPATCH1]]
Loading