Skip to content

Commit cfc0217

Browse files
committed
[SYCL] Fix specialization constants struct members
FE crashed on attempt to create initializer for struct with spec constant members because there was no initializers for spec const fields. Added default initialization for spec constants.
1 parent 4c57d4d commit cfc0217

File tree

4 files changed

+96
-20
lines changed

4 files changed

+96
-20
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,13 +1455,8 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
14551455
InitExprs.push_back(ILE);
14561456
}
14571457

1458-
void createSpecialMethodCall(const CXXRecordDecl *SpecialClass, Expr *Base,
1459-
const std::string &MethodName,
1460-
FieldDecl *Field) {
1461-
CXXMethodDecl *Method = getMethodByName(SpecialClass, MethodName);
1462-
assert(Method &&
1463-
"The accessor/sampler/stream must have the __init method. Stream"
1464-
" must also have __finalize method");
1458+
CXXMemberCallExpr *createSpecialMethodCall(Expr *Base, CXXMethodDecl *Method,
1459+
FieldDecl *Field) {
14651460
unsigned NumParams = Method->getNumParams();
14661461
llvm::SmallVector<Expr *, 4> ParamDREs(NumParams);
14671462
llvm::ArrayRef<ParmVarDecl *> KernelParameters =
@@ -1485,10 +1480,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
14851480
CXXMemberCallExpr *Call = CXXMemberCallExpr::Create(
14861481
SemaRef.Context, MethodME, ParamStmts, ResultTy, VK, SourceLocation(),
14871482
FPOptionsOverride());
1488-
if (MethodName == FinalizeMethodName)
1489-
FinalizeStmts.push_back(Call);
1490-
else
1491-
BodyStmts.push_back(Call);
1483+
return Call;
14921484
}
14931485

14941486
// FIXME Avoid creation of kernel obj clone.
@@ -1517,8 +1509,12 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
15171509
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None);
15181510
InitExprs.push_back(MemberInit.get());
15191511

1520-
createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName,
1521-
FD);
1512+
CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName);
1513+
if (InitMethod) {
1514+
CXXMemberCallExpr *InitCall =
1515+
createSpecialMethodCall(MemberExprBases.back(), InitMethod, FD);
1516+
BodyStmts.push_back(InitCall);
1517+
}
15221518
return true;
15231519
}
15241520

@@ -1535,8 +1531,12 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
15351531
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None);
15361532
InitExprs.push_back(MemberInit.get());
15371533

1538-
createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName,
1539-
nullptr);
1534+
CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName);
1535+
if (InitMethod) {
1536+
CXXMemberCallExpr *InitCall =
1537+
createSpecialMethodCall(MemberExprBases.back(), InitMethod, nullptr);
1538+
BodyStmts.push_back(InitCall);
1539+
}
15401540
return true;
15411541
}
15421542

@@ -1578,14 +1578,27 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
15781578
return handleSpecialType(FD, Ty);
15791579
}
15801580

1581+
bool handleSyclSpecConstantType(FieldDecl *FD, QualType Ty) final {
1582+
return handleSpecialType(FD, Ty);
1583+
}
1584+
15811585
bool handleSyclStreamType(FieldDecl *FD, QualType Ty) final {
15821586
const auto *StreamDecl = Ty->getAsCXXRecordDecl();
15831587
createExprForStructOrScalar(FD);
15841588
size_t NumBases = MemberExprBases.size();
1585-
createSpecialMethodCall(StreamDecl, MemberExprBases[NumBases - 2],
1586-
InitMethodName, FD);
1587-
createSpecialMethodCall(StreamDecl, MemberExprBases[NumBases - 2],
1588-
FinalizeMethodName, FD);
1589+
CXXMethodDecl *InitMethod = getMethodByName(StreamDecl, InitMethodName);
1590+
if (InitMethod) {
1591+
CXXMemberCallExpr *InitCall =
1592+
createSpecialMethodCall(MemberExprBases.back(), InitMethod, FD);
1593+
BodyStmts.push_back(InitCall);
1594+
}
1595+
CXXMethodDecl *FinalizeMethod =
1596+
getMethodByName(StreamDecl, FinalizeMethodName);
1597+
if (FinalizeMethod) {
1598+
CXXMemberCallExpr *FinalizeCall = createSpecialMethodCall(
1599+
MemberExprBases[NumBases - 2], FinalizeMethod, FD);
1600+
FinalizeStmts.push_back(FinalizeCall);
1601+
}
15891602
return true;
15901603
}
15911604

