Skip to content

Commit

Permalink
[flang][OpenMP][DoConcurrent] Simplify loop-nest detection logic
Browse files Browse the repository at this point in the history
With llvm#114020, do-concurrent loop-nests are more conforment to the spec and easier to detect. All we need to do is to check that the only operations inside `loop A` which perfectly wraps `loop B` are:

  * the operations needed to update `loop A`'s iteration variable and
  * `loop B` itself.

This PR simlifies the pass a bit using the above logic and replaces ROCm#127.
  • Loading branch information
ergawy committed Nov 1, 2024
1 parent 8634cb1 commit 77ce502
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 141 deletions.
193 changes: 98 additions & 95 deletions flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace flangomp {
#include "flang/Optimizer/OpenMP/Passes.h.inc"
} // namespace flangomp

#define DEBUG_TYPE "fopenmp-do-concurrent-conversion"
#define DEBUG_TYPE "do-concurrent-conversion"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

namespace Fortran {
namespace lower {
Expand All @@ -45,14 +46,12 @@ namespace internal {
// TODO The following 2 functions are copied from "flang/Lower/OpenMP/Utils.h".
// This duplication is temporary until we find a solution for a shared location
// for these utils that does not introduce circular CMake deps.
mlir::omp::MapInfoOp
createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
llvm::ArrayRef<mlir::Value> bounds,
llvm::ArrayRef<mlir::Value> members,
mlir::ArrayAttr membersIndex, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool partialMap = false) {
mlir::omp::MapInfoOp createMapInfoOp(
mlir::OpBuilder &builder, mlir::Location loc, mlir::Value baseAddr,
mlir::Value varPtrPtr, std::string name, llvm::ArrayRef<mlir::Value> bounds,
llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
mlir::Type retTy, bool partialMap = false) {
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
retTy = baseAddr.getType();
Expand Down Expand Up @@ -255,9 +254,21 @@ bool isIndVarUltimateOperand(mlir::Operation *op, fir::DoLoopOp doLoop) {
return false;
}

/// For the \p doLoop parameter, find the operations that declares its induction
/// variable or allocates memory for it.
mlir::Operation *findLoopIndVarMemDecl(fir::DoLoopOp doLoop) {
mlir::Value result = nullptr;
mlir::visitUsedValuesDefinedAbove(
doLoop.getRegion(), [&](mlir::OpOperand *operand) {
if (isIndVarUltimateOperand(operand->getOwner(), doLoop))
result = operand->get();
});

assert(result.getDefiningOp() != nullptr);
return result.getDefiningOp();
}

/// Collect the list of values used inside the loop but defined outside of it.
/// The first item in the returned list is always the loop's induction
/// variable.
void collectLoopLiveIns(fir::DoLoopOp doLoop,
llvm::SmallVectorImpl<mlir::Value> &liveIns) {
llvm::SmallDenseSet<mlir::Value> seenValues;
Expand All @@ -274,9 +285,6 @@ void collectLoopLiveIns(fir::DoLoopOp doLoop,
return;

liveIns.push_back(operand->get());

if (isIndVarUltimateOperand(operand->getOwner(), doLoop))
std::swap(*liveIns.begin(), *liveIns.rbegin());
});
}

Expand Down Expand Up @@ -366,24 +374,78 @@ void collectIndirectConstOpChain(mlir::Operation *link,
opChain.insert(link);
}

/// Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
/// there are no operations in \p outerloop's other than:
///
/// 1. the operations needed to assing/update \p outerLoop's induction variable.
/// 2. \p innerLoop itself.
///
/// \p return true if \p innerLoop is perfectly nested inside \p outerLoop
/// according to the above definition.
bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
mlir::BackwardSliceOptions backwardSliceOptions;
backwardSliceOptions.inclusive = true;
// We will collect the backward slices for innerLoop's LB, UB, and step.
// However, we want to limit the scope of these slices to the scope of
// outerLoop's region.
backwardSliceOptions.filter = [&](mlir::Operation *op) {
return !mlir::areValuesDefinedAbove(op->getResults(),
outerLoop.getRegion());
};

mlir::ForwardSliceOptions forwardSliceOptions;
forwardSliceOptions.inclusive = true;
// We don't care about the outer-loop's induction variable's uses within the
// inner-loop, so we filter out these uses.
forwardSliceOptions.filter = [&](mlir::Operation *op) {
return mlir::areValuesDefinedAbove(op->getResults(), innerLoop.getRegion());
};

llvm::SetVector<mlir::Operation *> indVarSlice;
mlir::getForwardSlice(outerLoop.getInductionVar(), &indVarSlice,
forwardSliceOptions);
llvm::DenseSet<mlir::Operation *> innerLoopSetupOpsSet(indVarSlice.begin(),
indVarSlice.end());

llvm::DenseSet<mlir::Operation *> loopBodySet;
outerLoop.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
if (op == outerLoop)
return mlir::WalkResult::advance();

if (op == innerLoop)
return mlir::WalkResult::skip();

if (op->hasTrait<mlir::OpTrait::IsTerminator>())
return mlir::WalkResult::advance();

loopBodySet.insert(op);
return mlir::WalkResult::advance();
});

bool result = (loopBodySet == innerLoopSetupOpsSet);
mlir::Location loc = outerLoop.getLoc();
LLVM_DEBUG(DBGS() << "Loop pair starting at location " << loc << " is"
<< (result ? "" : " not") << " perfectly nested\n");

return result;
}

/// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This
/// function collects as much as possible loops in the nest; it case it fails to
/// recognize a certain nested loop as part of the nest it just returns the
/// parent loops it discovered before.
mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop,
mlir::LogicalResult collectLoopNest(fir::DoLoopOp currentLoop,
LoopNestToIndVarMap &loopNest) {
assert(outerLoop.getUnordered());
llvm::SmallVector<mlir::Value> outerLoopLiveIns;
collectLoopLiveIns(outerLoop, outerLoopLiveIns);
assert(currentLoop.getUnordered());

while (true) {
loopNest.try_emplace(
outerLoop,
currentLoop,
InductionVariableInfo{
outerLoopLiveIns.front().getDefiningOp(),
std::move(looputils::extractIndVarUpdateOps(outerLoop))});
findLoopIndVarMemDecl(currentLoop),
std::move(looputils::extractIndVarUpdateOps(currentLoop))});

auto directlyNestedLoops = outerLoop.getRegion().getOps<fir::DoLoopOp>();
auto directlyNestedLoops = currentLoop.getRegion().getOps<fir::DoLoopOp>();
llvm::SmallVector<fir::DoLoopOp> unorderedLoops;

for (auto nestedLoop : directlyNestedLoops)
Expand All @@ -398,69 +460,10 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop,

fir::DoLoopOp nestedUnorderedLoop = unorderedLoops.front();

if ((nestedUnorderedLoop.getLowerBound().getDefiningOp() == nullptr) ||
(nestedUnorderedLoop.getUpperBound().getDefiningOp() == nullptr) ||
(nestedUnorderedLoop.getStep().getDefiningOp() == nullptr))
return mlir::failure();

llvm::SmallVector<mlir::Value> nestedLiveIns;
collectLoopLiveIns(nestedUnorderedLoop, nestedLiveIns);

llvm::DenseSet<mlir::Value> outerLiveInsSet;
llvm::DenseSet<mlir::Value> nestedLiveInsSet;

// Returns a "unified" view of an mlir::Value. This utility checks if the
// value is defined by an op, and if so, return the first value defined by
// that op (if there are many), otherwise just returns the value.
//
// This serves the purpose that if, for example, `%op_res#0` is used in the
// outer loop and `%op_res#1` is used in the nested loop (or vice versa),
// that we detect both as the same value. If we did not do so, we might
// falesely detect that the 2 loops are not perfectly nested since they use
// "different" sets of values.
auto getUnifiedLiveInView = [](mlir::Value liveIn) {
return liveIn.getDefiningOp() != nullptr
? liveIn.getDefiningOp()->getResult(0)
: liveIn;
};

// Re-package both lists of live-ins into sets so that we can use set
// equality to compare the values used in the outerloop vs. the nestd one.

for (auto liveIn : nestedLiveIns)
nestedLiveInsSet.insert(getUnifiedLiveInView(liveIn));

mlir::Value outerLoopIV;
for (auto liveIn : outerLoopLiveIns) {
outerLiveInsSet.insert(getUnifiedLiveInView(liveIn));

// Keep track of the IV of the outerloop. See `isPerfectlyNested` for more
// info on the reason.
if (outerLoopIV == nullptr)
outerLoopIV = getUnifiedLiveInView(liveIn);
}

// For the 2 loops to be perfectly nested, either:
// * both would have exactly the same set of live-in values or,
// * the outer loop would have exactly 1 extra live-in value: the outer
// loop's induction variable; this happens when the outer loop's IV is
// *not* referenced in the nested loop.
bool isPerfectlyNested = [&]() {
if (outerLiveInsSet == nestedLiveInsSet)
return true;

if ((outerLiveInsSet.size() == nestedLiveIns.size() + 1) &&
!nestedLiveInsSet.contains(outerLoopIV))
return true;

return false;
}();

