Skip to content

Commit

Permalink
remove check for temp allocs
Browse files Browse the repository at this point in the history
  • Loading branch information
ergawy committed Aug 22, 2024
1 parent efbda3c commit 8df3bac
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 59 deletions.
57 changes: 1 addition & 56 deletions flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,70 +341,15 @@ bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
mlir::getForwardSlice(outerLoop.getInductionVar(), &indVarSlice,
forwardSliceOptions);

// If any of the bounds is computed using a somewhat elaborate expression, we
// might have to allocate temporaries within the loop-nest. For example,
// `foo(m*n, m/n)` would allocate memory for the results of the resutls of
// `m*n` and `m/n` inside the loop and pass the results to `foo`.
//
// The mem alloc ops for these temp values will be part of the backward slices
// of the loop bounds. Howerver, the corresponding mem free ops will not.
// Howerver, we want to find these corresponding mem free ops and them to the
// set of loop setup ops since they would still be emitted in a perfectly
// nested loop.
//
// If an op has a MemAlloc effect, search the region for its corresponding
// MemFree op.
auto findMemFreeOp = [](mlir::Operation *possibleMemAllocOp) {
mlir::Operation *correspondingMemFreeOp = nullptr;

// We only care about ops which have `Allocate` effect on memory.
if (!mlir::hasEffect<mlir::MemoryEffects::Allocate>(possibleMemAllocOp))
return correspondingMemFreeOp;

mlir::Region *region = possibleMemAllocOp->getParentRegion();
assert(region);

region->walk([&](mlir::Operation *possibleMemFreeOp) {
// We only search for ops with `Free` effect on memory.
if (!mlir::hasEffect<mlir::MemoryEffects::Free>(possibleMemFreeOp))
return mlir::WalkResult::advance();

for (auto result : possibleMemAllocOp->getResults()) {
for (auto operand : possibleMemFreeOp->getOperands()) {
// If the result of a mem alloc op is the operand of a mem free op,
// then they probably correspond to each other.
if (result == operand) {
if (correspondingMemFreeOp == nullptr) {
correspondingMemFreeOp = possibleMemFreeOp;
break;
}
}
}

if (correspondingMemFreeOp != nullptr)
break;
}

return mlir::WalkResult::advance();
});

return correspondingMemFreeOp;
};

llvm::SetVector<mlir::Operation *> innerLoopSetupOpsVec;
innerLoopSetupOpsVec.set_union(indVarSlice);
innerLoopSetupOpsVec.set_union(lbSlice);
innerLoopSetupOpsVec.set_union(ubSlice);
innerLoopSetupOpsVec.set_union(stepSlice);
llvm::DenseSet<mlir::Operation *> innerLoopSetupOpsSet;

for (mlir::Operation *op : innerLoopSetupOpsVec) {
for (mlir::Operation *op : innerLoopSetupOpsVec)
innerLoopSetupOpsSet.insert(op);
mlir::Operation *memFreeOp = findMemFreeOp(op);

if (memFreeOp != nullptr)
innerLoopSetupOpsSet.insert(memFreeOp);
}

llvm::DenseSet<mlir::Operation *> loopBodySet;
outerLoop.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
Expand Down
16 changes: 13 additions & 3 deletions flang/test/Transforms/DoConcurrent/loop_nest_test.f90
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,30 @@ subroutine foo(n)
integer, dimension(n) :: a
integer, dimension(n, n, n) :: b

! NOTE This is for sure is a perfect loop nest. Howerver, the way `do-concurrent`
! loops are now emitted by flang is probably not correct. This is being looked
! into at the moment and once we have flang emitting proper loop headers, we
! will revisit this.
!
! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}})
! CHECK-SAME: is perfectly nested
! CHECK-SAME: is not perfectly nested
do concurrent(i=1:n, j=1:bar(n*m, n/m))
a(i) = n
end do

! NOTE same as above.
!
! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}})
! CHECK-SAME: is perfectly nested
! CHECK-SAME: is not perfectly nested
do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m))
a(i) = n
end do

! NOTE this is **not** a perfect nest since the inner call to `bar` will allocate
! memory for the temp results of `n*m` and `n/m` **inside** the outer loop.
!
! CHECK: Loop pair starting at location loc("{{.*}}":[[# @LINE + 2]]:{{.*}})
! CHECK-SAME: is perfectly nested
! CHECK-SAME: is not perfectly nested
do concurrent(i=bar(n, x):n)
do concurrent(j=1:bar(n*m, n/m))
a(i) = n
Expand Down

0 comments on commit 8df3bac

Please sign in to comment.