From 6386fd34a75fb032ebc8abe8de6f891875b3654b Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Mon, 21 Oct 2024 15:26:55 +0000 Subject: [PATCH 1/3] Fix error in FuseMultiUseElementwiseProducersPass `op` isn't included in `slice` so it must be independently checked to make sure it doesn't use any values defined above. Otherwise, a consumer that uses values defined above may be fused with a producer causing dominance errors. Closes https://github.com/iree-org/iree/issues/18847 Signed-off-by: Ian Wood --- .../compiler/DispatchCreation/FusionUtils.cpp | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index 238c866fe461..7a5510bb5379 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -101,11 +101,23 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return true; } +static bool usesValuesDefinedAbove(Operation *op) { + bool usesValuesFromAbove = false; + mlir::visitUsedValuesDefinedAbove( + op->getRegions(), [&](void *) { usesValuesFromAbove = true; }); + return usesValuesFromAbove; +} + bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, const DominanceInfo &dominanceInfo, Operation *seedOp) { assert(dominanceInfo.properlyDominates(seedOp, op) && op->getParentRegion() == seedOp->getParentRegion()); + + if (usesValuesDefinedAbove(op)) { + return false; + } + BackwardSliceOptions options; // Limit the slice to the seed to make sure the slice is small. options.filter = [&](Operation *op) { @@ -116,13 +128,8 @@ bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, // `getBackwardSlice` doesnt track uses from within an ops region, so make // sure there are no values defined above. - for (Operation *sliceOp : slice) { - bool usesValuesFromAbove = false; - mlir::visitUsedValuesDefinedAbove( - sliceOp->getRegions(), [&](void *) { usesValuesFromAbove = true; }); - if (usesValuesFromAbove) { - return false; - } + if (llvm::any_of(slice, usesValuesDefinedAbove)) { + return false; } return !llvm::any_of(currGroup, [&](Operation *groupedOp) { From 89d886239acab03879f09ddd768be85635f75d21 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Mon, 21 Oct 2024 15:43:47 +0000 Subject: [PATCH 2/3] fixup broken test Signed-off-by: Ian Wood --- .../test/fuse_multiuse_elementwise_producer.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir index c76fa0653635..64e5485cc91b 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir @@ -152,11 +152,11 @@ util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor< %8 = arith.addf %arg2, %cst : f32 linalg.yield %8 : f32 } -> tensor<5x5xf32> - // expected-note @below {{prior use here}} %collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.subf %arg2, %cst_0 : f32 + %cst_2 = arith.constant 2.000000e+00 : f32 + %8 = arith.subf %arg2, %cst_2 : f32 linalg.yield %8 : f32 } -> tensor<5x5xf32> util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32> From 6dbb65565c2ad36ee7c98f5f6c77a6a622fd3b8a Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 22 Oct 2024 20:18:14 +0000 Subject: [PATCH 3/3] Include values used in regions Signed-off-by: Ian Wood --- .../compiler/DispatchCreation/FusionUtils.cpp | 35 ++++++------ .../compiler/DispatchCreation/FusionUtils.h | 15 +++-- .../fuse_multiuse_elementwise_producer.mlir | 55 +++++++++++++++++++ 3 files changed, 83 insertions(+), 22 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index 7a5510bb5379..10224cea9444 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -101,11 +101,24 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return true; } -static bool usesValuesDefinedAbove(Operation *op) { - bool usesValuesFromAbove = false; - mlir::visitUsedValuesDefinedAbove( - op->getRegions(), [&](void *) { usesValuesFromAbove = true; }); - return usesValuesFromAbove; +void getBackwardSliceIncludingUsesFromAbove( + Operation *op, SetVector *backwardSlice, + const BackwardSliceOptions &options) { + BackwardSliceOptions wrappedOptions(options); + wrappedOptions.filter = [&](Operation *op) { + if (!options.filter(op)) { + return false; + } + BackwardSliceOptions regionOptions(wrappedOptions); + regionOptions.inclusive = true; + mlir::visitUsedValuesDefinedAbove( + op->getRegions(), [&](OpOperand *operand) { + getBackwardSlice(operand->get(), backwardSlice, regionOptions); + }); + return true; + }; + + getBackwardSlice(op, backwardSlice, wrappedOptions); } bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, @@ -114,23 +127,13 @@ bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, assert(dominanceInfo.properlyDominates(seedOp, op) && op->getParentRegion() == seedOp->getParentRegion()); - if (usesValuesDefinedAbove(op)) { - return false; - } - BackwardSliceOptions options; // Limit the slice to the seed to make sure the slice is small. options.filter = [&](Operation *op) { return !dominanceInfo.properlyDominates(op, seedOp); }; llvm::SetVector slice; - getBackwardSlice(op, &slice, options); - - // `getBackwardSlice` doesnt track uses from within an ops region, so make - // sure there are no values defined above. - if (llvm::any_of(slice, usesValuesDefinedAbove)) { - return false; - } + getBackwardSliceIncludingUsesFromAbove(op, &slice, options); return !llvm::any_of(currGroup, [&](Operation *groupedOp) { return slice.contains(groupedOp); diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h index 6526badfea31..3b40dbb4a523 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h @@ -29,6 +29,14 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand, bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, const DominanceInfo &dominanceInfo, Operation *seedOp); +/// Wrapps `filter` so that operations used in a region of an op also get +/// included in the backward slice. This breaks the topological sorting of the +/// original `getBackwardsSlice` but isn't nessasary for the uses here. +/// TODO: Upstream this as a part of `getBackwardsSlice`. +void getBackwardSliceIncludingUsesFromAbove( + Operation *op, SetVector *backwardSlice, + const BackwardSliceOptions &options); + /// Moves the operands and transitive defs for each op in `operations` directly /// after `insertionPoint`. Note: this does not check if it is legal to move the /// operands. @@ -44,16 +52,11 @@ moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, return !dominanceInfo.properlyDominates(op, insertionPoint) && !ignoreOperationsSet.contains(op); }; - // Set inclusive to true cause the slice is computed from the operand, and - // we want to include the defining op (which is the point here) - options.inclusive = true; llvm::SetVector slice; for (auto op : operations) { assert(insertionPoint->getBlock() == op->getBlock()); - for (auto operand : op->getOperands()) { - getBackwardSlice(operand, &slice, options); - } + getBackwardSliceIncludingUsesFromAbove(op, &slice, options); } mlir::topologicalSort(slice); diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir index 64e5485cc91b..36c02decdac5 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir @@ -164,3 +164,58 @@ util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor< // CHECK-LABEL: util.func public @fuse_by_moving_consumer // CHECK: linalg.generic // CHECK-NOT: linalg.generic + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +util.func public @dont_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %cst_1 = arith.constant 3.000000e+00 : f32 + %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %2 = arith.addf %in, %cst : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %c2 = arith.constant 2 : index + %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32> + %2 = arith.addf %extracted, %extracted : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32> +} + +// CHECK-LABEL: util.func public @dont_fuse_use_from_above +// CHECK: linalg.generic +// CHECK: linalg.generic + + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +util.func public @do_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %cst_1 = arith.constant 3.000000e+00 : f32 + %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %2 = arith.addf %in, %cst : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %c2 = arith.constant 2 : index + %extracted = tensor.extract %arg0[%c2, %c2] : tensor<5x5xf32> + %2 = arith.addf %extracted, %extracted : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32> +} + +// CHECK-LABEL: util.func public @do_fuse_use_from_above +// CHECK: linalg.generic +// CHECK-NOT: linalg.generic