if (!isPerfectlyNested)
if (!isPerfectlyNested(currentLoop, nestedUnorderedLoop))
return mlir::failure();

outerLoop = nestedUnorderedLoop;
outerLoopLiveIns = std::move(nestedLiveIns);
currentLoop = nestedUnorderedLoop;
}

return mlir::success();
Expand Down Expand Up @@ -634,10 +637,6 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
"defining operation.");
}

llvm::SmallVector<mlir::Value> outermostLoopLiveIns;
looputils::collectLoopLiveIns(doLoop, outermostLoopLiveIns);
assert(!outermostLoopLiveIns.empty());

looputils::LoopNestToIndVarMap loopNest;
bool hasRemainingNestedLoops =
failed(looputils::collectLoopNest(doLoop, loopNest));
Expand All @@ -646,15 +645,19 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
"Some `do concurent` loops are not perfectly-nested. "
"These will be serialzied.");

llvm::SmallVector<mlir::Value> loopNestLiveIns;
looputils::collectLoopLiveIns(loopNest.back().first, loopNestLiveIns);
assert(!loopNestLiveIns.empty());

llvm::SetVector<mlir::Value> locals;
looputils::collectLoopLocalValues(loopNest.back().first, locals);
// We do not want to map "loop-local" values to the device through
// `omp.map.info` ops. Therefore, we remove them from the list of live-ins.
outermostLoopLiveIns.erase(llvm::remove_if(outermostLoopLiveIns,
[&](mlir::Value liveIn) {
return locals.contains(liveIn);
}),
outermostLoopLiveIns.end());
loopNestLiveIns.erase(llvm::remove_if(loopNestLiveIns,
[&](mlir::Value liveIn) {
return locals.contains(liveIn);
}),
loopNestLiveIns.end());

looputils::sinkLoopIVArgs(rewriter, loopNest);

Expand All @@ -669,12 +672,12 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
// The outermost loop will contain all the live-in values in all nested
// loops since live-in values are collected recursively for all nested
// ops.
for (mlir::Value liveIn : outermostLoopLiveIns)
for (mlir::Value liveIn : loopNestLiveIns)
targetClauseOps.mapVars.push_back(
genMapInfoOpForLiveIn(rewriter, liveIn));

targetOp = genTargetOp(doLoop.getLoc(), rewriter, mapper,
outermostLoopLiveIns, targetClauseOps);
targetOp = genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
targetClauseOps);
genTeamsOp(doLoop.getLoc(), rewriter);
}

Expand Down
87 changes: 87 additions & 0 deletions flang/test/Transforms/DoConcurrent/loop_nest_test.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
! Tests loop-nest detection algorithm for do-concurrent mapping.

! REQUIRES: asserts

! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host \
! RUN: -mmlir -debug %s -o - 2> %t.log || true

! RUN: FileCheck %s < %t.log

program main
implicit none

contains

subroutine foo(n)
implicit none
integer :: n, m
integer :: i, j, k
integer :: x
integer, dimension(n) :: a
integer, dimension(n, n, n) :: b

! CHECK: Loop pair starting at location
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
do concurrent(i=1:n, j=1:bar(n*m, n/m))
a(i) = n
end do

! CHECK: Loop pair starting at location
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m))
a(i) = n
end do

! CHECK: Loop pair starting at location
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
do concurrent(i=bar(n, x):n)
do concurrent(j=1:bar(n*m, n/m))
a(i) = n
end do
end do

! CHECK: Loop pair starting at location
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
do concurrent(i=1:n)
x = 10
do concurrent(j=1:m)
b(i,j,k) = i * j + k
end do
end do

! CHECK: Loop pair starting at location
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
do concurrent(i=1:n)
do concurrent(j=1:m)
b(i,j,k) = i * j + k
end do
x = 10
end do

! CHECK: Loop pair starting at location
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
do concurrent(i=1:n)
do concurrent(j=1:m)
b(i,j,k) = i * j + k
x = 10
end do
end do

! CHECK: Loop pair starting at location
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m), k=1:bar(n*m, bar(n*m, n/m)))
a(i) = n
end do


end subroutine

pure function bar(n, m)
implicit none
integer, intent(in) :: n, m
integer :: bar

bar = n + m
end function

end program main
Loading

0 comments on commit 77ce502

Please sign in to comment.