diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp index 238c866fe461..10224cea9444 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp @@ -101,29 +101,39 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand, return true; } +void getBackwardSliceIncludingUsesFromAbove( + Operation *op, SetVector *backwardSlice, + const BackwardSliceOptions &options) { + BackwardSliceOptions wrappedOptions(options); + wrappedOptions.filter = [&](Operation *op) { + if (!options.filter(op)) { + return false; + } + BackwardSliceOptions regionOptions(wrappedOptions); + regionOptions.inclusive = true; + mlir::visitUsedValuesDefinedAbove( + op->getRegions(), [&](OpOperand *operand) { + getBackwardSlice(operand->get(), backwardSlice, regionOptions); + }); + return true; + }; + + getBackwardSlice(op, backwardSlice, wrappedOptions); +} + bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, const DominanceInfo &dominanceInfo, Operation *seedOp) { assert(dominanceInfo.properlyDominates(seedOp, op) && op->getParentRegion() == seedOp->getParentRegion()); + BackwardSliceOptions options; // Limit the slice to the seed to make sure the slice is small. options.filter = [&](Operation *op) { return !dominanceInfo.properlyDominates(op, seedOp); }; llvm::SetVector slice; - getBackwardSlice(op, &slice, options); - - // `getBackwardSlice` doesnt track uses from within an ops region, so make - // sure there are no values defined above. - for (Operation *sliceOp : slice) { - bool usesValuesFromAbove = false; - mlir::visitUsedValuesDefinedAbove( - sliceOp->getRegions(), [&](void *) { usesValuesFromAbove = true; }); - if (usesValuesFromAbove) { - return false; - } - } + getBackwardSliceIncludingUsesFromAbove(op, &slice, options); return !llvm::any_of(currGroup, [&](Operation *groupedOp) { return slice.contains(groupedOp); diff --git a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h index 6526badfea31..3b40dbb4a523 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h +++ b/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h @@ -29,6 +29,14 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand, bool isHorizontalToGroup(Operation *op, ArrayRef currGroup, const DominanceInfo &dominanceInfo, Operation *seedOp); +/// Wrapps `filter` so that operations used in a region of an op also get +/// included in the backward slice. This breaks the topological sorting of the +/// original `getBackwardsSlice` but isn't nessasary for the uses here. +/// TODO: Upstream this as a part of `getBackwardsSlice`. +void getBackwardSliceIncludingUsesFromAbove( + Operation *op, SetVector *backwardSlice, + const BackwardSliceOptions &options); + /// Moves the operands and transitive defs for each op in `operations` directly /// after `insertionPoint`. Note: this does not check if it is legal to move the /// operands. @@ -44,16 +52,11 @@ moveOperandDefs(RewriterBase &rewriter, ArrayRef operations, return !dominanceInfo.properlyDominates(op, insertionPoint) && !ignoreOperationsSet.contains(op); }; - // Set inclusive to true cause the slice is computed from the operand, and - // we want to include the defining op (which is the point here) - options.inclusive = true; llvm::SetVector slice; for (auto op : operations) { assert(insertionPoint->getBlock() == op->getBlock()); - for (auto operand : op->getOperands()) { - getBackwardSlice(operand, &slice, options); - } + getBackwardSliceIncludingUsesFromAbove(op, &slice, options); } mlir::topologicalSort(slice); diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir index c76fa0653635..36c02decdac5 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir @@ -152,11 +152,11 @@ util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor< %8 = arith.addf %arg2, %cst : f32 linalg.yield %8 : f32 } -> tensor<5x5xf32> - // expected-note @below {{prior use here}} %collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { ^bb0(%arg2: f32, %arg3: f32): - %8 = arith.subf %arg2, %cst_0 : f32 + %cst_2 = arith.constant 2.000000e+00 : f32 + %8 = arith.subf %arg2, %cst_2 : f32 linalg.yield %8 : f32 } -> tensor<5x5xf32> util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32> @@ -164,3 +164,58 @@ util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor< // CHECK-LABEL: util.func public @fuse_by_moving_consumer // CHECK: linalg.generic // CHECK-NOT: linalg.generic + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +util.func public @dont_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %cst_1 = arith.constant 3.000000e+00 : f32 + %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %2 = arith.addf %in, %cst : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %c2 = arith.constant 2 : index + %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32> + %2 = arith.addf %extracted, %extracted : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32> +} + +// CHECK-LABEL: util.func public @dont_fuse_use_from_above +// CHECK: linalg.generic +// CHECK: linalg.generic + + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +util.func public @do_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %cst_1 = arith.constant 3.000000e+00 : f32 + %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %2 = arith.addf %in, %cst : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { + ^bb0(%in: f32, %out: f32): + %c2 = arith.constant 2 : index + %extracted = tensor.extract %arg0[%c2, %c2] : tensor<5x5xf32> + %2 = arith.addf %extracted, %extracted : f32 + linalg.yield %2 : f32 + } -> tensor<5x5xf32> + util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32> +} + +// CHECK-LABEL: util.func public @do_fuse_use_from_above +// CHECK: linalg.generic +// CHECK-NOT: linalg.generic