diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 27b2da6b616d..445560f0cf27 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -543,20 +543,14 @@ moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target, rewriter.setInsertionPoint(body.getTerminator()); Operation *clonedTarget = rewriter.clone(*target); - // Replace any operands returned by the `regionOp` with the results yielded - // inside of the `regionOp`. - for (OpOperand &operand : clonedTarget->getOpOperands()) { - if (operand.get().getDefiningOp() != regionOp) { - continue; - } - auto returnOp = - cast(regionOp.getBody().front().getTerminator()); - auto opResult = cast(operand.get()); - Value yieldedValue = returnOp->getOperand(opResult.getResultNumber()); - rewriter.modifyOpInPlace(clonedTarget, [&]() { - clonedTarget->setOperand(operand.getOperandNumber(), yieldedValue); - }); - } + // Replace all of `clonedTarget` uses of `regionOp` with the values yielded + // from inside the region. + auto returnOp = + cast(regionOp.getBody().front().getTerminator()); + rewriter.replaceOpUsesWithIf( + regionOp, returnOp.getOperands(), [&](OpOperand &operand) { + return clonedTarget->isAncestor(operand.getOwner()); + }); // Gather all uses of `target`. for (auto [index, result] : llvm::enumerate(target->getResults())) { diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 4c16b3bc6707..25163fa85c58 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -127,6 +127,19 @@ class FusionGroup { FailureOr getRootParallelLoopToOpMap(Operation *op) const; bool isFusable(Operation *op) const { + // We only handle fusion across operation's operands. Don't fuse if the + // operation is using values in the fusion group in it's body. + bool hasUseFromAbove = false; + mlir::visitUsedValuesDefinedAbove( + op->getRegions(), [&](OpOperand *operand) { + if (loopMaps.contains(operand->get().getDefiningOp())) { + hasUseFromAbove = true; + } + }); + if (hasUseFromAbove) { + return false; + } + FailureOr maybeMap = getRootParallelLoopToOpMap(op); if (failed(maybeMap)) { return false; diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir index 866537b75179..c9ac66f7c207 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir @@ -1837,3 +1837,41 @@ util.func public @no_fusion_across_blocks(%arg0: tensor<3x2xf32>) -> tensor // CHECK-SAME: ins(%[[DISPATCH0]], %[[FILL]] // CHECK: flow.return %[[DIV]] // CHECK: util.return %[[DISPATCH1]] + +// ----- + +util.func public @no_fusion_use_from_above(%arg0 : tensor, + %arg1 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %empty = tensor.empty(%d0, %d1) : tensor + %matmul = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%empty : tensor) -> tensor + %empty2 = tensor.empty(%d0, %d1) : tensor + %consumer = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} + ins(%matmul : tensor) + outs(%empty2 : tensor) { + ^bb0(%in: f32, %out: f32): + %c0_idx = arith.constant 0 : index + %c1_idx = arith.constant 1 : index + %extracted = tensor.extract %matmul[%c0_idx, %c1_idx] : tensor + %sum = arith.addf %in, %extracted : f32 + linalg.yield %sum : f32 + } -> tensor + util.return %consumer : tensor +} +// CHECK-LABEL: util.func public @no_fusion_use_from_above( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region +// CHECK: %[[MATMUL:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] +// CHECK: flow.return %[[MATMUL]] +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region +// CHECK: %[[CONSUMER:.+]] = linalg.generic +// CHECK-SAME: ins(%[[DISPATCH0]] +// CHECK: tensor.extract %[[DISPATCH0]] +// CHECK: flow.return %[[CONSUMER]] +// CHECK: util.return %[[DISPATCH1]]