@@ -99,10 +99,23 @@ class Util {
9999 // / \param Tmpl whether the class is template instantiation or simple record
100100 static bool isSyclType (const QualType &Ty, StringRef Name, bool Tmpl = false );
101101
102+ // / Checks whether given function is a standard SYCL API function with given
103+ // / name.
104+ // / \param FD the function being checked.
105+ // / \param Name the function name to be checked against.
106+ static bool isSyclFunction (const FunctionDecl *FD, StringRef Name);
107+
102108 // / Checks whether given clang type is a full specialization of the SYCL
103109 // / specialization constant class.
104110 static bool isSyclSpecConstantType (const QualType &Ty);
105111
112+ // Checks declaration context hierarchy.
113+ // / \param DC the context of the item to be checked.
114+ // / \param Scopes the declaration scopes leading from the item context to the
115+ // / translation unit (excluding the latter)
116+ static bool matchContext (const DeclContext *DC,
117+ ArrayRef<Util::DeclContextDesc> Scopes);
118+
106119 // / Checks whether given clang type is declared in the given hierarchy of
107120 // / declaration contexts.
108121 // / \param Ty the clang type being checked
@@ -487,6 +500,21 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
487500 FunctionDecl *FD = WorkList.back ().first ;
488501 FunctionDecl *ParentFD = WorkList.back ().second ;
489502
503+ // To implement rounding-up of a parallel-for range the
504+ // SYCL header implementation modifies the kernel call like this:
505+ // auto Wrapper = [=](TransformedArgType Arg) {
506+ // if (Arg[0] >= NumWorkItems[0])
507+ // return;
508+ // Arg.set_allowed_range(NumWorkItems);
509+ // KernelFunc(Arg);
510+ // };
511+ //
512+ // This transformation leads to a condition where a kernel body
513+ // function becomes callable from a new kernel body function.
514+ // Hence this test.
515+ if ((ParentFD == KernelBody) && isSYCLKernelBodyFunction (FD))
516+ KernelBody = FD;
517+
490518 if ((ParentFD == SYCLKernel) && isSYCLKernelBodyFunction (FD)) {
491519 assert (!KernelBody && " inconsistent call graph - only one kernel body "
492520 " function can be called" );
@@ -2667,15 +2695,63 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
26672695 return !SemaRef.getASTContext ().hasSameType (FD->getType (), Ty);
26682696 }
26692697
2698+ // Sets a flag if the kernel is a parallel_for that calls the
2699+ // free function API "this_item".
2700+ void setThisItemIsCalled (const CXXRecordDecl *KernelObj,
2701+ FunctionDecl *KernelFunc) {
2702+ if (getKernelInvocationKind (KernelFunc) != InvokeParallelFor)
2703+ return ;
2704+
2705+ const CXXMethodDecl *WGLambdaFn = getOperatorParens (KernelObj);
2706+ if (!WGLambdaFn)
2707+ return ;
2708+
2709+ // The call graph for this translation unit.
2710+ CallGraph SYCLCG;
2711+ SYCLCG.addToCallGraph (SemaRef.getASTContext ().getTranslationUnitDecl ());
2712+ using ChildParentPair =
2713+ std::pair<const FunctionDecl *, const FunctionDecl *>;
2714+ llvm::SmallPtrSet<const FunctionDecl *, 16 > Visited;
2715+ llvm::SmallVector<ChildParentPair, 16 > WorkList;
2716+ WorkList.push_back ({WGLambdaFn, nullptr });
2717+
2718+ while (!WorkList.empty ()) {
2719+ const FunctionDecl *FD = WorkList.back ().first ;
2720+ WorkList.pop_back ();
2721+ if (!Visited.insert (FD).second )
2722+ continue ; // We've already seen this Decl
2723+
2724+ // Check whether this call is to sycl::this_item().
2725+ if (Util::isSyclFunction (FD, " this_item" )) {
2726+ Header.setCallsThisItem (true );
2727+ return ;
2728+ }
2729+
2730+ CallGraphNode *N = SYCLCG.getNode (FD);
2731+ if (!N)
2732+ continue ;
2733+
2734+ for (const CallGraphNode *CI : *N) {
2735+ if (auto *Callee = dyn_cast<FunctionDecl>(CI->getDecl ())) {
2736+ Callee = Callee->getMostRecentDecl ();
2737+ if (!Visited.count (Callee))
2738+ WorkList.push_back ({Callee, FD});
2739+ }
2740+ }
2741+ }
2742+ }
2743+
26702744public:
26712745 static constexpr const bool VisitInsideSimpleContainers = false ;
26722746 SyclKernelIntHeaderCreator (Sema &S, SYCLIntegrationHeader &H,
26732747 const CXXRecordDecl *KernelObj, QualType NameType,
2674- StringRef Name, StringRef StableName)
2748+ StringRef Name, StringRef StableName,
2749+ FunctionDecl *KernelFunc)
26752750 : SyclKernelFieldHandler(S), Header(H) {
26762751 bool IsSIMDKernel = isESIMDKernelType (KernelObj);
26772752 Header.startKernel (Name, NameType, StableName, KernelObj->getLocation (),
26782753 IsSIMDKernel);
2754+ setThisItemIsCalled (KernelObj, KernelFunc);
26792755 }
26802756
26812757 bool handleSyclAccessorType (const CXXRecordDecl *RD,
@@ -3123,7 +3199,7 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
31233199 SyclKernelIntHeaderCreator int_header (
31243200 *this , getSyclIntegrationHeader (), KernelObj,
31253201 calculateKernelNameType (Context, KernelCallerFunc), KernelName,
3126- StableName);
3202+ StableName, KernelCallerFunc );
31273203
31283204 KernelObjVisitor Visitor{*this };
31293205 Visitor.VisitRecordBases (KernelObj, kernel_decl, kernel_body, int_header);
@@ -3842,6 +3918,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
38423918 O << " __SYCL_DLL_LOCAL\n " ;
38433919 O << " static constexpr bool isESIMD() { return " << K.IsESIMDKernel
38443920 << " ; }\n " ;
3921+ O << " __SYCL_DLL_LOCAL\n " ;
3922+ O << " static constexpr bool callsThisItem() { return " ;
3923+ O << K.CallsThisItem << " ; }\n " ;
38453924 O << " };\n " ;
38463925 CurStart += N;
38473926 }
@@ -3900,6 +3979,12 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) {
39003979 SpecConsts.emplace_back (std::make_pair (IDType, IDName.str ()));
39013980}
39023981
3982+ void SYCLIntegrationHeader::setCallsThisItem (bool B) {
3983+ KernelDesc *K = getCurKernelDesc ();
3984+ assert (K && " no kernels" );
3985+ K->CallsThisItem = B;
3986+ }
3987+
39033988SYCLIntegrationHeader::SYCLIntegrationHeader (DiagnosticsEngine &_Diag,
39043989 bool _UnnamedLambdaSupport,
39053990 Sema &_S)
@@ -3967,6 +4052,21 @@ bool Util::isSyclType(const QualType &Ty, StringRef Name, bool Tmpl) {
39674052 return matchQualifiedTypeName (Ty, Scopes);
39684053}
39694054
4055+ bool Util::isSyclFunction (const FunctionDecl *FD, StringRef Name) {
4056+ if (!FD->isFunctionOrMethod () || !FD->getIdentifier () ||
4057+ FD->getName ().empty () || Name != FD->getName ())
4058+ return false ;
4059+
4060+ const DeclContext *DC = FD->getDeclContext ();
4061+ if (DC->isTranslationUnit ())
4062+ return false ;
4063+
4064+ std::array<DeclContextDesc, 2 > Scopes = {
4065+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " cl" },
4066+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " sycl" }};
4067+ return matchContext (DC, Scopes);
4068+ }
4069+
39704070bool Util::isAccessorPropertyListType (const QualType &Ty) {
39714071 const StringRef &Name = " accessor_property_list" ;
39724072 std::array<DeclContextDesc, 4 > Scopes = {
@@ -3977,21 +4077,15 @@ bool Util::isAccessorPropertyListType(const QualType &Ty) {
39774077 return matchQualifiedTypeName (Ty, Scopes);
39784078}
39794079
3980- bool Util::matchQualifiedTypeName (const QualType &Ty ,
3981- ArrayRef<Util::DeclContextDesc> Scopes) {
3982- // The idea: check the declaration context chain starting from the type
4080+ bool Util::matchContext (const DeclContext *Ctx ,
4081+ ArrayRef<Util::DeclContextDesc> Scopes) {
4082+ // The idea: check the declaration context chain starting from the item
39834083 // itself. At each step check the context is of expected kind
39844084 // (namespace) and name.
3985- const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
3986-
3987- if (!RecTy)
3988- return false ; // only classes/structs supported
3989- const auto *Ctx = cast<DeclContext>(RecTy);
39904085 StringRef Name = " " ;
39914086
39924087 for (const auto &Scope : llvm::reverse (Scopes)) {
39934088 clang::Decl::Kind DK = Ctx->getDeclKind ();
3994-
39954089 if (DK != Scope.first )
39964090 return false ;
39974091
@@ -4005,11 +4099,21 @@ bool Util::matchQualifiedTypeName(const QualType &Ty,
40054099 Name = cast<NamespaceDecl>(Ctx)->getName ();
40064100 break ;
40074101 default :
4008- llvm_unreachable (" matchQualifiedTypeName : decl kind not supported" );
4102+ llvm_unreachable (" matchContext : decl kind not supported" );
40094103 }
40104104 if (Name != Scope.second )
40114105 return false ;
40124106 Ctx = Ctx->getParent ();
40134107 }
40144108 return Ctx->isTranslationUnit ();
40154109}
4110+
4111+ bool Util::matchQualifiedTypeName (const QualType &Ty,
4112+ ArrayRef<Util::DeclContextDesc> Scopes) {
4113+ const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
4114+
4115+ if (!RecTy)
4116+ return false ; // only classes/structs supported
4117+ const auto *Ctx = cast<DeclContext>(RecTy);
4118+ return Util::matchContext (Ctx, Scopes);
4119+ }
0 commit comments