From 83776bfbf3af08cfc1a59c248a337c28ba5b157b Mon Sep 17 00:00:00 2001 From: ergawy Date: Fri, 26 Jul 2024 23:36:45 -0500 Subject: [PATCH] [flang][OpenMP] Implement more robust loop-nest detection logic The previous loop-nest detection algorithm fell short, in some cases, to detect whether a pair of `do concurrent` loops are perfectly nested or not. This is a re-implementation using forward and backward slice extraction algorithms to compare the set of ops required to setup the inner loop bounds vs. the set of ops nested in the outer loop other thatn the nested loop itself. --- .../Transforms/DoConcurrentConversion.cpp | 185 +++++++++++++----- .../DoConcurrent/loop_nest_test.f90 | 77 ++++++++ 2 files changed, 210 insertions(+), 52 deletions(-) create mode 100644 flang/test/Transforms/DoConcurrent/loop_nest_test.f90 diff --git a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp index 5a93ad1ec655123..07b27642416733f 100644 --- a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp @@ -36,7 +36,8 @@ namespace fir { #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir -#define DEBUG_TYPE "fopenmp-do-concurrent-conversion" +#define DEBUG_TYPE "do-concurrent-conversion" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") namespace Fortran { namespace lower { @@ -297,6 +298,136 @@ void collectIndirectConstOpChain(mlir::Operation *link, opChain.insert(link); } +/// Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff the +/// only operations in \p outerloop's region are: +/// +/// 1. those operations needed to setup \p innerLoop's LB, UB, and step values, +/// 2. the operations needed to assing/update \p outerLoop's induction variable. +/// 3. \p innerLoop itself. +/// +/// \p return true if \p innerLoop is perfectly nested inside \p outerLoop +/// according to the above definition. +bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) { + mlir::BackwardSliceOptions backwardSliceOptions; + backwardSliceOptions.inclusive = true; + // We will collect the backward slices for innerLoop's LB, UB, and step. + // However, we want to limit the scope of these slices to the scope of + // outerLoop's region. + backwardSliceOptions.filter = [&](mlir::Operation *op) { + return !mlir::areValuesDefinedAbove(op->getResults(), + outerLoop.getRegion()); + }; + + llvm::SetVector lbSlice; + mlir::getBackwardSlice(innerLoop.getLowerBound(), &lbSlice, + backwardSliceOptions); + + llvm::SetVector ubSlice; + mlir::getBackwardSlice(innerLoop.getUpperBound(), &ubSlice, + backwardSliceOptions); + + llvm::SetVector stepSlice; + mlir::getBackwardSlice(innerLoop.getStep(), &stepSlice, backwardSliceOptions); + + mlir::ForwardSliceOptions forwardSliceOptions; + forwardSliceOptions.inclusive = true; + // We don't care of the outer loop's induction variable's uses within the + // inner loop, so we filter out these uses. + forwardSliceOptions.filter = [&](mlir::Operation *op) { + return mlir::areValuesDefinedAbove(op->getResults(), innerLoop.getRegion()); + }; + + llvm::SetVector indVarSlice; + mlir::getForwardSlice(outerLoop.getInductionVar(), &indVarSlice, + forwardSliceOptions); + + // If any of the bounds is computed using a somewhat elaborate expression, we + // might have to allocate temproaries within the loop-nest. For example, + // `foo(m*n, m/n)` would allocate memory for the results of the restuls of + // `m*n` and `m/n` inside the loop and pass the results to `foo`. + // + // The mem alloc ops for these temp values will be part of the backward slices + // of the loop bounds. Howerver, the corresponding mem free ops will not. + // Howerver, we want to find these corresponding mem free ops and them to the + // set of loop setup ops since they would still be emitted in a perfectly + // nested loop. + // + // If an op has a MemAlloc effect, search the region for its corresponding + // MemFree op. + auto findMemFreeOp = [](mlir::Operation *possibleMemAllocOp) { + mlir::Operation *correspondingMemFreeOp = nullptr; + + // We only care about ops which have `Allocate` effect on memory. + if (!mlir::hasEffect(possibleMemAllocOp)) + return correspondingMemFreeOp; + + mlir::Region *region = possibleMemAllocOp->getParentRegion(); + assert(region); + + region->walk([&](mlir::Operation *possibleMemFreeOp) { + // We only search for ops with `Free` effect on memory. + if (!mlir::hasEffect(possibleMemFreeOp)) + return mlir::WalkResult::advance(); + + for (auto result : possibleMemAllocOp->getResults()) { + for (auto operand : possibleMemFreeOp->getOperands()) { + // If the result of a mem alloc op is the operand of a mem free op, + // then they probably correspond to each other. + if (result == operand) { + if (correspondingMemFreeOp == nullptr) { + correspondingMemFreeOp = possibleMemFreeOp; + break; + } + } + } + + if (correspondingMemFreeOp != nullptr) + break; + } + + return mlir::WalkResult::advance(); + }); + + return correspondingMemFreeOp; + }; + + llvm::SetVector innerLoopSetupOpsVec; + innerLoopSetupOpsVec.set_union(indVarSlice); + innerLoopSetupOpsVec.set_union(lbSlice); + innerLoopSetupOpsVec.set_union(ubSlice); + innerLoopSetupOpsVec.set_union(stepSlice); + llvm::DenseSet innerLoopSetupOpsSet; + + for (mlir::Operation *op : innerLoopSetupOpsVec) { + innerLoopSetupOpsSet.insert(op); + mlir::Operation *memFreeOp = findMemFreeOp(op); + + if (memFreeOp != nullptr) + innerLoopSetupOpsSet.insert(memFreeOp); + } + + llvm::DenseSet loopBodySet; + outerLoop.walk([&](mlir::Operation *op) { + if (op == outerLoop) + return mlir::WalkResult::advance(); + + if (op == innerLoop) + return mlir::WalkResult::skip(); + + if (op->hasTrait()) + return mlir::WalkResult::advance(); + + loopBodySet.insert(op); + return mlir::WalkResult::advance(); + }); + + bool result = (loopBodySet == innerLoopSetupOpsSet); + LLVM_DEBUG(DBGS() << "Loop pair starting at location " << outerLoop.getLoc() + << " is" << (result ? "" : " not") + << " perfectly nested\n"); + return result; +} + /// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This /// function collects as much as possible loops in the nest; it case it fails to /// recognize a certain nested loop as part of the nest it just returns the @@ -337,57 +468,7 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop, llvm::SmallVector nestedLiveIns; collectLoopLiveIns(nestedUnorderedLoop, nestedLiveIns); - llvm::DenseSet outerLiveInsSet; - llvm::DenseSet nestedLiveInsSet; - - // Returns a "unified" view of an mlir::Value. This utility checks if the - // value is defined by an op, and if so, return the first value defined by - // that op (if there are many), otherwise just returns the value. - // - // This serves the purpose that if, for example, `%op_res#0` is used in the - // outer loop and `%op_res#1` is used in the nested loop (or vice versa), - // that we detect both as the same value. If we did not do so, we might - // falesely detect that the 2 loops are not perfectly nested since they use - // "different" sets of values. - auto getUnifiedLiveInView = [](mlir::Value liveIn) { - return liveIn.getDefiningOp() != nullptr - ? liveIn.getDefiningOp()->getResult(0) - : liveIn; - }; - - // Re-package both lists of live-ins into sets so that we can use set - // equality to compare the values used in the outerloop vs. the nestd one. - - for (auto liveIn : nestedLiveIns) - nestedLiveInsSet.insert(getUnifiedLiveInView(liveIn)); - - mlir::Value outerLoopIV; - for (auto liveIn : outerLoopLiveIns) { - outerLiveInsSet.insert(getUnifiedLiveInView(liveIn)); - - // Keep track of the IV of the outerloop. See `isPerfectlyNested` for more - // info on the reason. - if (outerLoopIV == nullptr) - outerLoopIV = getUnifiedLiveInView(liveIn); - } - - // For the 2 loops to be perfectly nested, either: - // * both would have exactly the same set of live-in values or, - // * the outer loop would have exactly 1 extra live-in value: the outer - // loop's induction variable; this happens when the outer loop's IV is - // *not* referenced in the nested loop. - bool isPerfectlyNested = [&]() { - if (outerLiveInsSet == nestedLiveInsSet) - return true; - - if ((outerLiveInsSet.size() == nestedLiveIns.size() + 1) && - !nestedLiveInsSet.contains(outerLoopIV)) - return true; - - return false; - }(); - - if (!isPerfectlyNested) + if (!isPerfectlyNested(outerLoop, nestedUnorderedLoop)) return mlir::failure(); outerLoop = nestedUnorderedLoop; diff --git a/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 b/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 new file mode 100644 index 000000000000000..2fc614794f4cc3b --- /dev/null +++ b/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 @@ -0,0 +1,77 @@ +! Tests loop-nest detection algorithm for do-concurrent mapping. + +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host \ +! RUN: -mmlir -debug %s -o - &> %t.log || true + +! RUN: FileCheck %s < %t.log + +program main + implicit none + +contains + +subroutine foo(n) + implicit none + integer :: n, m + integer :: i, j, k + integer :: x + integer, dimension(n) :: a + integer, dimension(n, n, n) :: b + + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is perfectly nested + do concurrent(i=1:n, j=1:bar(n*m, n/m)) + a(i) = n + end do + + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is perfectly nested + do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m)) + a(i) = n + end do + + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is perfectly nested + do concurrent(i=bar(n, x):n) + do concurrent(j=1:bar(n*m, n/m)) + a(i) = n + end do + end do + + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is not perfectly nested + do concurrent(i=1:n) + x = 10 + do concurrent(j=1:m) + b(i,j,k) = i * j + k + end do + end do + + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is not perfectly nested + do concurrent(i=1:n) + do concurrent(j=1:m) + b(i,j,k) = i * j + k + end do + x = 10 + end do + + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is perfectly nested + do concurrent(i=1:n) + do concurrent(j=1:m) + b(i,j,k) = i * j + k + x = 10 + end do + end do +end subroutine + +pure function bar(n, m) + implicit none + integer, intent(in) :: n, m + integer :: bar + + bar = n + m +end function + +end program main