-
Notifications
You must be signed in to change notification settings - Fork 16.1k
[HLSL] Add WaveActiveBallot builtins and lower to DXIL / SPIR-V #174638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-backend-x86 Author: Joshua Batista (bob80905) ChangesThis PR adds WaveActiveBallot as a builtin function to HLSL. Full diff: https://github.com/llvm/llvm-project/pull/174638.diff 12 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index a5be656cdacf4..5199578e8f666 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5055,6 +5055,12 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool(bool)";
}
+def HLSLWaveActiveBallot : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_wave_active_ballot"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "bool(bool)";
+}
+
def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 317e64d595243..1b6c3714f7821 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -829,6 +829,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
+ case Builtin::BI__builtin_hlsl_wave_active_ballot: {
+ Value *Op = EmitScalarExpr(E->getArg(0));
+ assert(Op->getType()->isIntegerTy(1) &&
+ "Intrinsic WaveActiveBallot operand must be a bool");
+
+ Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBallotIntrinsic();
+ return EmitRuntimeCall(
+ Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+ }
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ba2ca2c358388..7a5643052ed84 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -146,6 +146,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBallot, wave_ballot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index b065e5dd8447f..a1fafb40f2c6c 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2410,6 +2410,18 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_any_true)
__attribute__((convergent)) bool WaveActiveAnyTrue(bool Val);
+/// \brief Returns a uint4 containing a bitmask of the evaluation of the
+/// boolean expression for all active lanes in the current wave.
+/// The least-significant bit corresponds to the lane with index zero.
+/// The bits corresponding to inactive lanes will be zero. The bits that
+/// are greater than or equal to WaveGetLaneCount will be zero.
+///
+/// \param Val The boolean expression to evaluate.
+/// \return uint4 bitmask
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_ballot)
+__attribute__((convergent)) bool WaveActiveBallot(bool Val);
+
/// \brief Counts the number of boolean variables which evaluate to true across
/// all active lanes in the current wave.
///
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
new file mode 100644
index 0000000000000..61b077eb1fead
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call for int values.
+
+// CHECK-LABEL: define {{.*}}test
+uint4 test(bool p1) {
+ // CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func <4 x i32> @llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
+ // CHECK-DXIL: %[[RET:.*]] = call <4 x i32> @llvm.dx.wave.ballot(i1 %{{[a-zA-Z0-9]+}})
+ // CHECK: ret <4 x i32> %[[RET]]
+ return WaveActiveBallot(p1);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveBallot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveBallot-errors.hlsl
new file mode 100644
index 0000000000000..ae39068494864
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveBallot-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_wave_active_ballot();
+ // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+bool test_too_many_arg(bool p0) {
+ return __builtin_hlsl_wave_active_ballot(p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+struct Foo
+{
+ int a;
+};
+
+bool test_type_check(Foo p0) {
+ return __builtin_hlsl_wave_active_ballot(p0);
+ // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 8ca93731ffa04..6e6eb2d0ece9d 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -153,6 +153,7 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 402235ec7cd9c..0e23d19a01c20 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -120,6 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+ def int_spv_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_min : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index b221fa2d7fe87..d84fa4a94f489 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1071,6 +1071,14 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}
+def WaveActiveBallot : DXILOp<118, waveAnyTrue> {
+ let Doc = "returns uint4 containing a bitmask of the evaluation of the boolean expression for all active lanes in the current wave.";
+ let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
+ let arguments = [Int1Ty];
+ let result = OverloadTy;
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
def WaveActiveOp : DXILOp<119, waveActiveOp> {
let Doc = "returns the result of the operation across waves";
let intrinsics = [
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f991938c14dfe..c403b297b57ee 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3803,6 +3803,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
case Intrinsic::spv_wave_any:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
+ case Intrinsic::spv_wave_ballot:
+ return selectWaveOpInst(ResVReg, ResType, I,
+ SPIRV::OpGroupNonUniformBallot);
case Intrinsic::spv_wave_is_first_lane:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect);
case Intrinsic::spv_wave_reduce_umax:
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
new file mode 100644
index 0000000000000..5c3f8fc2d7643
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define noundef <4 x i32> @wave_ballot_simple(i1 noundef %p1) {
+entry:
+; CHECK: call <4 x i32> @dx.op.waveAnyTrue.void(i32 118, i1 %p1)
+ %ret = call <4 x i32> @llvm.dx.wave.ballot(i1 %p1)
+ ret <4 x i32> %ret
+}
+
+declare <4 x i32> @llvm.dx.wave.ballot(i1)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
new file mode 100644
index 0000000000000..6831888f038fd
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
@@ -0,0 +1,22 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#bool:]] = OpTypeBool
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
+; CHECK-DAG: %[[#bitmask:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG: OpCapability GroupNonUniformBallot
+
+; CHECK-LABEL: Begin function test_wave_ballot
+define <4 x i32> @test_wave_ballot(i1 %p1) #0 {
+entry:
+; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
+; CHECK: %{{.+}} = OpGroupNonUniformBallot %[[#bitmask]] %[[#scope]] %[[#param]]
+ %0 = call token @llvm.experimental.convergence.entry()
+ %ret = call <4 x i32> @llvm.spv.wave.ballot(i1 %p1) [ "convergencectrl"(token %0) ]
+ ret <4 x i32> %ret
+}
+
+declare <4 x i32> @llvm.spv.wave.ballot(i1) #0
+
+attributes #0 = { convergent }
|
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
…ct into add_waveactiveballot
|
you are linking the wrong ticket should be this one: #99163 |
…#174638) This PR adds WaveActiveBallot as a builtin function to HLSL. Fixes llvm#99163
This PR adds WaveActiveBallot as a builtin function to HLSL.
Fixes #99163