Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Basic/BuiltinsSPIRVCommon.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ def length : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
def smoothstep : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;

def subgroup_ballot : SPIRVBuiltin<"_ExtVector<4, uint32_t>(bool)", [NoThrow, Const]>;
def subgroup_shuffle : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
Call->addRetAttr(llvm::Attribute::AttrKind::NoUndef);
return Call;
}
case SPIRV::BI__builtin_spirv_subgroup_shuffle: {
Value *X = EmitScalarExpr(E->getArg(0));
Value *Y = EmitScalarExpr(E->getArg(1));
assert(E->getArg(1)->getType()->hasIntegerRepresentation());
return Builder.CreateIntrinsic(
/*ReturnType=*/getTypes().ConvertType(E->getArg(0)->getType()),
Intrinsic::spv_wave_readlane, {X, Y}, nullptr, "spv.shuffle");
}
case SPIRV::BI__builtin_spirv_num_workgroups:
return Builder.CreateIntrinsic(
/*ReturnType=*/getTypes().ConvertType(E->getType()),
Expand Down
36 changes: 36 additions & 0 deletions clang/lib/Sema/SemaSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,42 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_subgroup_shuffle: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;

ExprResult A =
SemaRef.DefaultFunctionArrayLvalueConversion(TheCall->getArg(0));
if (A.isInvalid())
return true;
TheCall->setArg(0, A.get());

QualType ArgTyA = A.get()->getType();
if (!ArgTyA->isIntegerType() && !ArgTyA->isFloatingType()) {
SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< /* ordinal */ 1 << /* scalar */ 1 << /* no int */ 0
<< /* no fp */ 0 << ArgTyA;
return true;
}

ExprResult B =
SemaRef.DefaultFunctionArrayLvalueConversion(TheCall->getArg(1));
if (B.isInvalid())
return true;

QualType Uint32Ty =
SemaRef.getASTContext().getIntTypeForBitwidth(32,
/*Signed=*/false);
ExprResult ResB = SemaRef.PerformImplicitConversion(
B.get(), Uint32Ty, AssignmentAction::Passing);
if (ResB.isInvalid())
return true;
TheCall->setArg(1, ResB.get());

QualType RetTy = ArgTyA;
TheCall->setType(RetTy);
break;
}
}
return false;
}
Expand Down
12 changes: 10 additions & 2 deletions clang/test/CodeGenSPIRV/Builtins/subgroup.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@ typedef bool _Bool;
#endif
typedef unsigned __attribute__((ext_vector_type(4))) int4;

// CHECK: @{{.*}}test_subgroup_shuffle{{.*}}(
// CHECK: @{{.*}}test_subgroup_ballot{{.*}}(
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: tail call <4 x i32> @llvm.spv.wave.ballot(i1 %i)
[[clang::sycl_external]] int4 test_subgroup_shuffle(_Bool i) {
[[clang::sycl_external]] int4 test_subgroup_ballot(_Bool i) {
return __builtin_spirv_subgroup_ballot(i);
}

// CHECK: @{{.*}}test_subgroup_shuffle{{.*}}(
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: tail call float @llvm.spv.wave.readlane.f32(float %f, i32 %i)
//
[[clang::sycl_external]] float test_subgroup_shuffle(float f, int i) {
return __builtin_spirv_subgroup_shuffle(f, i);
}
12 changes: 12 additions & 0 deletions clang/test/SemaSPIRV/BuiltIns/subgroup-errors.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,15 @@ void ballot(_Bool c) {
x = __builtin_spirv_subgroup_ballot(x); // expected-error{{parameter of incompatible type}}
int y = __builtin_spirv_subgroup_ballot(c); // expected-error{{with an expression of incompatible type}}
}

void shuffle() {
int x = 0;
long long l = 0;
float f = 0;
int [[clang::ext_vector_type(1)]] v;
(void)__builtin_spirv_subgroup_shuffle(x, x);
(void)__builtin_spirv_subgroup_shuffle(f, f);
(void)__builtin_spirv_subgroup_shuffle(x, x, x); // expected-error{{too many arguments to function call, expected 2, have 3}}
(void)__builtin_spirv_subgroup_shuffle(v, f); // expected-error{{1st argument must be a scalar type}}
(void)__builtin_spirv_subgroup_shuffle(f, v); // expected-error{{to parameter of incompatible type}}
}