From 432125464b59d364297b325c661e9e0438fd8e35 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Thu, 24 Oct 2024 09:16:14 -0700 Subject: [PATCH 1/2] Revert "[DispatchCreation] Extend multi-use producer fusion (#18551)" This reverts commit 206b60ca59c9dbbca5769694df4714c38cecaced. --- .github/workflows/pkgci_regression_test.yml | 4 +- .../FuseHorizontalContractions.cpp | 61 +++++++++++++-- .../FuseMultiUseElementwiseProducer.cpp | 76 ++++--------------- .../compiler/DispatchCreation/FusionUtils.cpp | 33 -------- .../compiler/DispatchCreation/FusionUtils.h | 44 ----------- .../fuse_multiuse_elementwise_producer.mlir | 25 ------ 6 files changed, 74 insertions(+), 169 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 0748ec51859b..3cec71e43179 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -220,7 +220,7 @@ jobs: --goldentime-rocm-unet-ms 419.0 \ --goldentime-rocm-clip-ms 18.5 \ --goldentime-rocm-vae-ms 337.0 \ - --goldendispatch-rocm-unet 1545 \ + --goldendispatch-rocm-unet 1551 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 248 \ --goldensize-rocm-unet-bytes 2280000 \ @@ -241,7 +241,7 @@ jobs: --goldentime-rocm-unet-ms 95.0 \ --goldentime-rocm-clip-ms 15.5 \ --goldentime-rocm-vae-ms 80.0 \ - --goldendispatch-rocm-unet 1545 \ + --goldendispatch-rocm-unet 1551 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 248 \ --goldensize-rocm-unet-bytes 2270000 \ diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp index 845485667d38..a78b6b83876b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp @@ -7,7 +7,6 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" @@ -108,6 +107,25 @@ static bool isEmptyFillContractionDAGRootOp( return true; } +/// Check that a given operation is "horizontal" to the group. The operation +/// is horizontal if the `slice` of the operation does not contain any op +/// from the group. +static bool isHorizontalToGroup(Operation *op, + const llvm::SetVector &currGroup, + const DominanceInfo &dominanceInfo, + Operation *seedOp) { + BackwardSliceOptions options; + // Limit the slice to the seed to make sure the slice is small. + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, seedOp); + }; + llvm::SetVector slice; + getBackwardSlice(op, &slice, options); + return !llvm::any_of(currGroup, [&](Operation *groupedOp) { + return slice.contains(groupedOp); + }); +} + /// Get user of operation that is a truncate operation. static std::optional getTruncateOp(Operation *op, @@ -131,8 +149,8 @@ getTruncateOp(Operation *op, if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) { return std::nullopt; } - if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(), - dominanceInfo, seedTruncateOp.value())) { + if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo, + seedTruncateOp.value())) { return std::nullopt; } } @@ -208,8 +226,7 @@ static std::optional getHorizontalFusionGroupMembers( if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) { return false; } - if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo, - seedOp)) { + if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) { return false; } return true; @@ -329,6 +346,40 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter, return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0); } +/// During horizontal fusion, there might be operands of the fused operations +/// whose definitions are interspersed between the fused operations. For groups +/// chosen to fuse horizontally, such operations can be moved before the +/// seed contraction operation (where the fused operation is generated). +template +static LogicalResult +moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, + Operation *insertionPoint, DominanceInfo &dominanceInfo, + ArrayRef ignoreOperations = {}) { + BackwardSliceOptions options; + llvm::DenseSet ignoreOperationsSet; + ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); + options.filter = [&](Operation *op) { + return !dominanceInfo.properlyDominates(op, insertionPoint) && + !ignoreOperationsSet.contains(op); + }; + // Set inclusive to true cause the slice is computed from the operand, and + // we want to include the defining op (which is the point here) + options.inclusive = true; + + llvm::SetVector slice; + for (auto op : operations) { + for (auto operand : op->getOperands()) { + getBackwardSlice(operand, &slice, options); + } + } + + mlir::topologicalSort(slice); + for (auto op : slice) { + rewriter.moveOpBefore(op, insertionPoint); + } + return success(); +} + /// On finding this pattern /// ``` /// %0 = linalg.matmul ins(%arg0, %arg1) diff --git a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp index d79d5145e77d..9d9d477c9a57 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp @@ -16,13 +16,9 @@ #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "iree/compiler/DispatchCreation/FusionUtils.h" #include "iree/compiler/DispatchCreation/Passes.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -49,55 +45,25 @@ static llvm::cl::opt clLinalgMaxConstantFoldElements( llvm::cl::desc("Maximum number of elements to try to constant fold."), llvm::cl::init(0)); -static Operation *getMostDominantUse(Operation *op, - const DominanceInfo &dominanceInfo) { - auto uses = op->getUses(); - auto it = llvm::find_if(uses, [&](OpOperand &source) { - Operation *sourceOp = source.getOwner(); - - return llvm::all_of(uses, [&](OpOperand &target) { - Operation *targetOp = target.getOwner(); - return dominanceInfo.dominates(sourceOp, targetOp); - }); - }); - if (it != uses.end()) { - return it->getOwner(); - } - return nullptr; -} - /// Check if any of the use dominates all other uses of the operation. -static Operation *getFusableUse(Operation *op, - const DominanceInfo &dominanceInfo) { +static std::optional getFusableUse(Operation *op, + DominanceInfo &dominanceInfo) { auto uses = op->getUses(); - Operation *fusableUse = nullptr; for (OpOperand &source : uses) { Operation *sourceOp = source.getOwner(); - - bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) { + bool dominatesAllUsers = true; + for (OpOperand &target : uses) { Operation *targetOp = target.getOwner(); - return !isa(targetOp) || - dominanceInfo.dominates(sourceOp, targetOp); - }); - if (dominatesAllFusableOps) { - fusableUse = sourceOp; - break; + if (!dominanceInfo.dominates(sourceOp, targetOp)) { + dominatesAllUsers = false; + break; + } + } + if (dominatesAllUsers) { + return &source; } } - Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo); - if (!fusableUse || !mostDominantOp) { - return nullptr; - } - - // If `fusableUse` dominates all other users, there's nothing else to do. - if (fusableUse == mostDominantOp) { - return fusableUse; - } - - SmallVector users(op->getUsers().begin(), op->getUsers().end()); - return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp) - ? fusableUse - : nullptr; + return std::nullopt; } static OpOperand *getFirstUseInConsumer(Operation *producer, @@ -125,7 +91,6 @@ static SmallVector getAllUsesInConsumer(Operation *producer, /// using elementwise fusion. static LogicalResult doMultiUseFusion(Operation *rootOp, llvm::SetVector &fusableOps, - const DominanceInfo &dominanceInfo, RewriterBase &rewriter) { assert(rootOp && "root op cant be null"); @@ -147,20 +112,11 @@ static LogicalResult doMultiUseFusion(Operation *rootOp, Operation *consumerOp = rootOp; OpBuilder::InsertionGuard g(rewriter); for (Operation *producerOp : llvm::reverse(fusedOpsVec)) { - Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo); // Fuse all uses from producer -> consumer. It has been checked // before that all uses are fusable. while (OpOperand *fusedOperand = getFirstUseInConsumer(producerOp, consumerOp)) { rewriter.setInsertionPoint(consumerOp); - - if (consumerOp != mostDominantUser && - failed(moveOperandDefs(rewriter, ArrayRef{consumerOp}, - mostDominantUser, dominanceInfo))) { - return rewriter.notifyMatchFailure(consumerOp, - "failed to move operand defs"); - } - rewriter.moveOpBefore(consumerOp, mostDominantUser); FailureOr fusionResult = linalg::fuseElementwiseOps(rewriter, fusedOperand); if (failed(fusionResult)) { @@ -234,8 +190,9 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, } // 6. Check that the `genericOp` dominates all uses of `producer`. - Operation *fusableUse = getFusableUse(producer, dominanceInfo); - if (!fusableUse || fusableUse != genericOp) { + std::optional fusableUse = + getFusableUse(producer, dominanceInfo); + if (!fusableUse || fusableUse.value()->getOwner() != genericOp) { continue; } @@ -275,8 +232,7 @@ static FailureOr fuseMultiUseProducers(Operation *funcOp, IRRewriter rewriter(context); for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) { - if (failed( - doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) { + if (failed(doMultiUseFusion(it->first, it->second, rewriter))) { return funcOp->emitOpError("failed multi use fusion"); } } diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index 238c866fe461..c428091f6cf8 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -10,11 +10,7 @@ #include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h" #include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" -#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Transforms/RegionUtils.h" namespace mlir::iree_compiler::DispatchCreation { @@ -101,33 +97,4 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return true; } -bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, - const DominanceInfo &dominanceInfo, - Operation *seedOp) { - assert(dominanceInfo.properlyDominates(seedOp, op) && - op->getParentRegion() == seedOp->getParentRegion()); - BackwardSliceOptions options; - // Limit the slice to the seed to make sure the slice is small. - options.filter = [&](Operation *op) { - return !dominanceInfo.properlyDominates(op, seedOp); - }; - llvm::SetVector slice; - getBackwardSlice(op, &slice, options); - - // `getBackwardSlice` doesnt track uses from within an ops region, so make - // sure there are no values defined above. - for (Operation *sliceOp : slice) { - bool usesValuesFromAbove = false; - mlir::visitUsedValuesDefinedAbove( - sliceOp->getRegions(), [&](void *) { usesValuesFromAbove = true; }); - if (usesValuesFromAbove) { - return false; - } - } - - return !llvm::any_of(currGroup, [&](Operation *groupedOp) { - return slice.contains(groupedOp); - }); -} - } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h index 6526badfea31..1d9c9306f7ae 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h @@ -10,10 +10,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Analysis/TopologicalSortUtils.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" namespace mlir::iree_compiler::DispatchCreation { @@ -23,44 +19,4 @@ namespace mlir::iree_compiler::DispatchCreation { bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand, bool fuseMultiReduction); -/// Check that a given operation is "horizontal" to the group. The operation -/// is horizontal if the program slice of the operation (from op back to seedOp) -/// does not contain any op from the group. -bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, - const DominanceInfo &dominanceInfo, Operation *seedOp); - -/// Moves the operands and transitive defs for each op in `operations` directly -/// after `insertionPoint`. Note: this does not check if it is legal to move the -/// operands. -template -static LogicalResult -moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, - Operation *insertionPoint, const DominanceInfo &dominanceInfo, - ArrayRef ignoreOperations = {}) { - BackwardSliceOptions options; - llvm::DenseSet ignoreOperationsSet; - ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end()); - options.filter = [&](Operation *op) { - return !dominanceInfo.properlyDominates(op, insertionPoint) && - !ignoreOperationsSet.contains(op); - }; - // Set inclusive to true cause the slice is computed from the operand, and - // we want to include the defining op (which is the point here) - options.inclusive = true; - - llvm::SetVector slice; - for (auto op : operations) { - assert(insertionPoint->getBlock() == op->getBlock()); - for (auto operand : op->getOperands()) { - getBackwardSlice(operand, &slice, options); - } - } - - mlir::topologicalSort(slice); - for (auto op : slice) { - rewriter.moveOpBefore(op, insertionPoint); - } - return success(); -} - } // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir index c76fa0653635..cc3e159ca943 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir @@ -139,28 +139,3 @@ util.func public @math_sin() { // CHECK: %[[GENERIC:.+]]:2 = linalg.generic // CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0, // CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1, - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { - %cst = arith.constant 1.000000e+00 : f32 - %cst_0 = arith.constant 2.000000e+00 : f32 - %cst_1 = arith.constant 3.000000e+00 : f32 - %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.addf %arg2, %cst : f32 - linalg.yield %8 : f32 - } -> tensor<5x5xf32> - // expected-note @below {{prior use here}} - %collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> - %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.subf %arg2, %cst_0 : f32 - linalg.yield %8 : f32 - } -> tensor<5x5xf32> - util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32> -} -// CHECK-LABEL: util.func public @fuse_by_moving_consumer -// CHECK: linalg.generic -// CHECK-NOT: linalg.generic From 7ef93a7b584ef50a6a9cf0840667380bfdd02f6d Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Mon, 28 Oct 2024 16:59:04 +0000 Subject: [PATCH 2/2] Update dispatch count Signed-off-by: Ian Wood --- .github/workflows/pkgci_regression_test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index fb94905c1b29..9849c574dd72 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -220,7 +220,7 @@ jobs: --goldentime-rocm-unet-ms 419.0 \ --goldentime-rocm-clip-ms 18.5 \ --goldentime-rocm-vae-ms 337.0 \ - --goldendispatch-rocm-unet 1527 \ + --goldendispatch-rocm-unet 1531 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2280000 \ @@ -241,7 +241,7 @@ jobs: --goldentime-rocm-unet-ms 95.0 \ --goldentime-rocm-clip-ms 15.5 \ --goldentime-rocm-vae-ms 80.0 \ - --goldendispatch-rocm-unet 1527 \ + --goldendispatch-rocm-unet 1531 \ --goldendispatch-rocm-clip 1139 \ --goldendispatch-rocm-vae 247 \ --goldensize-rocm-unet-bytes 2270000 \