Skip to content

Commit 8c3fa7b

Browse files
committed
subgroup: add subgroup_broadcast_const variant for pre-spv1.5 broadcasts
1 parent 9fc1a5f commit 8c3fa7b

File tree

7 files changed

+106
-5
lines changed

7 files changed

+106
-5
lines changed

crates/spirv-std/src/arch/subgroup.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,9 @@ pub fn subgroup_all_equal<T: ScalarOrVector>(value: T) -> bool {
281281
///
282282
/// # Safety
283283
/// * `id` must be dynamically uniform
284-
/// * before 1.5: `id` must be constant
285284
/// * Result is undefined if `id` is an inactive invocation or out of bounds
285+
/// * This variant with a dynamic `id` requires at least `spv1.5` or `vulkan1.2`. Alternatively, you can use
286+
/// [`subgroup_broadcast_const`] with a constant `id`.
286287
#[spirv_std_macros::gpu_only]
287288
#[doc(alias = "OpGroupNonUniformBroadcast")]
288289
#[inline]
@@ -307,6 +308,48 @@ pub unsafe fn subgroup_broadcast<T: ScalarOrVector>(value: T, id: u32) -> T {
307308
result
308309
}
309310

311+
/// Result is the `value` of the invocation identified by the id `id` to all active invocations in the group.
312+
///
313+
/// Result Type must be a scalar or vector of floating-point type, integer type, or Boolean type.
314+
///
315+
/// Execution is a Scope that identifies the group of invocations affected by this command. It must be Subgroup.
316+
///
317+
/// The type of `value` must be the same as Result Type.
318+
///
319+
/// `id` must be a scalar of integer type, whose Signedness operand is 0.
320+
///
321+
/// Before version 1.5, `id` must come from a constant instruction. Starting with version 1.5, this restriction is lifted. However, behavior is undefined when `id` is not dynamically uniform.
322+
///
323+
/// The resulting value is undefined if `id` is an inactive invocation, or is greater than or equal to the size of the group.
324+
///
325+
/// Requires Capability `GroupNonUniformBallot`.
326+
///
327+
/// # Safety
328+
/// * Result is undefined if `id` is an inactive invocation or out of bounds
329+
#[spirv_std_macros::gpu_only]
330+
#[doc(alias = "OpGroupNonUniformBroadcast")]
331+
#[inline]
332+
pub unsafe fn subgroup_broadcast_const<T: ScalarOrVector, const ID: u32>(value: T) -> T {
333+
let mut result = T::default();
334+
335+
unsafe {
336+
asm! {
337+
"%u32 = OpTypeInt 32 0",
338+
"%subgroup = OpConstant %u32 {subgroup}",
339+
"%id = OpConstant %u32 {id}",
340+
"%value = OpLoad _ {value}",
341+
"%result = OpGroupNonUniformBroadcast _ %subgroup %value %id",
342+
"OpStore {result} %result",
343+
subgroup = const SUBGROUP,
344+
value = in(reg) &value,
345+
id = const ID,
346+
result = in(reg) &mut result,
347+
}
348+
}
349+
350+
result
351+
}
352+
310353
/// Result is the `value` of the invocation from the active invocation with the lowest id in the group to all active invocations in the group.
311354
///
312355
/// Result Type must be a scalar or vector of floating-point type, integer type, or Boolean type.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot,+ext:SPV_KHR_vulkan_memory_model
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_broadcast::disassembly
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
// ignore-vulkan1.0
6+
// ignore-vulkan1.1
7+
// ignore-spv1.0
8+
// ignore-spv1.1
9+
// ignore-spv1.2
10+
// ignore-spv1.3
11+
// ignore-spv1.4
12+
13+
use spirv_std::arch::{GroupOperation, SubgroupMask};
14+
use spirv_std::spirv;
15+
16+
unsafe fn disassembly(value: i32, id: u32) -> i32 {
17+
spirv_std::arch::subgroup_broadcast(value, id)
18+
}
19+
20+
#[spirv(compute(threads(32, 1, 1)))]
21+
pub fn main() {
22+
unsafe {
23+
disassembly(42, 5);
24+
}
25+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
%1 = OpFunction %2 None %3
2+
%4 = OpFunctionParameter %2
3+
%5 = OpFunctionParameter %6
4+
%7 = OpLabel
5+
%9 = OpGroupNonUniformBroadcast %2 %10 %4 %5
6+
OpNoLine
7+
OpReturnValue %9
8+
OpFunctionEnd
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot,+ext:SPV_KHR_vulkan_memory_model
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_broadcast_const::disassembly
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
6+
use spirv_std::arch::{GroupOperation, SubgroupMask};
7+
use spirv_std::spirv;
8+
9+
unsafe fn disassembly(value: i32) -> i32 {
10+
spirv_std::arch::subgroup_broadcast_const::<_, 5>(value)
11+
}
12+
13+
#[spirv(compute(threads(32, 1, 1)))]
14+
pub fn main() {
15+
unsafe {
16+
disassembly(-42);
17+
}
18+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
%1 = OpFunction %2 None %3
2+
%4 = OpFunctionParameter %2
3+
%5 = OpLabel
4+
%7 = OpGroupNonUniformBroadcast %2 %8 %4 %9
5+
OpNoLine
6+
OpReturnValue %7
7+
OpFunctionEnd

tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_0_fail.stderr

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
error[E0080]: evaluation panicked: `ClusterSize` must be at least 1
2-
--> $SPIRV_STD_SRC/arch/subgroup.rs:825:1
2+
--> $SPIRV_STD_SRC/arch/subgroup.rs:868:1
33
|
44
LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r"
55
LL | | An integer add group operation of all `value` operands contributed by active invocations in the group.
@@ -13,7 +13,7 @@ LL | | ");
1313
= note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info)
1414

1515
note: erroneous constant encountered
16-
--> $SPIRV_STD_SRC/arch/subgroup.rs:825:1
16+
--> $SPIRV_STD_SRC/arch/subgroup.rs:868:1
1717
|
1818
LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r"
1919
LL | | An integer add group operation of all `value` operands contributed by active invocations in the group.

tests/compiletests/ui/arch/subgroup/subgroup_cluster_size_non_power_of_two_fail.stderr

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
error[E0080]: evaluation panicked: `ClusterSize` must be a power of 2
2-
--> $SPIRV_STD_SRC/arch/subgroup.rs:825:1
2+
--> $SPIRV_STD_SRC/arch/subgroup.rs:868:1
33
|
44
LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r"
55
LL | | An integer add group operation of all `value` operands contributed by active invocations in the group.
@@ -13,7 +13,7 @@ LL | | ");
1313
= note: this error originates in the macro `$crate::panic::panic_2021` which comes from the expansion of the macro `macro_subgroup_op_clustered` (in Nightly builds, run with -Z macro-backtrace for more info)
1414

1515
note: erroneous constant encountered
16-
--> $SPIRV_STD_SRC/arch/subgroup.rs:825:1
16+
--> $SPIRV_STD_SRC/arch/subgroup.rs:868:1
1717
|
1818
LL | / macro_subgroup_op_clustered!(impl Integer, "OpGroupNonUniformIAdd", subgroup_clustered_i_add; r"
1919
LL | | An integer add group operation of all `value` operands contributed by active invocations in the group.

0 commit comments

Comments
 (0)