Skip to content

Commit 748aa57

Browse files
authored
[MLIR][OpenMP][OMPIRBuilder] Support omp.target IF translation to LLVM IR (#157)
This patch adds missing MLIR to LLVM IR translation support for the `if` clause on `omp.target` operations. This completes the missing piece for Fortran support of `!$omp target if(...)`. The implementation updates `emitTargetCall()` in the OMPIRBuilder to follow clang's support for the `if` clause in `CGOpenMPRuntime::emitTargetCall()`.
1 parent 8688adb commit 748aa57

File tree

4 files changed

+131
-95
lines changed

4 files changed

+131
-95
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -2871,6 +2871,7 @@ class OpenMPIRBuilder {
28712871
/// \param Loc where the target data construct was encountered.
28722872
/// \param IsSPMD whether this is an SPMD target launch.
28732873
/// \param IsOffloadEntry whether it is an offload entry.
2874+
/// \param IfCond value of the IF clause for the TARGET construct or nullptr.
28742875
/// \param CodeGenIP The insertion point where the call to the outlined
28752876
/// function should be emitted.
28762877
/// \param EntryInfo The entry information about the function.
@@ -2884,7 +2885,7 @@ class OpenMPIRBuilder {
28842885
/// \param Dependencies A vector of DependData objects that carry
28852886
// dependency information as passed in the depend clause.
28862887
InsertPointTy createTarget(const LocationDescription &Loc, bool IsSPMD,
2887-
bool IsOffloadEntry,
2888+
bool IsOffloadEntry, Value *IfCond,
28882889
OpenMPIRBuilder::InsertPointTy AllocaIP,
28892890
OpenMPIRBuilder::InsertPointTy CodeGenIP,
28902891
TargetRegionEntryInfo &EntryInfo,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

+115-80
Original file line numberDiff line numberDiff line change
@@ -7266,7 +7266,7 @@ static void emitTargetCall(
72667266
const OpenMPIRBuilder::TargetKernelDefaultBounds &DefaultBounds,
72677267
const OpenMPIRBuilder::TargetKernelRuntimeBounds &RuntimeBounds,
72687268
Function *OutlinedFn, Constant *OutlinedFnID,
7269-
SmallVectorImpl<Value *> &Args,
7269+
SmallVectorImpl<Value *> &Args, Value *IfCond,
72707270
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
72717271
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
72727272
// Generate a function call to the host fallback implementation of the target
@@ -7283,9 +7283,7 @@ static void emitTargetCall(
72837283
bool HasDependencies = Dependencies.size() > 0;
72847284
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
72857285

7286-
// If we don't have an ID for the target region, it means an offload entry
7287-
// wasn't created. In this case we just run the host fallback directly.
7288-
if (!OutlinedFnID) {
7286+
auto &&EmitTargetCallElse = [&]() {
72897287
if (RequiresOuterTargetTask) {
72907288
// Arguments that are intended to be directly forwarded to an
72917289
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
@@ -7298,96 +7296,132 @@ static void emitTargetCall(
72987296
} else {
72997297
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
73007298
}
7301-
return;
7302-
}
7303-
7304-
OpenMPIRBuilder::TargetDataInfo Info(
7305-
/*RequiresDevicePointerInfo=*/false,
7306-
/*SeparateBeginEndCalls=*/true);
7307-
7308-
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7309-
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7310-
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7311-
RTArgs, MapInfo,
7312-
/*IsNonContiguous=*/true,
7313-
/*ForEndCall=*/false);
7314-
7315-
SmallVector<Value *, 3> NumTeamsC;
7316-
for (auto [DefNumTeams, RtNumTeams] :
7317-
llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) {
7318-
NumTeamsC.push_back(RtNumTeams ? RtNumTeams
7319-
: Builder.getInt32(DefNumTeams));
7320-
}
7321-
7322-
// Calculate number of threads: 0 if no clauses specified, otherwise it is the
7323-
// minimum between optional THREAD_LIMIT and MAX_THREADS clauses. Perform a
7324-
// type cast to uint32.
7325-
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7326-
if (Clause)
7327-
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7328-
/*isSigned=*/false);
7329-
return Clause;
73307299
};
73317300

7332-
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7333-
if (Clause)
7334-
Result = Result
7335-
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7301+
auto &&EmitTargetCallThen = [&]() {
7302+
OpenMPIRBuilder::TargetDataInfo Info(
7303+
/*RequiresDevicePointerInfo=*/false,
7304+
/*SeparateBeginEndCalls=*/true);
7305+
7306+
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7307+
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7308+
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7309+
RTArgs, MapInfo,
7310+
/*IsNonContiguous=*/true,
7311+
/*ForEndCall=*/false);
7312+
7313+
SmallVector<Value *, 3> NumTeamsC;
7314+
for (auto [DefNumTeams, RtNumTeams] :
7315+
llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) {
7316+
NumTeamsC.push_back(RtNumTeams ? RtNumTeams
7317+
: Builder.getInt32(DefNumTeams));
7318+
}
7319+
7320+
// Calculate number of threads: 0 if no clauses specified, otherwise it is
7321+
// the minimum between optional THREAD_LIMIT and MAX_THREADS clauses.
7322+
// Perform a type cast to uint32.
7323+
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7324+
if (Clause)
7325+
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7326+
/*isSigned=*/false);
7327+
return Clause;
7328+
};
7329+
7330+
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7331+
if (Clause)
7332+
Result =
7333+
Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
73367334
Result, Clause)
73377335
: Clause;
7338-
};
7336+
};
7337+
7338+
// TODO: Check if this is the correct handling for multi-dim thread_limit.
7339+
SmallVector<Value *, 3> NumThreadsC;
7340+
Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads);
73397341

