Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -5055,6 +5055,12 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool(bool)";
}

def HLSLWaveActiveBallot : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_ballot"];
let Attributes = [NoThrow, Const];
let Prototype = "_ExtVector<4, unsigned int>(bool)";
}

def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
let Attributes = [NoThrow, Const];
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
case Builtin::BI__builtin_hlsl_wave_active_ballot: {
Value *Op = EmitScalarExpr(E->getArg(0));
assert(Op->getType()->isIntegerTy(1) &&
"Intrinsic WaveActiveBallot operand must be a bool");

Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBallotIntrinsic();
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBallot, wave_ballot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2410,6 +2410,18 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_any_true)
__attribute__((convergent)) bool WaveActiveAnyTrue(bool Val);

/// \brief Returns a uint4 containing a bitmask of the evaluation of the
/// boolean expression for all active lanes in the current wave.
/// The least-significant bit corresponds to the lane with index zero.
/// The bits corresponding to inactive lanes will be zero. The bits that
/// are greater than or equal to WaveGetLaneCount will be zero.
///
/// \param Val The boolean expression to evaluate.
/// \return uint4 bitmask
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_ballot)
__attribute__((convergent)) uint4 WaveActiveBallot(bool Val);

/// \brief Counts the number of boolean variables which evaluate to true across
/// all active lanes in the current wave.
///
Expand Down
17 changes: 17 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call for int values.

// CHECK-LABEL: define {{.*}}test
uint4 test(bool p1) {
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
// CHECK-SPIRV: %[[RET:.*]] = call spir_func <4 x i32> @llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
// CHECK-DXIL: %[[RET:.*]] = call <4 x i32> @llvm.dx.wave.ballot(i1 %{{[a-zA-Z0-9]+}})
// CHECK: ret <4 x i32> %[[RET]]
return WaveActiveBallot(p1);
}
21 changes: 21 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveBallot-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

uint4 test_too_few_arg() {
return __builtin_hlsl_wave_active_ballot();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

uint4 test_too_many_arg(bool p0) {
return __builtin_hlsl_wave_active_ballot(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

struct Foo
{
int a;
};

uint4 test_type_check(Foo p0) {
return __builtin_hlsl_wave_active_ballot(p0);
// expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_min : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,14 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def WaveActiveBallot : DXILOp<118, waveAnyTrue> {
let Doc = "returns uint4 containing a bitmask of the evaluation of the boolean expression for all active lanes in the current wave.";
let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
let arguments = [Int1Ty];
let result = OverloadTy;
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def WaveActiveOp : DXILOp<119, waveActiveOp> {
let Doc = "returns the result of the operation across waves";
let intrinsics = [
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3803,6 +3803,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
case Intrinsic::spv_wave_any:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
case Intrinsic::spv_wave_ballot:
return selectWaveOpInst(ResVReg, ResType, I,
SPIRV::OpGroupNonUniformBallot);
case Intrinsic::spv_wave_is_first_lane:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect);
case Intrinsic::spv_wave_reduce_umax:
Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

define noundef <4 x i32> @wave_ballot_simple(i1 noundef %p1) {
entry:
; CHECK: call <4 x i32> @dx.op.waveAnyTrue.void(i32 118, i1 %p1)
%ret = call <4 x i32> @llvm.dx.wave.ballot(i1 %p1)
ret <4 x i32> %ret
}

declare <4 x i32> @llvm.dx.wave.ballot(i1)
22 changes: 22 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#bool:]] = OpTypeBool
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
; CHECK-DAG: %[[#bitmask:]] = OpTypeVector %[[#uint]] 4
; CHECK-DAG: OpCapability GroupNonUniformBallot

; CHECK-LABEL: Begin function test_wave_ballot
define <4 x i32> @test_wave_ballot(i1 %p1) #0 {
entry:
; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
; CHECK: %{{.+}} = OpGroupNonUniformBallot %[[#bitmask]] %[[#scope]] %[[#param]]
%0 = call token @llvm.experimental.convergence.entry()
%ret = call <4 x i32> @llvm.spv.wave.ballot(i1 %p1) [ "convergencectrl"(token %0) ]
ret <4 x i32> %ret
}

declare <4 x i32> @llvm.spv.wave.ballot(i1) #0

attributes #0 = { convergent }