From d0e4fa24e0b39351455d29589a38f178079667e3 Mon Sep 17 00:00:00 2001 From: ergawy Date: Fri, 26 Jul 2024 23:36:45 -0500 Subject: [PATCH] [flang][OpenMP][DoConcurrent] Simplify loop-nest detection logic With https://github.com/llvm/llvm-project/pull/114020, do-concurrent loop-nests are more conforment to the spec and easier to detect. All we need to do is to check that the only operations inside `loop A` which perfectly wraps `loop B` are: * the operations needed to update `loop A`'s iteration variable and * `loop B` itself. This PR simlifies the pass a bit using the above logic and replaces https://github.com/ROCm/llvm-project/pull/127. --- .../OpenMP/DoConcurrentConversion.cpp | 205 +++++++++--------- .../DoConcurrent/loop_nest_test.f90 | 87 ++++++++ .../multiple_iteration_ranges.f90 | 46 ---- 3 files changed, 193 insertions(+), 145 deletions(-) create mode 100644 flang/test/Transforms/DoConcurrent/loop_nest_test.f90 diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index e2a109126810dd..dcc0ad0d53d0ed 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -36,7 +36,8 @@ namespace flangomp { #include "flang/Optimizer/OpenMP/Passes.h.inc" } // namespace flangomp -#define DEBUG_TYPE "fopenmp-do-concurrent-conversion" +#define DEBUG_TYPE "do-concurrent-conversion" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") namespace Fortran { namespace lower { @@ -45,14 +46,12 @@ namespace internal { // TODO The following 2 functions are copied from "flang/Lower/OpenMP/Utils.h". // This duplication is temporary until we find a solution for a shared location // for these utils that does not introduce circular CMake deps. -mlir::omp::MapInfoOp -createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc, - mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name, - llvm::ArrayRef bounds, - llvm::ArrayRef members, - mlir::ArrayAttr membersIndex, uint64_t mapType, - mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, - bool partialMap = false) { +mlir::omp::MapInfoOp createMapInfoOp( + mlir::OpBuilder &builder, mlir::Location loc, mlir::Value baseAddr, + mlir::Value varPtrPtr, std::string name, llvm::ArrayRef bounds, + llvm::ArrayRef members, mlir::ArrayAttr membersIndex, + uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, + mlir::Type retTy, bool partialMap = false) { if (auto boxTy = llvm::dyn_cast(baseAddr.getType())) { baseAddr = builder.create(loc, baseAddr); retTy = baseAddr.getType(); @@ -255,9 +254,24 @@ bool isIndVarUltimateOperand(mlir::Operation *op, fir::DoLoopOp doLoop) { return false; } +/// For the \p doLoop parameter, find the operations that declares its induction +/// variable or allocates memory for it. +mlir::Operation *findLoopIndVarMemDecl(fir::DoLoopOp doLoop) { + mlir::Value result = nullptr; + mlir::visitUsedValuesDefinedAbove( + doLoop.getRegion(), [&](mlir::OpOperand *operand) { + if (isIndVarUltimateOperand(operand->getOwner(), doLoop)) { + assert(result == nullptr && + "loop can have only one induction variable"); + result = operand->get(); + } + }); + + assert(result != nullptr && result.getDefiningOp() != nullptr); + return result.getDefiningOp(); +} + /// Collect the list of values used inside the loop but defined outside of it. -/// The first item in the returned list is always the loop's induction -/// variable. void collectLoopLiveIns(fir::DoLoopOp doLoop, llvm::SmallVectorImpl &liveIns) { llvm::SmallDenseSet seenValues; @@ -274,9 +288,6 @@ void collectLoopLiveIns(fir::DoLoopOp doLoop, return; liveIns.push_back(operand->get()); - - if (isIndVarUltimateOperand(operand->getOwner(), doLoop)) - std::swap(*liveIns.begin(), *liveIns.rbegin()); }); } @@ -366,24 +377,78 @@ void collectIndirectConstOpChain(mlir::Operation *link, opChain.insert(link); } +/// Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff +/// there are no operations in \p outerloop's other than: +/// +/// 1. the operations needed to assing/update \p outerLoop's induction variable. +/// 2. \p innerLoop itself. +/// +/// \p return true if \p innerLoop is perfectly nested inside \p outerLoop +/// according to the above definition. +bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) { + mlir::BackwardSliceOptions backwardSliceOptions; + backwardSliceOptions.inclusive = true; + // We will collect the backward slices for innerLoop's LB, UB, and step. + // However, we want to limit the scope of these slices to the scope of + // outerLoop's region. + backwardSliceOptions.filter = [&](mlir::Operation *op) { + return !mlir::areValuesDefinedAbove(op->getResults(), + outerLoop.getRegion()); + }; + + mlir::ForwardSliceOptions forwardSliceOptions; + forwardSliceOptions.inclusive = true; + // We don't care about the outer-loop's induction variable's uses within the + // inner-loop, so we filter out these uses. + forwardSliceOptions.filter = [&](mlir::Operation *op) { + return mlir::areValuesDefinedAbove(op->getResults(), innerLoop.getRegion()); + }; + + llvm::SetVector indVarSlice; + mlir::getForwardSlice(outerLoop.getInductionVar(), &indVarSlice, + forwardSliceOptions); + llvm::DenseSet innerLoopSetupOpsSet(indVarSlice.begin(), + indVarSlice.end()); + + llvm::DenseSet loopBodySet; + outerLoop.walk([&](mlir::Operation *op) { + if (op == outerLoop) + return mlir::WalkResult::advance(); + + if (op == innerLoop) + return mlir::WalkResult::skip(); + + if (mlir::isa(op)) + return mlir::WalkResult::advance(); + + loopBodySet.insert(op); + return mlir::WalkResult::advance(); + }); + + bool result = (loopBodySet == innerLoopSetupOpsSet); + mlir::Location loc = outerLoop.getLoc(); + LLVM_DEBUG(DBGS() << "Loop pair starting at location " << loc << " is" + << (result ? "" : " not") << " perfectly nested\n"); + + return result; +} + /// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This /// function collects as much as possible loops in the nest; it case it fails to /// recognize a certain nested loop as part of the nest it just returns the /// parent loops it discovered before. -mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop, +mlir::LogicalResult collectLoopNest(fir::DoLoopOp currentLoop, LoopNestToIndVarMap &loopNest) { - assert(outerLoop.getUnordered()); - llvm::SmallVector outerLoopLiveIns; - collectLoopLiveIns(outerLoop, outerLoopLiveIns); + assert(currentLoop.getUnordered()); while (true) { loopNest.try_emplace( - outerLoop, + currentLoop, InductionVariableInfo{ - outerLoopLiveIns.front().getDefiningOp(), - std::move(looputils::extractIndVarUpdateOps(outerLoop))}); + findLoopIndVarMemDecl(currentLoop), + std::move(looputils::extractIndVarUpdateOps(currentLoop))}); - auto directlyNestedLoops = outerLoop.getRegion().getOps(); + auto directlyNestedLoops = currentLoop.getRegion().getOps(); llvm::SmallVector unorderedLoops; for (auto nestedLoop : directlyNestedLoops) @@ -398,69 +463,10 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop, fir::DoLoopOp nestedUnorderedLoop = unorderedLoops.front(); - if ((nestedUnorderedLoop.getLowerBound().getDefiningOp() == nullptr) || - (nestedUnorderedLoop.getUpperBound().getDefiningOp() == nullptr) || - (nestedUnorderedLoop.getStep().getDefiningOp() == nullptr)) + if (!isPerfectlyNested(currentLoop, nestedUnorderedLoop)) return mlir::failure(); - llvm::SmallVector nestedLiveIns; - collectLoopLiveIns(nestedUnorderedLoop, nestedLiveIns); - - llvm::DenseSet outerLiveInsSet; - llvm::DenseSet nestedLiveInsSet; - - // Returns a "unified" view of an mlir::Value. This utility checks if the - // value is defined by an op, and if so, return the first value defined by - // that op (if there are many), otherwise just returns the value. - // - // This serves the purpose that if, for example, `%op_res#0` is used in the - // outer loop and `%op_res#1` is used in the nested loop (or vice versa), - // that we detect both as the same value. If we did not do so, we might - // falesely detect that the 2 loops are not perfectly nested since they use - // "different" sets of values. - auto getUnifiedLiveInView = [](mlir::Value liveIn) { - return liveIn.getDefiningOp() != nullptr - ? liveIn.getDefiningOp()->getResult(0) - : liveIn; - }; - - // Re-package both lists of live-ins into sets so that we can use set - // equality to compare the values used in the outerloop vs. the nestd one. - - for (auto liveIn : nestedLiveIns) - nestedLiveInsSet.insert(getUnifiedLiveInView(liveIn)); - - mlir::Value outerLoopIV; - for (auto liveIn : outerLoopLiveIns) { - outerLiveInsSet.insert(getUnifiedLiveInView(liveIn)); - - // Keep track of the IV of the outerloop. See `isPerfectlyNested` for more - // info on the reason. - if (outerLoopIV == nullptr) - outerLoopIV = getUnifiedLiveInView(liveIn); - } - - // For the 2 loops to be perfectly nested, either: - // * both would have exactly the same set of live-in values or, - // * the outer loop would have exactly 1 extra live-in value: the outer - // loop's induction variable; this happens when the outer loop's IV is - // *not* referenced in the nested loop. - bool isPerfectlyNested = [&]() { - if (outerLiveInsSet == nestedLiveInsSet) - return true; - - if ((outerLiveInsSet.size() == nestedLiveIns.size() + 1) && - !nestedLiveInsSet.contains(outerLoopIV)) - return true; - - return false; - }(); - - if (!isPerfectlyNested) - return mlir::failure(); - - outerLoop = nestedUnorderedLoop; - outerLoopLiveIns = std::move(nestedLiveIns); + currentLoop = nestedUnorderedLoop; } return mlir::success(); @@ -634,10 +640,6 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { "defining operation."); } - llvm::SmallVector outermostLoopLiveIns; - looputils::collectLoopLiveIns(doLoop, outermostLoopLiveIns); - assert(!outermostLoopLiveIns.empty()); - looputils::LoopNestToIndVarMap loopNest; bool hasRemainingNestedLoops = failed(looputils::collectLoopNest(doLoop, loopNest)); @@ -646,15 +648,19 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { "Some `do concurent` loops are not perfectly-nested. " "These will be serialzied."); + llvm::SmallVector loopNestLiveIns; + looputils::collectLoopLiveIns(loopNest.back().first, loopNestLiveIns); + assert(!loopNestLiveIns.empty()); + llvm::SetVector locals; looputils::collectLoopLocalValues(loopNest.back().first, locals); // We do not want to map "loop-local" values to the device through // `omp.map.info` ops. Therefore, we remove them from the list of live-ins. - outermostLoopLiveIns.erase(llvm::remove_if(outermostLoopLiveIns, - [&](mlir::Value liveIn) { - return locals.contains(liveIn); - }), - outermostLoopLiveIns.end()); + loopNestLiveIns.erase(llvm::remove_if(loopNestLiveIns, + [&](mlir::Value liveIn) { + return locals.contains(liveIn); + }), + loopNestLiveIns.end()); looputils::sinkLoopIVArgs(rewriter, loopNest); @@ -669,12 +675,12 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { // The outermost loop will contain all the live-in values in all nested // loops since live-in values are collected recursively for all nested // ops. - for (mlir::Value liveIn : outermostLoopLiveIns) + for (mlir::Value liveIn : loopNestLiveIns) targetClauseOps.mapVars.push_back( genMapInfoOpForLiveIn(rewriter, liveIn)); - targetOp = genTargetOp(doLoop.getLoc(), rewriter, mapper, - outermostLoopLiveIns, targetClauseOps); + targetOp = genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns, + targetClauseOps); genTeamsOp(doLoop.getLoc(), rewriter); } @@ -1062,10 +1068,11 @@ class DoConcurrentConversionPass context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device, concurrentLoopsToSkip); mlir::ConversionTarget target(*context); - target.addLegalDialect< - fir::FIROpsDialect, hlfir::hlfirDialect, mlir::arith::ArithDialect, - mlir::func::FuncDialect, mlir::omp::OpenMPDialect, - mlir::cf::ControlFlowDialect, mlir::math::MathDialect>(); + target + .addLegalDialect(); target.addDynamicallyLegalOp([&](fir::DoLoopOp op) { return !op.getUnordered() || concurrentLoopsToSkip.contains(op); diff --git a/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 b/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 new file mode 100644 index 00000000000000..c73a8be64bed63 --- /dev/null +++ b/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 @@ -0,0 +1,87 @@ +! Tests loop-nest detection algorithm for do-concurrent mapping. + +! REQUIRES: asserts + +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host \ +! RUN: -mmlir -debug %s -o - 2> %t.log || true + +! RUN: FileCheck %s < %t.log + +program main + implicit none + +contains + +subroutine foo(n) + implicit none + integer :: n, m + integer :: i, j, k + integer :: x + integer, dimension(n) :: a + integer, dimension(n, n, n) :: b + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested + do concurrent(i=1:n, j=1:bar(n*m, n/m)) + a(i) = n + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested + do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m)) + a(i) = n + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested + do concurrent(i=bar(n, x):n) + do concurrent(j=1:bar(n*m, n/m)) + a(i) = n + end do + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested + do concurrent(i=1:n) + x = 10 + do concurrent(j=1:m) + b(i,j,k) = i * j + k + end do + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested + do concurrent(i=1:n) + do concurrent(j=1:m) + b(i,j,k) = i * j + k + end do + x = 10 + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested + do concurrent(i=1:n) + do concurrent(j=1:m) + b(i,j,k) = i * j + k + x = 10 + end do + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested + do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m), k=1:bar(n*m, bar(n*m, n/m))) + a(i) = n + end do + + +end subroutine + +pure function bar(n, m) + implicit none + integer, intent(in) :: n, m + integer :: bar + + bar = n + m +end function + +end program main diff --git a/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 b/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 index cc3e04306da1f2..6269c8a62d5f43 100644 --- a/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 +++ b/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 @@ -8,22 +8,6 @@ ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=device %t/multi_range.f90 -o - \ ! RUN: | FileCheck %s --check-prefixes=DEVICE,COMMON -! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host %t/perfectly_nested.f90 -o - \ -! RUN: | FileCheck %s --check-prefixes=HOST,COMMON - -! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=device %t/perfectly_nested.f90 -o - \ -! RUN: | FileCheck %s --check-prefixes=DEVICE,COMMON - -! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host %t/partially_nested.f90 -o - \ -! RUN: | FileCheck %s --check-prefixes=HOST,COMMON - -! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=device %t/partially_nested.f90 -o - \ -! RUN: | FileCheck %s --check-prefixes=DEVICE,COMMON - -! This is temporarily disabled since the IR for `do concurrent` loops is different after -! https://github.com/llvm/llvm-project/pull/114020. This will be enabled again soon. -! XFAIL: true - !--- multi_range.f90 program main integer, parameter :: n = 10 @@ -36,36 +20,6 @@ program main end do end -!--- perfectly_nested.f90 -program main - integer, parameter :: n = 10 - integer, parameter :: m = 20 - integer, parameter :: l = 30 - integer :: a(n, m, l) - - do concurrent(i=1:n) - do concurrent(j=1:m) - do concurrent(k=1:l) - a(i,j,k) = i * j + k - end do - end do - end do -end - -!--- partially_nested.f90 -program main - integer, parameter :: n = 10 - integer, parameter :: m = 20 - integer, parameter :: l = 30 - integer :: a(n, m, l) - - do concurrent(i=1:n, j=1:m) - do concurrent(k=1:l) - a(i,j,k) = i * j + k - end do - end do -end - ! DEVICE: omp.target ! DEVICE: omp.teams