7340-
// TODO: Check if this is the correct handling for multi-dim thread_limit.
7341-
SmallVector<Value *, 3> NumThreadsC;
7342-
Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads);
7342+
for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal(
7343+
RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) {
7344+
Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit);
7345+
Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit);
73437346

7344-
for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal(
7345-
RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) {
7346-
Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit);
7347-
Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit);
7347+
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7348+
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
73487349

7349-
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7350-
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7350+
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7351+
}
73517352

7352-
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7353+
unsigned NumTargetItems = Info.NumberOfPtrs;
7354+
// TODO: Use correct device ID
7355+
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7356+
uint32_t SrcLocStrSize;
7357+
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7358+
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7359+
llvm::omp::IdentFlag(0), 0);
7360+
7361+
Value *TripCount = RuntimeBounds.LoopTripCount
7362+
? Builder.CreateIntCast(RuntimeBounds.LoopTripCount,
7363+
Builder.getInt64Ty(),
7364+
/*isSigned=*/false)
7365+
: Builder.getInt64(0);
7366+
7367+
// TODO: Use correct DynCGGroupMem
7368+
Value *DynCGGroupMem = Builder.getInt32(0);
7369+
OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount,
7370+
NumTeamsC, NumThreadsC,
7371+
DynCGGroupMem, HasNoWait);
7372+
7373+
// The presence of certain clauses on the target directive require the
7374+
// explicit generation of the target task.
7375+
if (RequiresOuterTargetTask) {
7376+
Builder.restoreIP(OMPBuilder.emitTargetTask(
7377+
OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7378+
RTLoc, AllocaIP, Dependencies, HasNoWait));
7379+
} else {
7380+
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
7381+
Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
7382+
DeviceID, RTLoc, AllocaIP));
7383+
}
7384+
};
7385+
7386+
// If we don't have an ID for the target region, it means an offload entry
7387+
// wasn't created. In this case we just run the host fallback directly.
7388+
if (!OutlinedFnID) {
7389+
EmitTargetCallElse();
7390+
return;
73537391
}
73547392

7355-
unsigned NumTargetItems = Info.NumberOfPtrs;
7356-
// TODO: Use correct device ID
7357-
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7358-
uint32_t SrcLocStrSize;
7359-
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7360-
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7361-
llvm::omp::IdentFlag(0), 0);
7362-
7363-
Value *TripCount = RuntimeBounds.LoopTripCount
7364-
? Builder.CreateIntCast(RuntimeBounds.LoopTripCount,
7365-
Builder.getInt64Ty(),
7366-
/*isSigned=*/false)
7367-
: Builder.getInt64(0);
7368-
7369-
// TODO: Use correct DynCGGroupMem
7370-
Value *DynCGGroupMem = Builder.getInt32(0);
7371-
OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount,
7372-
NumTeamsC, NumThreadsC, DynCGGroupMem,
7373-
HasNoWait);
7374-
7375-
// The presence of certain clauses on the target directive require the
7376-
// explicit generation of the target task.
7377-
if (RequiresOuterTargetTask) {
7378-
Builder.restoreIP(OMPBuilder.emitTargetTask(
7379-
OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7380-
RTLoc, AllocaIP, Dependencies, HasNoWait));
7381-
} else {
7382-
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
7383-
Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
7384-
DeviceID, RTLoc, AllocaIP));
7393+
// If there's no IF clause, only generate the kernel launch code path.
7394+
if (!IfCond) {
7395+
EmitTargetCallThen();
7396+
return;
73857397
}
7398+
7399+
// Create if-else to handle IF clause.
7400+
llvm::BasicBlock *ThenBlock =
7401+
BasicBlock::Create(Builder.getContext(), "omp_if.then");
7402+
llvm::BasicBlock *ElseBlock =
7403+
BasicBlock::Create(Builder.getContext(), "omp_if.else");
7404+
llvm::BasicBlock *ContBlock =
7405+
BasicBlock::Create(Builder.getContext(), "omp_if.end");
7406+
Builder.CreateCondBr(IfCond, ThenBlock, ElseBlock);
7407+
7408+
Function *CurFn = Builder.GetInsertBlock()->getParent();
7409+
7410+
// Emit the 'then' code.
7411+
OMPBuilder.emitBlock(ThenBlock, CurFn);
7412+
EmitTargetCallThen();
7413+
OMPBuilder.emitBranch(ContBlock);
7414+
// Emit the 'else' code.
7415+
OMPBuilder.emitBlock(ElseBlock, CurFn);
7416+
EmitTargetCallElse();
7417+
OMPBuilder.emitBranch(ContBlock);
7418+
// Emit the continuation block.
7419+
OMPBuilder.emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
73867420
}
73877421

