Skip to content

Commit b5752c3

Browse files
committed
scalar or vector: adjust subgroup_all_equal to accept composites
1 parent 07413a0 commit b5752c3

File tree

3 files changed

+86
-14
lines changed

3 files changed

+86
-14
lines changed

crates/spirv-std/src/arch/subgroup.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,24 +244,35 @@ pub fn subgroup_any(predicate: bool) -> bool {
244244
#[spirv_std_macros::gpu_only]
245245
#[doc(alias = "OpGroupNonUniformAllEqual")]
246246
#[inline]
247-
pub fn subgroup_all_equal<T: ScalarOrVector>(value: T) -> bool {
248-
let mut result = false;
247+
pub fn subgroup_all_equal<T: ScalarOrVectorComposite>(value: T) -> bool {
248+
struct Transform(bool);
249249

250-
unsafe {
251-
asm! {
252-
"%bool = OpTypeBool",
253-
"%u32 = OpTypeInt 32 0",
254-
"%subgroup = OpConstant %u32 {subgroup}",
255-
"%value = OpLoad _ {value}",
256-
"%result = OpGroupNonUniformAllEqual %bool %subgroup %value",
257-
"OpStore {result} %result",
258-
subgroup = const SUBGROUP,
259-
value = in(reg) &value,
260-
result = in(reg) &mut result,
250+
impl ScalarOrVectorTransform for Transform {
251+
#[inline]
252+
fn transform<T: ScalarOrVector>(&mut self, value: T) -> T {
253+
let mut result = false;
254+
unsafe {
255+
asm! {
256+
"%bool = OpTypeBool",
257+
"%u32 = OpTypeInt 32 0",
258+
"%subgroup = OpConstant %u32 {subgroup}",
259+
"%value = OpLoad _ {value}",
260+
"%result = OpGroupNonUniformAllEqual %bool %subgroup %value",
261+
"OpStore {result} %result",
262+
subgroup = const SUBGROUP,
263+
value = in(reg) &value,
264+
result = in(reg) &mut result,
265+
}
266+
}
267+
self.0 &= result;
268+
value
261269
}
262270
}
263271

264-
result
272+
let mut transform = Transform(true);
273+
// ignore returned value
274+
value.transform(&mut transform);
275+
transform.0
265276
}
266277

267278
/// Result is the `value` of the invocation identified by the id `id` to all active invocations in the group.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformVote,+ext:SPV_KHR_vulkan_memory_model
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_composite_all_equals::disassembly
4+
// normalize-stderr-test "OpLine .*\n" -> ""
5+
6+
use glam::*;
7+
use spirv_std::ScalarOrVectorComposite;
8+
use spirv_std::arch::*;
9+
use spirv_std::spirv;
10+
11+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
12+
pub struct MyStruct {
13+
a: f32,
14+
b: UVec3,
15+
c: Nested,
16+
d: Zst,
17+
}
18+
19+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
20+
pub struct Nested(i32);
21+
22+
#[derive(Copy, Clone, ScalarOrVectorComposite)]
23+
pub struct Zst;
24+
25+
/// this should be 3 `subgroup_all_equal` instructions, with all calls inlined
26+
fn disassembly(my_struct: MyStruct) -> bool {
27+
subgroup_all_equal(my_struct)
28+
}
29+
30+
#[spirv(compute(threads(32)))]
31+
pub fn main(
32+
#[spirv(local_invocation_index)] inv_id: UVec3,
33+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] output: &mut u32,
34+
) {
35+
unsafe {
36+
let my_struct = MyStruct {
37+
a: inv_id.x as f32,
38+
b: inv_id,
39+
c: Nested(5i32 - inv_id.x as i32),
40+
d: Zst,
41+
};
42+
43+
let bool = disassembly(my_struct);
44+
*output = u32::from(bool);
45+
}
46+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
%1 = OpFunction %2 None %3
2+
%4 = OpFunctionParameter %5
3+
%6 = OpLabel
4+
%8 = OpCompositeExtract %9 %4 0
5+
%11 = OpGroupNonUniformAllEqual %2 %12 %8
6+
%13 = OpLogicalAnd %2 %14 %11
7+
%15 = OpCompositeExtract %16 %4 1
8+
%17 = OpGroupNonUniformAllEqual %2 %12 %15
9+
%18 = OpLogicalAnd %2 %13 %17
10+
%19 = OpCompositeExtract %20 %4 2
11+
%21 = OpGroupNonUniformAllEqual %2 %12 %19
12+
%22 = OpLogicalAnd %2 %18 %21
13+
OpNoLine
14+
OpReturnValue %22
15+
OpFunctionEnd

0 commit comments

Comments
 (0)