@@ -1796,7 +1809,7 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
17961809
cast<ClassTemplateSpecializationDecl>(FieldTy->getAsRecordDecl())
17971810
->getTemplateInstantiationArgs();
17981811
assert(TemplateArgs.size() == 2 &&
1799-
"Incorrect template args for Accessor Type");
1812+
"Incorrect template args for spec constant type");
18001813
// Get specialization constant ID type, which is the second template
18011814
// argument.
18021815
QualType SpecConstIDTy = TemplateArgs.get(1).getAsType().getCanonicalType();
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -ast-dump %s | FileCheck %s
2+
3+
// This test checks that compiler generates correct initialization for spec
4+
// constants
5+
6+
#include <sycl.hpp>
7+
8+
struct SpecConstantsWrapper{
9+
cl::sycl::experimental::spec_constant<int, class sc_name1> SC1;
10+
cl::sycl::experimental::spec_constant<int, class sc_name2> SC2;
11+
};
12+
13+
14+
int main() {
15+
cl::sycl::experimental::spec_constant<char, class MyInt32Const> SC;
16+
SpecConstantsWrapper W;
17+
cl::sycl::kernel_single_task<class kernel_sc>(
18+
[=]() {
19+
(void)SC;
20+
(void)W;
21+
});
22+
}
23+
24+
// CHECK: FunctionDecl {{.*}}kernel_sc 'void ()'
25+
// CHECK: VarDecl {{.*}}'(lambda at {{.*}}'
26+
// CHECK-NEXT: InitListExpr {{.*}}'(lambda at {{.*}}'
27+
// CHECK-NEXT: CXXConstructExpr {{.*}}'cl::sycl::experimental::spec_constant<char, class MyInt32Const>':'cl::sycl::experimental::spec_constant<char, MyInt32Const>'
28+
// CHECK-NEXT: InitListExpr {{.*}} 'SpecConstantsWrapper'
29+
// CHECK-NEXT: CXXConstructExpr {{.*}} 'cl::sycl::experimental::spec_constant<int, class sc_name1>':'cl::sycl::experimental::spec_constant<int, sc_name1>'
30+
// CHECK-NEXT: CXXConstructExpr {{.*}} 'cl::sycl::experimental::spec_constant<int, class sc_name2>':'cl::sycl::experimental::spec_constant<int, sc_name2>'

sycl/include/CL/sycl/experimental/spec_constant.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ template <typename T, typename ID = T> class spec_constant {
3232
private:
3333
// Implementation defined constructor.
3434
#ifdef __SYCL_DEVICE_ONLY__
35+
public:
3536
spec_constant() {}
37+
private:
3638
#else
3739
spec_constant(T Cst) : Val(Cst) {}
3840
#endif

sycl/test/spec_const/spec_const_hw.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ float foo(
3939
return f32;
4040
}
4141

42+
struct SCWrapper {
43+
SCWrapper(cl::sycl::program &p)
44+
: SC1(p.set_spec_constant<class sc_name1, int>(4)),
45+
SC2(p.set_spec_constant<class sc_name2, int>(2)) {}
46+
47+
cl::sycl::experimental::spec_constant<int, class sc_name1> SC1;
48+
cl::sycl::experimental::spec_constant<int, class sc_name2> SC2;
49+
};
50+
4251
int main(int argc, char **argv) {
4352
val = argc + 16;
4453

@@ -61,6 +70,7 @@ int main(int argc, char **argv) {
6170
std::cout << "val = " << val << "\n";
6271
cl::sycl::program program1(q.get_context());
6372
cl::sycl::program program2(q.get_context());
73+
cl::sycl::program program3(q.get_context());
6474

6575
int goldi = (int)get_value();
6676
// TODO make this floating point once supported by the compiler
@@ -77,11 +87,17 @@ int main(int argc, char **argv) {
7787
// SYCL RT execution path
7888
program2.build_with_kernel_type<KernelBBBf>("-cl-fast-relaxed-math");
7989

90+
SCWrapper W(program3);
91+
program3.build_with_kernel_type<class KernelWrappedSC>();
92+
int goldw = 6;
93+
8094
std::vector<int> veci(1);
8195
std::vector<float> vecf(1);
96+
std::vector<int> vecw(1);
8297
try {
8398
cl::sycl::buffer<int, 1> bufi(veci.data(), veci.size());
8499
cl::sycl::buffer<float, 1> buff(vecf.data(), vecf.size());
100+
cl::sycl::buffer<int, 1> bufw(vecw.data(), vecw.size());
85101

86102
q.submit([&](cl::sycl::handler &cgh) {
87103
auto acci = bufi.get_access<cl::sycl::access::mode::write>(cgh);
@@ -99,6 +115,15 @@ int main(int argc, char **argv) {
99115
accf[0] = foo(f32);
100116
});
101117
});
118+
119+
q.submit([&](cl::sycl::handler &cgh) {
120+
auto accw = bufw.get_access<cl::sycl::access::mode::write>(cgh);
121+
cgh.single_task<KernelWrappedSC>(
122+
program3.get_kernel<KernelWrappedSC>(),
123+
[=]() {
124+
accw[0] = W.SC1.get() + W.SC2.get();
125+
});
126+
});
102127
} catch (cl::sycl::exception &e) {
103128
std::cout << "*** Exception caught: " << e.what() << "\n";
104129
return 1;
@@ -116,6 +141,12 @@ int main(int argc, char **argv) {
116141
std::cout << "*** ERROR: " << valf << " != " << goldf << "(gold)\n";
117142
passed = false;
118143
}
144+
int valw = vecw[0];
145+
146+
if (valw != goldw) {
147+
std::cout << "*** ERROR: " << valw << " != " << goldw << "(gold)\n";
148+
passed = false;
149+
}
119150
std::cout << (passed ? "passed\n" : "FAILED\n");
120151
return passed ? 0 : 1;
121152
}

0 commit comments

Comments
 (0)