73887422
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
73897423
const LocationDescription &Loc, bool IsSPMD, bool IsOffloadEntry,
7390-
InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7424+
Value *IfCond, InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
73917425
TargetRegionEntryInfo &EntryInfo,
73927426
const TargetKernelDefaultBounds &DefaultBounds,
73937427
const TargetKernelRuntimeBounds &RuntimeBounds,
@@ -7415,7 +7449,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
74157449
// that represents the target region. Do that now.
74167450
if (!Config.isTargetDevice())
74177451
emitTargetCall(*this, Builder, AllocaIP, DefaultBounds, RuntimeBounds,
7418-
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies);
7452+
OutlinedFn, OutlinedFnID, Args, IfCond, GenMapInfoCB,
7453+
Dependencies);
74197454
return Builder.saveIP();
74207455
}
74217456

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -6017,9 +6017,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
60176017
RuntimeBounds.TeamsThreadLimit.push_back(nullptr);
60186018
RuntimeBounds.MaxTeams.push_back(nullptr);
60196019
Builder.restoreIP(OMPBuilder.createTarget(
6020-
OmpLoc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, Builder.saveIP(),
6021-
Builder.saveIP(), EntryInfo, DefaultBounds, RuntimeBounds, Inputs,
6022-
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
6020+
OmpLoc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr,
6021+
Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultBounds,
6022+
RuntimeBounds, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
60236023
OMPBuilder.finalize();
60246024
Builder.CreateRetVoid();
60256025

@@ -6134,9 +6134,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
61346134
RuntimeBounds.TeamsThreadLimit.push_back(nullptr);
61356135
RuntimeBounds.MaxTeams.push_back(nullptr);
61366136
Builder.restoreIP(OMPBuilder.createTarget(
6137-
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6138-
EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, GenMapInfoCB,
6139-
BodyGenCB, SimpleArgAccessorCB));
6137+
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr,
6138+
EntryIP, EntryIP, EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs,
6139+
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
61406140

61416141
Builder.CreateRetVoid();
61426142
OMPBuilder.finalize();
@@ -6290,9 +6290,9 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
62906290
RuntimeBounds.TeamsThreadLimit.push_back(nullptr);
62916291
RuntimeBounds.MaxTeams.push_back(nullptr);
62926292
Builder.restoreIP(OMPBuilder.createTarget(
6293-
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6294-
EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs, GenMapInfoCB,
6295-
BodyGenCB, SimpleArgAccessorCB));
6293+
Loc, /*IsSPMD=*/false, /*IsOffloadEntry=*/true, /*IfCond=*/nullptr,
6294+
EntryIP, EntryIP, EntryInfo, DefaultBounds, RuntimeBounds, CapturedArgs,
6295+
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
62966296

62976297
Builder.CreateRetVoid();
62986298
OMPBuilder.finalize();

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -3476,10 +3476,6 @@ static bool getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
34763476

34773477
static bool targetOpSupported(Operation &opInst) {
34783478
auto targetOp = cast<omp::TargetOp>(opInst);
3479-
if (targetOp.getIfExpr()) {
3480-
opInst.emitError("If clause not yet supported");
3481-
return false;
3482-
}
34833479

34843480
if (targetOp.getDevice()) {
34853481
opInst.emitError("Device clause not yet supported");
@@ -3955,8 +3951,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39553951
if (Value targetThreadLimit = targetOp.getThreadLimit())
39563952
llvmTargetThreadLimit = moduleTranslation.lookupValue(targetThreadLimit);
39573953

3954+
llvm::Value *ifCond = nullptr;
3955+
if (Value targetIfCond = targetOp.getIfExpr())
3956+
ifCond = moduleTranslation.lookupValue(targetIfCond);
3957+
39583958
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
3959-
ompLoc, targetOp.isTargetSPMDLoop(), isOffloadEntry, allocaIP,
3959+
ompLoc, targetOp.isTargetSPMDLoop(), isOffloadEntry, ifCond, allocaIP,
39603960
builder.saveIP(), entryInfo, defaultBounds, runtimeBounds, kernelInput,
39613961
genMapInfoCB, bodyCB, argAccessorCB, dds));
39623962

0 commit comments

Comments
 (0)