@@ -831,6 +831,9 @@ class KernelObjVisitor {
831831 else if (ElementTy->isStructureOrClassType ())
832832 VisitRecord (Owner, ArrayField, ElementTy->getAsCXXRecordDecl (),
833833 handlers...);
834+ else if (ElementTy->isUnionType ())
835+ VisitUnion (Owner, ArrayField, ElementTy->getAsCXXRecordDecl (),
836+ handlers...);
834837 else if (ElementTy->isArrayType ())
835838 VisitArrayElements (ArrayField, ElementTy, handlers...);
836839 else if (ElementTy->isScalarType ())
@@ -858,6 +861,65 @@ class KernelObjVisitor {
858861 void VisitRecord (const CXXRecordDecl *Owner, ParentTy &Parent,
859862 const CXXRecordDecl *Wrapper, Handlers &... handlers);
860863
864+ // Base case, only calls these when filtered.
865+ template <typename ... FilteredHandlers, typename ParentTy>
866+ std::enable_if_t <(sizeof ...(FilteredHandlers) > 0 )>
867+ VisitUnion (FilteredHandlers &... handlers, const CXXRecordDecl *Owner,
868+ ParentTy &Parent, const CXXRecordDecl *Wrapper) {
869+ (void )std::initializer_list<int >{
870+ (handlers.enterUnion (Owner, Parent), 0 )...};
871+ VisitRecordHelper (Wrapper, Wrapper->fields (), handlers...);
872+ (void )std::initializer_list<int >{
873+ (handlers.leaveUnion (Owner, Parent), 0 )...};
874+ }
875+
876+ // Handle empty base case.
877+ template <typename ParentTy>
878+ void VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
879+ const CXXRecordDecl *Wrapper) {}
880+
881+ template <typename ... FilteredHandlers, typename ParentTy,
882+ typename CurHandler, typename ... Handlers>
883+ std::enable_if_t <!CurHandler::VisitUnionBody &&
884+ (sizeof ...(FilteredHandlers) > 0 )>
885+ VisitUnion (FilteredHandlers &... filtered_handlers,
886+ const CXXRecordDecl *Owner, ParentTy &Parent,
887+ const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
888+ Handlers &... handlers) {
889+ VisitUnion<FilteredHandlers...>(filtered_handlers..., Owner, Parent,
890+ Wrapper, handlers...);
891+ }
892+
893+ template <typename ... FilteredHandlers, typename ParentTy,
894+ typename CurHandler, typename ... Handlers>
895+ std::enable_if_t <CurHandler::VisitUnionBody &&
896+ (sizeof ...(FilteredHandlers) > 0 )>
897+ VisitUnion (FilteredHandlers &... filtered_handlers,
898+ const CXXRecordDecl *Owner, ParentTy &Parent,
899+ const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
900+ Handlers &... handlers) {
901+ VisitUnion<FilteredHandlers..., CurHandler>(
902+ filtered_handlers..., cur_handler, Owner, Parent, Wrapper, handlers...);
903+ }
904+
905+ // Add overloads without having filtered-handlers
906+ // to handle leading-empty argument packs.
907+ template <typename ParentTy, typename CurHandler, typename ... Handlers>
908+ std::enable_if_t <!CurHandler::VisitUnionBody>
909+ VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
910+ const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
911+ Handlers &... handlers) {
912+ VisitUnion (Owner, Parent, Wrapper, handlers...);
913+ }
914+
915+ template <typename ParentTy, typename CurHandler, typename ... Handlers>
916+ std::enable_if_t <CurHandler::VisitUnionBody>
917+ VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
918+ const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
919+ Handlers &... handlers) {
920+ VisitUnion<CurHandler>(cur_handler, Owner, Parent, Wrapper, handlers...);
921+ }
922+
861923 template <typename ... Handlers>
862924 void VisitRecordHelper (const CXXRecordDecl *Owner,
863925 clang::CXXRecordDecl::base_class_const_range Range,
@@ -943,6 +1005,11 @@ class KernelObjVisitor {
9431005 CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl ();
9441006 VisitRecord (Owner, Field, RD, handlers...);
9451007 }
1008+ } else if (FieldTy->isUnionType ()) {
1009+ if (KF_FOR_EACH (handleUnionType, Field, FieldTy)) {
1010+ CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl ();
1011+ VisitUnion (Owner, Field, RD, handlers...);
1012+ }
9461013 } else if (FieldTy->isReferenceType ())
9471014 KF_FOR_EACH (handleReferenceType, Field, FieldTy);
9481015 else if (FieldTy->isPointerType ())
@@ -982,6 +1049,7 @@ class SyclKernelFieldHandler {
9821049 SyclKernelFieldHandler (Sema &S) : SemaRef(S) {}
9831050
9841051public:
1052+ static constexpr const bool VisitUnionBody = false ;
9851053 // Mark these virtual so that we can use override in the implementer classes,
9861054 // despite virtual dispatch never being used.
9871055
@@ -1006,6 +1074,7 @@ class SyclKernelFieldHandler {
10061074 }
10071075 virtual bool handleSyclHalfType (FieldDecl *, QualType) { return true ; }
10081076 virtual bool handleStructType (FieldDecl *, QualType) { return true ; }
1077+ virtual bool handleUnionType (FieldDecl *, QualType) { return true ; }
10091078 virtual bool handleReferenceType (FieldDecl *, QualType) { return true ; }
10101079 virtual bool handlePointerType (FieldDecl *, QualType) { return true ; }
10111080 virtual bool handleArrayType (FieldDecl *, QualType) { return true ; }
@@ -1025,6 +1094,8 @@ class SyclKernelFieldHandler {
10251094 virtual bool leaveStruct (const CXXRecordDecl *, const CXXBaseSpecifier &) {
10261095 return true ;
10271096 }
1097+ virtual bool enterUnion (const CXXRecordDecl *, FieldDecl *) { return true ; }
1098+ virtual bool leaveUnion (const CXXRecordDecl *, FieldDecl *) { return true ; }
10281099
10291100 // The following are used for stepping through array elements.
10301101
@@ -1047,7 +1118,6 @@ class SyclKernelFieldHandler {
10471118class SyclKernelFieldChecker : public SyclKernelFieldHandler {
10481119 bool IsInvalid = false ;
10491120 DiagnosticsEngine &Diag;
1050-
10511121 // Check whether the object should be disallowed from being copied to kernel.
10521122 // Return true if not copyable, false if copyable.
10531123 bool checkNotCopyableToKernel (const FieldDecl *FD, const QualType &FieldTy) {
@@ -1202,6 +1272,65 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
12021272 }
12031273};
12041274
1275+ // A type to check the validity of accessing accessor/sampler/stream
1276+ // types as kernel parameters inside union.
1277+ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
1278+ int UnionCount = 0 ;
1279+ bool IsInvalid = false ;
1280+ DiagnosticsEngine &Diag;
1281+
1282+ public:
1283+ SyclKernelUnionChecker (Sema &S)
1284+ : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
1285+ bool isValid () { return !IsInvalid; }
1286+ static constexpr const bool VisitUnionBody = true ;
1287+
1288+ bool checkType (SourceLocation Loc, QualType Ty) {
1289+ if (UnionCount) {
1290+ IsInvalid = true ;
1291+ Diag.Report (Loc, diag::err_bad_union_kernel_param_members) << Ty;
1292+ }
1293+ return isValid ();
1294+ }
1295+
1296+ bool enterUnion (const CXXRecordDecl *RD, FieldDecl *FD) {
1297+ ++UnionCount;
1298+ return true ;
1299+ }
1300+
1301+ bool leaveUnion (const CXXRecordDecl *RD, FieldDecl *FD) {
1302+ --UnionCount;
1303+ return true ;
1304+ }
1305+
1306+ bool handleSyclAccessorType (FieldDecl *FD, QualType FieldTy) final {
1307+ return checkType (FD->getLocation (), FieldTy);
1308+ }
1309+
1310+ bool handleSyclAccessorType (const CXXBaseSpecifier &BS,
1311+ QualType FieldTy) final {
1312+ return checkType (BS.getBeginLoc (), FieldTy);
1313+ }
1314+
1315+ bool handleSyclSamplerType (FieldDecl *FD, QualType FieldTy) final {
1316+ return checkType (FD->getLocation (), FieldTy);
1317+ }
1318+
1319+ bool handleSyclSamplerType (const CXXBaseSpecifier &BS,
1320+ QualType FieldTy) final {
1321+ return checkType (BS.getBeginLoc (), FieldTy);
1322+ }
1323+
1324+ bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
1325+ return checkType (FD->getLocation (), FieldTy);
1326+ }
1327+
1328+ bool handleSyclStreamType (const CXXBaseSpecifier &BS,
1329+ QualType FieldTy) final {
1330+ return checkType (BS.getBeginLoc (), FieldTy);
1331+ }
1332+ };
1333+
12051334// A type to Create and own the FunctionDecl for the kernel.
12061335class SyclKernelDeclCreator : public SyclKernelFieldHandler {
12071336 FunctionDecl *KernelDecl;
@@ -1453,6 +1582,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
14531582 return true ;
14541583 }
14551584
1585+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
1586+ return handleScalarType (FD, FieldTy);
1587+ }
1588+
14561589 bool handleSyclHalfType (FieldDecl *FD, QualType FieldTy) final {
14571590 addParam (FD, FieldTy);
14581591 return true ;
@@ -1805,6 +1938,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
18051938 return true ;
18061939 }
18071940
1941+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
1942+ return handleScalarType (FD, FieldTy);
1943+ }
1944+
18081945 bool enterStruct (const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
18091946 CXXCastPath BasePath;
18101947 QualType DerivedTy (RD->getTypeForDecl (), 0 );
@@ -2012,6 +2149,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
20122149 return true ;
20132150 }
20142151
2152+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
2153+ return handleScalarType (FD, FieldTy);
2154+ }
2155+
20152156 bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
20162157 addParam (FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
20172158 return true ;
@@ -2105,14 +2246,14 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
21052246 }
21062247 }
21072248
2108- SyclKernelFieldChecker Checker (*this );
2109-
2249+ SyclKernelFieldChecker FieldChecker (*this );
2250+ SyclKernelUnionChecker UnionChecker (* this );
21102251 KernelObjVisitor Visitor{*this };
21112252 DiagnosingSYCLKernel = true ;
2112- Visitor.VisitRecordBases (KernelObj, Checker );
2113- Visitor.VisitRecordFields (KernelObj, Checker );
2253+ Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker );
2254+ Visitor.VisitRecordFields (KernelObj, FieldChecker, UnionChecker );
21142255 DiagnosingSYCLKernel = false ;
2115- if (!Checker .isValid ())
2256+ if (!FieldChecker. isValid () || !UnionChecker .isValid ())
21162257 KernelFunc->setInvalidDecl ();
21172258}
21182259
0 commit comments