Skip to content

Commit

Permalink
[MLIR][OpenMP] Correctly handle branching within target captures (#217)
Browse files Browse the repository at this point in the history
This patch improves the detection of captured OpenMP constructs inside of an
`omp.target` operation by also considering potential branches. If a nested
OpenMP construct can be executed in a loop or optionally omitted by means of
explicit MLIR control flow, then it's not supposed to be captured.

The following Fortran example results in such a case:
```f90
!$omp target teams
do i = 1, n
  !$omp distribute parallel do
  do j = 1, n
    ...
  end do
  !$omp end distribute parallel do
end do
!$omp end target teams
```

The result of lowering that code to MLIR is the creation of multiple blocks and
branches inside of the `omp.teams` operation's region. Without this change, it
is identified as an SPMD kernel during translation to LLVM IR due to the nesting
of operations, but it is a generic kernel, so it causes a compiler crash. This
is because it tries to get host-evaluated loop bounds that do not exist to
calculate the inner loop's trip count.

```mlir
omp.target map_entries(...) {
  ...
  omp.teams {
    %233 = llvm.trunc %227 : i64 to i32
    llvm.br ^bb1(%233, %225 : i32, i64)
  ^bb1(%234: i32, %235: i64):  // 2 preds: ^bb0, ^bb2
    %236 = llvm.icmp "sgt" %235, %226 : i64
    llvm.cond_br %236, ^bb2, ^bb3
  ^bb2:  // pred: ^bb1
    llvm.store %234, %arg5 : i32, !llvm.ptr
    omp.parallel ... {
      ...
    } {omp.composite}
    ...
    llvm.br ^bb1(%239, %240 : i32, i64)
  ^bb3:  // pred: ^bb1
    llvm.store %234, %arg5 : i32, !llvm.ptr
    omp.terminator
  }
  omp.terminator
}
```
  • Loading branch information
skatrak authored Dec 3, 2024
1 parent ae842a7 commit df7e436
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,7 @@ LogicalResult TargetOp::verify() {
Operation *TargetOp::getInnermostCapturedOmpOp() {
Dialect *ompDialect = (*this)->getDialect();
Operation *capturedOp = nullptr;
DominanceInfo domInfo;

// Process in pre-order to check operations from outermost to innermost,
// ensuring we only enter the region of an operation if it meets the criteria
Expand All @@ -1787,9 +1788,25 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
if (!isOmpDialect || !hasRegions)
return WalkResult::skip();

// This operation cannot be captured if it can be executed more than once
// (i.e. its block's successors can reach it) or if it's not guaranteed to
// be executed before all exits of the region (i.e. it doesn't dominate all
// blocks with no successors reachable from the entry block).
Region *parentRegion = op->getParentRegion();
Block *parentBlock = op->getBlock();

for (Block *successor : parentBlock->getSuccessors())
if (successor->isReachable(parentBlock))
return WalkResult::interrupt();

for (Block &block : *parentRegion)
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
!domInfo.dominates(parentBlock, &block))
return WalkResult::interrupt();

// Don't capture this op if it has a not-allowed sibling, and stop recursing
// into nested operations.
for (Operation &sibling : op->getParentRegion()->getOps())
for (Operation &sibling : parentRegion->getOps())
if (&sibling != op && !siblingAllowedInCapture(&sibling))
return WalkResult::interrupt();

Expand Down

0 comments on commit df7e436

Please sign in to comment.