@@ -99,10 +99,23 @@ class Util {
9999 // / \param Tmpl whether the class is template instantiation or simple record
100100 static bool isSyclType (const QualType &Ty, StringRef Name, bool Tmpl = false );
101101
102+ // / Checks whether given function is a standard SYCL API function with given
103+ // / name.
104+ // / \param FD the function being checked.
105+ // / \param Name the function name to be checked against.
106+ static bool isSyclFunction (const FunctionDecl *FD, StringRef Name);
107+
102108 // / Checks whether given clang type is a full specialization of the SYCL
103109 // / specialization constant class.
104110 static bool isSyclSpecConstantType (const QualType &Ty);
105111
112+ // Checks declaration context hierarchy.
113+ // / \param DC the context of the item to be checked.
114+ // / \param Scopes the declaration scopes leading from the item context to the
115+ // / translation unit (excluding the latter)
116+ static bool matchContext (const DeclContext *DC,
117+ ArrayRef<Util::DeclContextDesc> Scopes);
118+
106119 // / Checks whether given clang type is declared in the given hierarchy of
107120 // / declaration contexts.
108121 // / \param Ty the clang type being checked
@@ -165,38 +178,14 @@ static bool IsSyclMathFunc(unsigned BuiltinID) {
165178 case Builtin::BI__builtin_truncl:
166179 case Builtin::BIlroundl:
167180 case Builtin::BI__builtin_lroundl:
168- case Builtin::BIcopysign:
169- case Builtin::BI__builtin_copysign:
170- case Builtin::BIfloor:
171- case Builtin::BI__builtin_floor:
172181 case Builtin::BIfmax:
173182 case Builtin::BI__builtin_fmax:
174183 case Builtin::BIfmin:
175184 case Builtin::BI__builtin_fmin:
176- case Builtin::BInearbyint:
177- case Builtin::BI__builtin_nearbyint:
178- case Builtin::BIrint:
179- case Builtin::BI__builtin_rint:
180- case Builtin::BIround:
181- case Builtin::BI__builtin_round:
182- case Builtin::BItrunc:
183- case Builtin::BI__builtin_trunc:
184- case Builtin::BIcopysignf:
185- case Builtin::BI__builtin_copysignf:
186- case Builtin::BIfloorf:
187- case Builtin::BI__builtin_floorf:
188185 case Builtin::BIfmaxf:
189186 case Builtin::BI__builtin_fmaxf:
190187 case Builtin::BIfminf:
191188 case Builtin::BI__builtin_fminf:
192- case Builtin::BInearbyintf:
193- case Builtin::BI__builtin_nearbyintf:
194- case Builtin::BIrintf:
195- case Builtin::BI__builtin_rintf:
196- case Builtin::BIroundf:
197- case Builtin::BI__builtin_roundf:
198- case Builtin::BItruncf:
199- case Builtin::BI__builtin_truncf:
200189 case Builtin::BIlroundf:
201190 case Builtin::BI__builtin_lroundf:
202191 case Builtin::BI__builtin_fpclassify:
@@ -511,6 +500,21 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
511500 FunctionDecl *FD = WorkList.back ().first ;
512501 FunctionDecl *ParentFD = WorkList.back ().second ;
513502
503+ // To implement rounding-up of a parallel-for range the
504+ // SYCL header implementation modifies the kernel call like this:
505+ // auto Wrapper = [=](TransformedArgType Arg) {
506+ // if (Arg[0] >= NumWorkItems[0])
507+ // return;
508+ // Arg.set_allowed_range(NumWorkItems);
509+ // KernelFunc(Arg);
510+ // };
511+ //
512+ // This transformation leads to a condition where a kernel body
513+ // function becomes callable from a new kernel body function.
514+ // Hence this test.
515+ if ((ParentFD == KernelBody) && isSYCLKernelBodyFunction (FD))
516+ KernelBody = FD;
517+
514518 if ((ParentFD == SYCLKernel) && isSYCLKernelBodyFunction (FD)) {
515519 assert (!KernelBody && " inconsistent call graph - only one kernel body "
516520 " function can be called" );
@@ -2691,15 +2695,63 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
26912695 return !SemaRef.getASTContext ().hasSameType (FD->getType (), Ty);
26922696 }
26932697
2698+ // Sets a flag if the kernel is a parallel_for that calls the
2699+ // free function API "this_item".
2700+ void setThisItemIsCalled (const CXXRecordDecl *KernelObj,
2701+ FunctionDecl *KernelFunc) {
2702+ if (getKernelInvocationKind (KernelFunc) != InvokeParallelFor)
2703+ return ;
2704+
2705+ const CXXMethodDecl *WGLambdaFn = getOperatorParens (KernelObj);
2706+ if (!WGLambdaFn)
2707+ return ;
2708+
2709+ // The call graph for this translation unit.
2710+ CallGraph SYCLCG;
2711+ SYCLCG.addToCallGraph (SemaRef.getASTContext ().getTranslationUnitDecl ());
2712+ using ChildParentPair =
2713+ std::pair<const FunctionDecl *, const FunctionDecl *>;
2714+ llvm::SmallPtrSet<const FunctionDecl *, 16 > Visited;
2715+ llvm::SmallVector<ChildParentPair, 16 > WorkList;
2716+ WorkList.push_back ({WGLambdaFn, nullptr });
2717+
2718+ while (!WorkList.empty ()) {
2719+ const FunctionDecl *FD = WorkList.back ().first ;
2720+ WorkList.pop_back ();
2721+ if (!Visited.insert (FD).second )
2722+ continue ; // We've already seen this Decl
2723+
2724+ // Check whether this call is to sycl::this_item().
2725+ if (Util::isSyclFunction (FD, " this_item" )) {
2726+ Header.setCallsThisItem (true );
2727+ return ;
2728+ }
2729+
2730+ CallGraphNode *N = SYCLCG.getNode (FD);
2731+ if (!N)
2732+ continue ;
2733+
2734+ for (const CallGraphNode *CI : *N) {
2735+ if (auto *Callee = dyn_cast<FunctionDecl>(CI->getDecl ())) {
2736+ Callee = Callee->getMostRecentDecl ();
2737+ if (!Visited.count (Callee))
2738+ WorkList.push_back ({Callee, FD});
2739+ }
2740+ }
2741+ }
2742+ }
2743+
26942744public:
26952745 static constexpr const bool VisitInsideSimpleContainers = false ;
26962746 SyclKernelIntHeaderCreator (Sema &S, SYCLIntegrationHeader &H,
26972747 const CXXRecordDecl *KernelObj, QualType NameType,
2698- StringRef Name, StringRef StableName)
2748+ StringRef Name, StringRef StableName,
2749+ FunctionDecl *KernelFunc)
26992750 : SyclKernelFieldHandler(S), Header(H) {
27002751 bool IsSIMDKernel = isESIMDKernelType (KernelObj);
27012752 Header.startKernel (Name, NameType, StableName, KernelObj->getLocation (),
27022753 IsSIMDKernel);
2754+ setThisItemIsCalled (KernelObj, KernelFunc);
27032755 }
27042756
27052757 bool handleSyclAccessorType (const CXXRecordDecl *RD,
@@ -3147,7 +3199,7 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
31473199 SyclKernelIntHeaderCreator int_header (
31483200 *this , getSyclIntegrationHeader (), KernelObj,
31493201 calculateKernelNameType (Context, KernelCallerFunc), KernelName,
3150- StableName);
3202+ StableName, KernelCallerFunc );
31513203
31523204 KernelObjVisitor Visitor{*this };
31533205 Visitor.VisitRecordBases (KernelObj, kernel_decl, kernel_body, int_header);
@@ -3866,6 +3918,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
38663918 O << " __SYCL_DLL_LOCAL\n " ;
38673919 O << " static constexpr bool isESIMD() { return " << K.IsESIMDKernel
38683920 << " ; }\n " ;
3921+ O << " __SYCL_DLL_LOCAL\n " ;
3922+ O << " static constexpr bool callsThisItem() { return " ;
3923+ O << K.CallsThisItem << " ; }\n " ;
38693924 O << " };\n " ;
38703925 CurStart += N;
38713926 }
@@ -3924,6 +3979,12 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) {
39243979 SpecConsts.emplace_back (std::make_pair (IDType, IDName.str ()));
39253980}
39263981
3982+ void SYCLIntegrationHeader::setCallsThisItem (bool B) {
3983+ KernelDesc *K = getCurKernelDesc ();
3984+ assert (K && " no kernels" );
3985+ K->CallsThisItem = B;
3986+ }
3987+
39273988SYCLIntegrationHeader::SYCLIntegrationHeader (DiagnosticsEngine &_Diag,
39283989 bool _UnnamedLambdaSupport,
39293990 Sema &_S)
@@ -3991,6 +4052,21 @@ bool Util::isSyclType(const QualType &Ty, StringRef Name, bool Tmpl) {
39914052 return matchQualifiedTypeName (Ty, Scopes);
39924053}
39934054
4055+ bool Util::isSyclFunction (const FunctionDecl *FD, StringRef Name) {
4056+ if (!FD->isFunctionOrMethod () || !FD->getIdentifier () ||
4057+ FD->getName ().empty () || Name != FD->getName ())
4058+ return false ;
4059+
4060+ const DeclContext *DC = FD->getDeclContext ();
4061+ if (DC->isTranslationUnit ())
4062+ return false ;
4063+
4064+ std::array<DeclContextDesc, 2 > Scopes = {
4065+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " cl" },
4066+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " sycl" }};
4067+ return matchContext (DC, Scopes);
4068+ }
4069+
39944070bool Util::isAccessorPropertyListType (const QualType &Ty) {
39954071 const StringRef &Name = " accessor_property_list" ;
39964072 std::array<DeclContextDesc, 4 > Scopes = {
@@ -4001,21 +4077,15 @@ bool Util::isAccessorPropertyListType(const QualType &Ty) {
40014077 return matchQualifiedTypeName (Ty, Scopes);
40024078}
40034079
4004- bool Util::matchQualifiedTypeName (const QualType &Ty ,
4005- ArrayRef<Util::DeclContextDesc> Scopes) {
4006- // The idea: check the declaration context chain starting from the type
4080+ bool Util::matchContext (const DeclContext *Ctx ,
4081+ ArrayRef<Util::DeclContextDesc> Scopes) {
4082+ // The idea: check the declaration context chain starting from the item
40074083 // itself. At each step check the context is of expected kind
40084084 // (namespace) and name.
4009- const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
4010-
4011- if (!RecTy)
4012- return false ; // only classes/structs supported
4013- const auto *Ctx = cast<DeclContext>(RecTy);
40144085 StringRef Name = " " ;
40154086
40164087 for (const auto &Scope : llvm::reverse (Scopes)) {
40174088 clang::Decl::Kind DK = Ctx->getDeclKind ();
4018-
40194089 if (DK != Scope.first )
40204090 return false ;
40214091
@@ -4029,11 +4099,21 @@ bool Util::matchQualifiedTypeName(const QualType &Ty,
40294099 Name = cast<NamespaceDecl>(Ctx)->getName ();
40304100 break ;
40314101 default :
4032- llvm_unreachable (" matchQualifiedTypeName : decl kind not supported" );
4102+ llvm_unreachable (" matchContext : decl kind not supported" );
40334103 }
40344104 if (Name != Scope.second )
40354105 return false ;
40364106 Ctx = Ctx->getParent ();
40374107 }
40384108 return Ctx->isTranslationUnit ();
40394109}
4110+
4111+ bool Util::matchQualifiedTypeName (const QualType &Ty,
4112+ ArrayRef<Util::DeclContextDesc> Scopes) {
4113+ const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
4114+
4115+ if (!RecTy)
4116+ return false ; // only classes/structs supported
4117+ const auto *Ctx = cast<DeclContext>(RecTy);
4118+ return Util::matchContext (Ctx, Scopes);
4119+ }
0 commit comments