diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 48d6838f08fc11..cec8013e3aee8b 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1127,13 +1127,16 @@ def TargetOp : OpenMP_Op<"target", traits = [ let hasVerifier = 1; let extraClassDeclaration = [{ - /// Returns the innermost OpenMP dialect operation nested inside of this - /// operation's region. For an operation to be detected as captured, it must - /// be inside a (possibly multi-level) nest of OpenMP dialect operation's + /// Returns the innermost OpenMP dialect operation captured by this target + /// construct. For an operation to be detected as captured, it must be + /// inside a (possibly multi-level) nest of OpenMP dialect operation's /// regions where none of these levels contain other operations considered /// not-allowed for these purposes (i.e. only terminator operations are /// allowed from the OpenMP dialect, and other dialect's operations are /// allowed as long as they don't have a memory write effect). + /// + /// If there are omp.loop_nest operations in the sequence of nested + /// operations, the top level one will be the one captured. Operation *getInnermostCapturedOmpOp(); /// Tells whether this target region represents a single worksharing loop diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 00951017a2aef7..089c1e27147e13 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1520,6 +1520,13 @@ Operation *TargetOp::getInnermostCapturedOmpOp() { if (op == *this) return; + // Reset captured op if crossing through an omp.loop_nest, so that the top + // level one will be the one captured. + if (llvm::isa(op)) { + capturedOp = nullptr; + capturedParentRegion = nullptr; + } + bool isOmpDialect = op->getDialect() == ompDialect; bool hasRegions = op->getNumRegions() > 0; @@ -1563,21 +1570,11 @@ Operation *TargetOp::getInnermostCapturedOmpOp() { bool TargetOp::isTargetSPMDLoop() { Operation *capturedOp = getInnermostCapturedOmpOp(); - - // Allow an omp.atomic_update to be captured inside of the loop and still - // consider the parent omp.target operation to be potentially defining an SPMD - // loop. - // TODO: Potentially accept other captured OpenMP dialect operations as well, - // if they are allowed inside of an SPMD loop. - if (isa_and_present(capturedOp)) - capturedOp = capturedOp->getParentOp(); - if (!isa_and_present(capturedOp)) return false; - Operation *workshareOp = capturedOp->getParentOp(); - // Accept optional SIMD leaf construct. + Operation *workshareOp = capturedOp->getParentOp(); if (isa_and_present(workshareOp)) workshareOp = workshareOp->getParentOp();