diff --git a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp index 7fe27c947e549d1..dd3b01e793c0932 100644 --- a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp @@ -36,7 +36,8 @@ namespace fir { #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir -#define DEBUG_TYPE "fopenmp-do-concurrent-conversion" +#define DEBUG_TYPE "do-concurrent-conversion" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") namespace Fortran { namespace lower { @@ -297,6 +298,81 @@ void collectIndirectConstOpChain(mlir::Operation *link, opChain.insert(link); } +/// Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff +/// there are no operations in \p outerloop's other than: +/// +/// 1. those operations needed to setup \p innerLoop's LB, UB, and step values, +/// 2. the operations needed to assing/update \p outerLoop's induction variable. +/// 3. \p innerLoop itself. +/// +/// \p return true if \p innerLoop is perfectly nested inside \p outerLoop +/// according to the above definition. +bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) { + mlir::BackwardSliceOptions backwardSliceOptions; + backwardSliceOptions.inclusive = true; + // We will collect the backward slices for innerLoop's LB, UB, and step. + // However, we want to limit the scope of these slices to the scope of + // outerLoop's region. + backwardSliceOptions.filter = [&](mlir::Operation *op) { + return !mlir::areValuesDefinedAbove(op->getResults(), + outerLoop.getRegion()); + }; + + llvm::SetVector lbSlice; + mlir::getBackwardSlice(innerLoop.getLowerBound(), &lbSlice, + backwardSliceOptions); + + llvm::SetVector ubSlice; + mlir::getBackwardSlice(innerLoop.getUpperBound(), &ubSlice, + backwardSliceOptions); + + llvm::SetVector stepSlice; + mlir::getBackwardSlice(innerLoop.getStep(), &stepSlice, backwardSliceOptions); + + mlir::ForwardSliceOptions forwardSliceOptions; + forwardSliceOptions.inclusive = true; + // We don't care of the outer loop's induction variable's uses within the + // inner loop, so we filter out these uses. + forwardSliceOptions.filter = [&](mlir::Operation *op) { + return mlir::areValuesDefinedAbove(op->getResults(), innerLoop.getRegion()); + }; + + llvm::SetVector indVarSlice; + mlir::getForwardSlice(outerLoop.getInductionVar(), &indVarSlice, + forwardSliceOptions); + + llvm::SetVector innerLoopSetupOpsVec; + innerLoopSetupOpsVec.set_union(indVarSlice); + innerLoopSetupOpsVec.set_union(lbSlice); + innerLoopSetupOpsVec.set_union(ubSlice); + innerLoopSetupOpsVec.set_union(stepSlice); + llvm::DenseSet innerLoopSetupOpsSet; + + for (mlir::Operation *op : innerLoopSetupOpsVec) + innerLoopSetupOpsSet.insert(op); + + llvm::DenseSet loopBodySet; + outerLoop.walk([&](mlir::Operation *op) { + if (op == outerLoop) + return mlir::WalkResult::advance(); + + if (op == innerLoop) + return mlir::WalkResult::skip(); + + if (op->hasTrait()) + return mlir::WalkResult::advance(); + + loopBodySet.insert(op); + return mlir::WalkResult::advance(); + }); + + bool result = (loopBodySet == innerLoopSetupOpsSet); + LLVM_DEBUG(DBGS() << "Loop pair starting at location " << outerLoop.getLoc() + << " is" << (result ? "" : " not") + << " perfectly nested\n"); + return result; +} + /// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This /// function collects as much as possible loops in the nest; it case it fails to /// recognize a certain nested loop as part of the nest it just returns the @@ -337,57 +413,7 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop, llvm::SmallVector nestedLiveIns; collectLoopLiveIns(nestedUnorderedLoop, nestedLiveIns); - llvm::DenseSet outerLiveInsSet; - llvm::DenseSet nestedLiveInsSet; - - // Returns a "unified" view of an mlir::Value. This utility checks if the - // value is defined by an op, and if so, return the first value defined by - // that op (if there are many), otherwise just returns the value. - // - // This serves the purpose that if, for example, `%op_res#0` is used in the - // outer loop and `%op_res#1` is used in the nested loop (or vice versa), - // that we detect both as the same value. If we did not do so, we might - // falesely detect that the 2 loops are not perfectly nested since they use - // "different" sets of values. - auto getUnifiedLiveInView = [](mlir::Value liveIn) { - return liveIn.getDefiningOp() != nullptr - ? liveIn.getDefiningOp()->getResult(0) - : liveIn; - }; - - // Re-package both lists of live-ins into sets so that we can use set - // equality to compare the values used in the outerloop vs. the nestd one. - - for (auto liveIn : nestedLiveIns) - nestedLiveInsSet.insert(getUnifiedLiveInView(liveIn)); - - mlir::Value outerLoopIV; - for (auto liveIn : outerLoopLiveIns) { - outerLiveInsSet.insert(getUnifiedLiveInView(liveIn)); - - // Keep track of the IV of the outerloop. See `isPerfectlyNested` for more - // info on the reason. - if (outerLoopIV == nullptr) - outerLoopIV = getUnifiedLiveInView(liveIn); - } - - // For the 2 loops to be perfectly nested, either: - // * both would have exactly the same set of live-in values or, - // * the outer loop would have exactly 1 extra live-in value: the outer - // loop's induction variable; this happens when the outer loop's IV is - // *not* referenced in the nested loop. - bool isPerfectlyNested = [&]() { - if (outerLiveInsSet == nestedLiveInsSet) - return true; - - if ((outerLiveInsSet.size() == nestedLiveIns.size() + 1) && - !nestedLiveInsSet.contains(outerLoopIV)) - return true; - - return false; - }(); - - if (!isPerfectlyNested) + if (!isPerfectlyNested(outerLoop, nestedUnorderedLoop)) return mlir::failure(); outerLoop = nestedUnorderedLoop; diff --git a/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 b/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 new file mode 100644 index 000000000000000..a0463ea6fdb6011 --- /dev/null +++ b/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 @@ -0,0 +1,89 @@ +! Tests loop-nest detection algorithm for do-concurrent mapping. + +! REQUIRES: asserts + +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host \ +! RUN: -mmlir -debug %s -o - 2> %t.log || true + +! RUN: FileCheck %s < %t.log + +program main + implicit none + +contains + +subroutine foo(n) + implicit none + integer :: n, m + integer :: i, j, k + integer :: x + integer, dimension(n) :: a + integer, dimension(n, n, n) :: b + + ! NOTE This for sure is a perfect loop nest. However, the way `do-concurrent` + ! loops are now emitted by flang is probably not correct. This is being looked + ! into at the moment and once we have flang emitting proper loop headers, we + ! will revisit this. + ! + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is not perfectly nested + do concurrent(i=1:n, j=1:bar(n*m, n/m)) + a(i) = n + end do + + ! NOTE same as above. + ! + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is not perfectly nested + do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m)) + a(i) = n + end do + + ! NOTE This is **not** a perfect nest since the inner call to `bar` will allocate + ! memory for the temp results of `n*m` and `n/m` **inside** the outer loop. + ! + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is not perfectly nested + do concurrent(i=bar(n, x):n) + do concurrent(j=1:bar(n*m, n/m)) + a(i) = n + end do + end do + + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is not perfectly nested + do concurrent(i=1:n) + x = 10 + do concurrent(j=1:m) + b(i,j,k) = i * j + k + end do + end do + + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is not perfectly nested + do concurrent(i=1:n) + do concurrent(j=1:m) + b(i,j,k) = i * j + k + end do + x = 10 + end do + + ! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}}) + ! CHECK-SAME: is perfectly nested + do concurrent(i=1:n) + do concurrent(j=1:m) + b(i,j,k) = i * j + k + x = 10 + end do + end do +end subroutine + +pure function bar(n, m) + implicit none + integer, intent(in) :: n, m + integer :: bar + + bar = n + m +end function + +end program main