@@ -7266,7 +7266,7 @@ static void emitTargetCall(
7266
7266
const OpenMPIRBuilder::TargetKernelDefaultBounds &DefaultBounds,
7267
7267
const OpenMPIRBuilder::TargetKernelRuntimeBounds &RuntimeBounds,
7268
7268
Function *OutlinedFn, Constant *OutlinedFnID,
7269
- SmallVectorImpl<Value *> &Args,
7269
+ SmallVectorImpl<Value *> &Args, Value *IfCond,
7270
7270
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7271
7271
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
7272
7272
// Generate a function call to the host fallback implementation of the target
@@ -7283,9 +7283,7 @@ static void emitTargetCall(
7283
7283
bool HasDependencies = Dependencies.size() > 0;
7284
7284
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7285
7285
7286
- // If we don't have an ID for the target region, it means an offload entry
7287
- // wasn't created. In this case we just run the host fallback directly.
7288
- if (!OutlinedFnID) {
7286
+ auto &&EmitTargetCallElse = [&]() {
7289
7287
if (RequiresOuterTargetTask) {
7290
7288
// Arguments that are intended to be directly forwarded to an
7291
7289
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
@@ -7298,96 +7296,132 @@ static void emitTargetCall(
7298
7296
} else {
7299
7297
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
7300
7298
}
7301
- return;
7302
- }
7303
-
7304
- OpenMPIRBuilder::TargetDataInfo Info(
7305
- /*RequiresDevicePointerInfo=*/false,
7306
- /*SeparateBeginEndCalls=*/true);
7307
-
7308
- OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7309
- OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7310
- OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7311
- RTArgs, MapInfo,
7312
- /*IsNonContiguous=*/true,
7313
- /*ForEndCall=*/false);
7314
-
7315
- SmallVector<Value *, 3> NumTeamsC;
7316
- for (auto [DefNumTeams, RtNumTeams] :
7317
- llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) {
7318
- NumTeamsC.push_back(RtNumTeams ? RtNumTeams
7319
- : Builder.getInt32(DefNumTeams));
7320
- }
7321
-
7322
- // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7323
- // minimum between optional THREAD_LIMIT and MAX_THREADS clauses. Perform a
7324
- // type cast to uint32.
7325
- auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7326
- if (Clause)
7327
- Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7328
- /*isSigned=*/false);
7329
- return Clause;
7330
7299
};
7331
7300
7332
- auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7333
- if (Clause)
7334
- Result = Result
7335
- ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7301
+ auto &&EmitTargetCallThen = [&]() {
7302
+ OpenMPIRBuilder::TargetDataInfo Info(
7303
+ /*RequiresDevicePointerInfo=*/false,
7304
+ /*SeparateBeginEndCalls=*/true);
7305
+
7306
+ OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7307
+ OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7308
+ OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7309
+ RTArgs, MapInfo,
7310
+ /*IsNonContiguous=*/true,
7311
+ /*ForEndCall=*/false);
7312
+
7313
+ SmallVector<Value *, 3> NumTeamsC;
7314
+ for (auto [DefNumTeams, RtNumTeams] :
7315
+ llvm::zip_equal(DefaultBounds.MaxTeams, RuntimeBounds.MaxTeams)) {
7316
+ NumTeamsC.push_back(RtNumTeams ? RtNumTeams
7317
+ : Builder.getInt32(DefNumTeams));
7318
+ }
7319
+
7320
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is
7321
+ // the minimum between optional THREAD_LIMIT and MAX_THREADS clauses.
7322
+ // Perform a type cast to uint32.
7323
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7324
+ if (Clause)
7325
+ Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7326
+ /*isSigned=*/false);
7327
+ return Clause;
7328
+ };
7329
+
7330
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7331
+ if (Clause)
7332
+ Result =
7333
+ Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7336
7334
Result, Clause)
7337
7335
: Clause;
7338
- };
7336
+ };
7337
+
7338
+ // TODO: Check if this is the correct handling for multi-dim thread_limit.
7339
+ SmallVector<Value *, 3> NumThreadsC;
7340
+ Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads);
7339
7341
7340
- // TODO: Check if this is the correct handling for multi-dim thread_limit.
7341
- SmallVector<Value *, 3> NumThreadsC;
7342
- Value *MaxThreadsClause = InitMaxThreadsClause(RuntimeBounds.MaxThreads);
7342
+ for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal(
7343
+ RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) {
7344
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit);
7345
+ Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit);
7343
7346
7344
- for (auto [RtTeamsThreadLimit, RtTargetThreadLimit] : llvm::zip_equal(
7345
- RuntimeBounds.TeamsThreadLimit, RuntimeBounds.TargetThreadLimit)) {
7346
- Value *TeamsThreadLimitClause = InitMaxThreadsClause(RtTeamsThreadLimit);
7347
- Value *NumThreads = InitMaxThreadsClause(RtTargetThreadLimit);
7347
+ CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7348
+ CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7348
7349
7349
- CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7350
- CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7350
+ NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0) );
7351
+ }
7351
7352
7352
- NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7353
+ unsigned NumTargetItems = Info.NumberOfPtrs;
7354
+ // TODO: Use correct device ID
7355
+ Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7356
+ uint32_t SrcLocStrSize;
7357
+ Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7358
+ Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7359
+ llvm::omp::IdentFlag(0), 0);
7360
+
7361
+ Value *TripCount = RuntimeBounds.LoopTripCount
7362
+ ? Builder.CreateIntCast(RuntimeBounds.LoopTripCount,
7363
+ Builder.getInt64Ty(),
7364
+ /*isSigned=*/false)
7365
+ : Builder.getInt64(0);
7366
+
7367
+ // TODO: Use correct DynCGGroupMem
7368
+ Value *DynCGGroupMem = Builder.getInt32(0);
7369
+ OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount,
7370
+ NumTeamsC, NumThreadsC,
7371
+ DynCGGroupMem, HasNoWait);
7372
+
7373
+ // The presence of certain clauses on the target directive require the
7374
+ // explicit generation of the target task.
7375
+ if (RequiresOuterTargetTask) {
7376
+ Builder.restoreIP(OMPBuilder.emitTargetTask(
7377
+ OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7378
+ RTLoc, AllocaIP, Dependencies, HasNoWait));
7379
+ } else {
7380
+ Builder.restoreIP(OMPBuilder.emitKernelLaunch(
7381
+ Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
7382
+ DeviceID, RTLoc, AllocaIP));
7383
+ }
7384
+ };
7385
+
7386
+ // If we don't have an ID for the target region, it means an offload entry
7387
+ // wasn't created. In this case we just run the host fallback directly.
7388
+ if (!OutlinedFnID) {
7389
+ EmitTargetCallElse();
7390
+ return;
7353
7391
}
7354
7392
7355
- unsigned NumTargetItems = Info.NumberOfPtrs;
7356
- // TODO: Use correct device ID
7357
- Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7358
- uint32_t SrcLocStrSize;
7359
- Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7360
- Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7361
- llvm::omp::IdentFlag(0), 0);
7362
-
7363
- Value *TripCount = RuntimeBounds.LoopTripCount
7364
- ? Builder.CreateIntCast(RuntimeBounds.LoopTripCount,
7365
- Builder.getInt64Ty(),
7366
- /*isSigned=*/false)
7367
- : Builder.getInt64(0);
7368
-
7369
- // TODO: Use correct DynCGGroupMem
7370
- Value *DynCGGroupMem = Builder.getInt32(0);
7371
- OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, TripCount,
7372
- NumTeamsC, NumThreadsC, DynCGGroupMem,
7373
- HasNoWait);
7374
-
7375
- // The presence of certain clauses on the target directive require the
7376
- // explicit generation of the target task.
7377
- if (RequiresOuterTargetTask) {
7378
- Builder.restoreIP(OMPBuilder.emitTargetTask(
7379
- OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7380
- RTLoc, AllocaIP, Dependencies, HasNoWait));
7381
- } else {
7382
- Builder.restoreIP(OMPBuilder.emitKernelLaunch(
7383
- Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
7384
- DeviceID, RTLoc, AllocaIP));
7393
+ // If there's no IF clause, only generate the kernel launch code path.
7394
+ if (!IfCond) {
7395
+ EmitTargetCallThen();
7396
+ return;
7385
7397
}
7398
+
7399
+ // Create if-else to handle IF clause.
7400
+ llvm::BasicBlock *ThenBlock =
7401
+ BasicBlock::Create(Builder.getContext(), "omp_if.then");
7402
+ llvm::BasicBlock *ElseBlock =
7403
+ BasicBlock::Create(Builder.getContext(), "omp_if.else");
7404
+ llvm::BasicBlock *ContBlock =
7405
+ BasicBlock::Create(Builder.getContext(), "omp_if.end");
7406
+ Builder.CreateCondBr(IfCond, ThenBlock, ElseBlock);
7407
+
7408
+ Function *CurFn = Builder.GetInsertBlock()->getParent();
7409
+
7410
+ // Emit the 'then' code.
7411
+ OMPBuilder.emitBlock(ThenBlock, CurFn);
7412
+ EmitTargetCallThen();
7413
+ OMPBuilder.emitBranch(ContBlock);
7414
+ // Emit the 'else' code.
7415
+ OMPBuilder.emitBlock(ElseBlock, CurFn);
7416
+ EmitTargetCallElse();
7417
+ OMPBuilder.emitBranch(ContBlock);
7418
+ // Emit the continuation block.
7419
+ OMPBuilder.emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
7386
7420
}
7387
7421
7388
7422
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7389
7423
const LocationDescription &Loc, bool IsSPMD, bool IsOffloadEntry,
7390
- InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7424
+ Value *IfCond, InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7391
7425
TargetRegionEntryInfo &EntryInfo,
7392
7426
const TargetKernelDefaultBounds &DefaultBounds,
7393
7427
const TargetKernelRuntimeBounds &RuntimeBounds,
@@ -7415,7 +7449,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7415
7449
// that represents the target region. Do that now.
7416
7450
if (!Config.isTargetDevice())
7417
7451
emitTargetCall(*this, Builder, AllocaIP, DefaultBounds, RuntimeBounds,
7418
- OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies);
7452
+ OutlinedFn, OutlinedFnID, Args, IfCond, GenMapInfoCB,
7453
+ Dependencies);
7419
7454
return Builder.saveIP();
7420
7455
}
7421
7456
0 commit comments