Skip to content

Commit f438c8c

Browse files
committed
scalar or vector: rename VectorOrScalar to ScalarOrVector and give it its own mod
1 parent 9806b1c commit f438c8c

File tree

7 files changed

+48
-40
lines changed

7 files changed

+48
-40
lines changed

crates/spirv-std/src/arch/subgroup.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::VectorOrScalar;
1+
use crate::ScalarOrVector;
22
#[cfg(target_arch = "spirv")]
33
use crate::arch::barrier;
44
#[cfg(target_arch = "spirv")]
@@ -243,7 +243,7 @@ pub fn subgroup_any(predicate: bool) -> bool {
243243
#[spirv_std_macros::gpu_only]
244244
#[doc(alias = "OpGroupNonUniformAllEqual")]
245245
#[inline]
246-
pub fn subgroup_all_equal<T: VectorOrScalar>(value: T) -> bool {
246+
pub fn subgroup_all_equal<T: ScalarOrVector>(value: T) -> bool {
247247
let mut result = false;
248248

249249
unsafe {
@@ -286,7 +286,7 @@ pub fn subgroup_all_equal<T: VectorOrScalar>(value: T) -> bool {
286286
#[spirv_std_macros::gpu_only]
287287
#[doc(alias = "OpGroupNonUniformBroadcast")]
288288
#[inline]
289-
pub unsafe fn subgroup_broadcast<T: VectorOrScalar>(value: T, id: u32) -> T {
289+
pub unsafe fn subgroup_broadcast<T: ScalarOrVector>(value: T, id: u32) -> T {
290290
let mut result = T::default();
291291

292292
unsafe {
@@ -319,7 +319,7 @@ pub unsafe fn subgroup_broadcast<T: VectorOrScalar>(value: T, id: u32) -> T {
319319
#[spirv_std_macros::gpu_only]
320320
#[doc(alias = "OpGroupNonUniformBroadcastFirst")]
321321
#[inline]
322-
pub fn subgroup_broadcast_first<T: VectorOrScalar>(value: T) -> T {
322+
pub fn subgroup_broadcast_first<T: ScalarOrVector>(value: T) -> T {
323323
let mut result = T::default();
324324

325325
unsafe {
@@ -594,7 +594,7 @@ pub fn subgroup_ballot_find_msb(value: SubgroupMask) -> u32 {
594594
#[spirv_std_macros::gpu_only]
595595
#[doc(alias = "OpGroupNonUniformShuffle")]
596596
#[inline]
597-
pub fn subgroup_shuffle<T: VectorOrScalar>(value: T, id: u32) -> T {
597+
pub fn subgroup_shuffle<T: ScalarOrVector>(value: T, id: u32) -> T {
598598
let mut result = T::default();
599599

600600
unsafe {
@@ -635,7 +635,7 @@ pub fn subgroup_shuffle<T: VectorOrScalar>(value: T, id: u32) -> T {
635635
#[spirv_std_macros::gpu_only]
636636
#[doc(alias = "OpGroupNonUniformShuffleXor")]
637637
#[inline]
638-
pub fn subgroup_shuffle_xor<T: VectorOrScalar>(value: T, mask: u32) -> T {
638+
pub fn subgroup_shuffle_xor<T: ScalarOrVector>(value: T, mask: u32) -> T {
639639
let mut result = T::default();
640640

641641
unsafe {
@@ -676,7 +676,7 @@ pub fn subgroup_shuffle_xor<T: VectorOrScalar>(value: T, mask: u32) -> T {
676676
#[spirv_std_macros::gpu_only]
677677
#[doc(alias = "OpGroupNonUniformShuffleUp")]
678678
#[inline]
679-
pub fn subgroup_shuffle_up<T: VectorOrScalar>(value: T, delta: u32) -> T {
679+
pub fn subgroup_shuffle_up<T: ScalarOrVector>(value: T, delta: u32) -> T {
680680
let mut result = T::default();
681681

682682
unsafe {
@@ -717,7 +717,7 @@ pub fn subgroup_shuffle_up<T: VectorOrScalar>(value: T, delta: u32) -> T {
717717
#[spirv_std_macros::gpu_only]
718718
#[doc(alias = "OpGroupNonUniformShuffleDown")]
719719
#[inline]
720-
pub fn subgroup_shuffle_down<T: VectorOrScalar>(value: T, delta: u32) -> T {
720+
pub fn subgroup_shuffle_down<T: ScalarOrVector>(value: T, delta: u32) -> T {
721721
let mut result = T::default();
722722

723723
unsafe {
@@ -744,7 +744,7 @@ macro_rules! macro_subgroup_op {
744744
#[spirv_std_macros::gpu_only]
745745
#[doc(alias = $asm_op)]
746746
#[inline]
747-
pub fn $name<I: VectorOrScalar<Scalar = $scalar>>(
747+
pub fn $name<I: ScalarOrVector<Scalar = $scalar>>(
748748
value: I,
749749
) -> I {
750750
let mut result = I::default();
@@ -772,7 +772,7 @@ macro_rules! macro_subgroup_op_clustered {
772772
#[spirv_std_macros::gpu_only]
773773
#[doc(alias = $asm_op)]
774774
#[inline]
775-
pub unsafe fn $name<const CLUSTER_SIZE: u32, I: VectorOrScalar<Scalar = $scalar>>(
775+
pub unsafe fn $name<const CLUSTER_SIZE: u32, I: ScalarOrVector<Scalar = $scalar>>(
776776
value: I,
777777
) -> I {
778778
const {
@@ -1344,7 +1344,7 @@ Requires Capability `GroupNonUniformArithmetic` and `GroupNonUniformClustered`.
13441344
#[spirv_std_macros::gpu_only]
13451345
#[doc(alias = "OpGroupNonUniformQuadBroadcast")]
13461346
#[inline]
1347-
pub fn subgroup_quad_broadcast<T: VectorOrScalar>(value: T, index: u32) -> T {
1347+
pub fn subgroup_quad_broadcast<T: ScalarOrVector>(value: T, index: u32) -> T {
13481348
let mut result = T::default();
13491349

13501350
unsafe {
@@ -1427,7 +1427,7 @@ pub enum QuadDirection {
14271427
#[spirv_std_macros::gpu_only]
14281428
#[doc(alias = "OpGroupNonUniformQuadSwap")]
14291429
#[inline]
1430-
pub fn subgroup_quad_swap<const DIRECTION: u32, T: VectorOrScalar>(value: T) -> T {
1430+
pub fn subgroup_quad_swap<const DIRECTION: u32, T: ScalarOrVector>(value: T) -> T {
14311431
let mut result = T::default();
14321432

14331433
unsafe {

crates/spirv-std/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ pub mod ray_tracing;
100100
mod runtime_array;
101101
mod sampler;
102102
mod scalar;
103-
pub(crate) mod sealed;
103+
mod scalar_or_vector;
104104
mod typed_buffer;
105105
mod vector;
106106

@@ -110,6 +110,7 @@ pub use byte_addressable_buffer::ByteAddressableBuffer;
110110
pub use num_traits;
111111
pub use runtime_array::*;
112112
pub use scalar::*;
113+
pub use scalar_or_vector::*;
113114
pub use typed_buffer::*;
114115
pub use vector::*;
115116

crates/spirv-std/src/scalar.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Traits related to scalars.
22
3-
use crate::VectorOrScalar;
3+
use crate::ScalarOrVector;
44
use crate::sealed::Sealed;
55
use core::num::NonZeroUsize;
66

@@ -13,7 +13,7 @@ use core::num::NonZeroUsize;
1313
///
1414
/// # Safety
1515
/// Implementing this trait on non-scalar types breaks assumptions of other unsafe code, and should not be done.
16-
pub unsafe trait Scalar: VectorOrScalar<Scalar = Self> + crate::sealed::Sealed {}
16+
pub unsafe trait Scalar: ScalarOrVector<Scalar = Self> + crate::sealed::Sealed {}
1717

1818
/// Abstract trait representing a SPIR-V integer or floating-point type. Unlike [`Scalar`], excludes the boolean type.
1919
///
@@ -61,9 +61,9 @@ pub unsafe trait Float: num_traits::Float + Number {
6161
macro_rules! impl_scalar {
6262
(impl Scalar for $ty:ty;) => {
6363
impl Sealed for $ty {}
64-
unsafe impl VectorOrScalar for $ty {
64+
unsafe impl ScalarOrVector for $ty {
6565
type Scalar = Self;
66-
const DIM: NonZeroUsize = NonZeroUsize::new(1).unwrap();
66+
const N: NonZeroUsize = NonZeroUsize::new(1).unwrap();
6767
}
6868
unsafe impl Scalar for $ty {}
6969
};
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
use crate::Scalar;
2+
use core::num::NonZeroUsize;
3+
4+
pub(crate) mod sealed {
5+
/// A marker trait used to prevent other traits from being implemented outside
6+
/// of `spirv-std`.
7+
pub trait Sealed {}
8+
}
9+
10+
/// Abstract trait representing either a [`Scalar`] or [`Vector`] type.
11+
///
12+
/// # Safety
13+
/// Your type must also implement [`Scalar`] or [`Vector`], see their safety sections as well.
14+
///
15+
/// [`Vector`]: crate::Vector
16+
pub unsafe trait ScalarOrVector: Copy + Default + Send + Sync + 'static {
17+
/// Either the scalar component type of the vector or the scalar itself.
18+
type Scalar: Scalar;
19+
20+
/// The dimension of the vector, or 1 if it is a scalar
21+
const N: NonZeroUsize;
22+
}

crates/spirv-std/src/sealed.rs

Lines changed: 0 additions & 3 deletions
This file was deleted.

crates/spirv-std/src/vector.rs

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,10 @@
11
//! Traits related to vectors.
22
3-
use crate::Scalar;
43
use crate::sealed::Sealed;
4+
use crate::{Scalar, ScalarOrVector};
55
use core::num::NonZeroUsize;
66
use glam::{Vec3Swizzles, Vec4Swizzles};
77

8-
/// Abstract trait representing either a vector or a scalar type.
9-
///
10-
/// # Safety
11-
/// Your type must also implement [`Vector`] or [`Scalar`], see their safety sections as well.
12-
pub unsafe trait VectorOrScalar: Copy + Default + Send + Sync + 'static {
13-
/// Either the scalar component type of the vector or the scalar itself.
14-
type Scalar: Scalar;
15-
16-
/// The dimension of the vector, or 1 if it is a scalar
17-
const DIM: NonZeroUsize;
18-
}
19-
208
/// Abstract trait representing a SPIR-V vector type.
219
///
2210
/// To implement this trait, your struct must be marked with:
@@ -63,15 +51,15 @@ pub unsafe trait VectorOrScalar: Copy + Default + Send + Sync + 'static {
6351
// While it's possible with `T: Scalar`, it's not with `const N: usize`, since some impl blocks in `image::params` need
6452
// to be conditional on a specific N value. And you can only express that with const generics, but not with associated
6553
// constants due to lack of const generics support in rustc.
66-
pub unsafe trait Vector<T: Scalar, const N: usize>: VectorOrScalar<Scalar = T> {}
54+
pub unsafe trait Vector<T: Scalar, const N: usize>: ScalarOrVector<Scalar = T> {}
6755

6856
macro_rules! impl_vector {
6957
($($ty:ty: [$scalar:ty; $n:literal];)+) => {
7058
$(
7159
impl Sealed for $ty {}
72-
unsafe impl VectorOrScalar for $ty {
60+
unsafe impl ScalarOrVector for $ty {
7361
type Scalar = $scalar;
74-
const DIM: NonZeroUsize = NonZeroUsize::new($n).unwrap();
62+
const N: NonZeroUsize = NonZeroUsize::new($n).unwrap();
7563
}
7664
unsafe impl Vector<$scalar, $n> for $ty {}
7765
)+

tests/compiletests/ui/arch/debug_printf_type_checking.stderr

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ LL | debug_printf!("%f", 11_u32);
7575
| |
7676
| this argument influences the return type of `debug_printf_assert_is_type`
7777
note: function defined here
78-
--> $SPIRV_STD_SRC/lib.rs:134:8
78+
--> $SPIRV_STD_SRC/lib.rs:135:8
7979
|
8080
LL | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
8181
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -103,7 +103,7 @@ LL | debug_printf!("%u", 11.0_f32);
103103
| |
104104
| this argument influences the return type of `debug_printf_assert_is_type`
105105
note: function defined here
106-
--> $SPIRV_STD_SRC/lib.rs:134:8
106+
--> $SPIRV_STD_SRC/lib.rs:135:8
107107
|
108108
LL | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
109109
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -131,7 +131,7 @@ LL | debug_printf!("%v2f", 11.0);
131131
`IVec3` implements `Vector<i32, 3>`
132132
and 8 others
133133
note: required by a bound in `debug_printf_assert_is_vector`
134-
--> $SPIRV_STD_SRC/lib.rs:139:53
134+
--> $SPIRV_STD_SRC/lib.rs:140:53
135135
|
136136
LL | pub fn debug_printf_assert_is_vector<TY: Scalar, V: Vector<TY, SIZE>, const SIZE: usize>(
137137
| ^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector`
@@ -154,7 +154,7 @@ LL | debug_printf!("%f", Vec2::splat(33.3));
154154
| |
155155
| this argument influences the return type of `debug_printf_assert_is_type`
156156
note: function defined here
157-
--> $SPIRV_STD_SRC/lib.rs:134:8
157+
--> $SPIRV_STD_SRC/lib.rs:135:8
158158
|
159159
LL | pub fn debug_printf_assert_is_type<T>(ty: T) -> T {
160160
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^

0 commit comments

Comments
 (0)