Skip to content

Commit

Permalink
[naga] Introduce Scalar type to IR.
Browse files Browse the repository at this point in the history
Introduce a new struct type, `Scalar`, combining a `ScalarKind` and a
`Bytes` width, and use this whenever such pairs of values are passed
around.

In particular, use `Scalar` in `TypeInner` variants `Scalar`, `Vector`,
`Atomic`, and `ValuePointer`.

Introduce associated `Scalar` constants `I32`, `U32`, `F32`, `BOOL`
and `F64`, for common cases.

Introduce a helper function `Scalar::float` for constructing `Float`
scalars of a given width, for dealing with `TypeInner::Matrix`, which
only supplies the scalar width of its elements, not a kind.

Introduce helper functions on `Literal` and `TypeInner`, to produce
the `Scalar` describing elements' values.

Use `Scalar` in `wgpu_core::validation::NumericType` as well.
  • Loading branch information
jimblandy authored and teoxoy committed Nov 14, 2023
1 parent 049cb75 commit 9f91c95
Show file tree
Hide file tree
Showing 55 changed files with 2,353 additions and 2,322 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ For naga changelogs at or before v0.14.0. See [naga's changelog](naga/CHANGELOG.
- Log vulkan validation layer messages during instance creation and destruction: By @exrook in [#4586](https://github.com/gfx-rs/wgpu/pull/4586)
- `TextureFormat::block_size` is deprecated, use `TextureFormat::block_copy_size` instead: By @wumpf in [#4647](https://github.com/gfx-rs/wgpu/pull/4647)

#### Naga

- Introduce a new `Scalar` struct type for use in Naga's IR, and update all frontend, middle, and backend code appropriately. By @jimblandy in [#4673](https://github.com/gfx-rs/wgpu/pull/4673).

### Bug Fixes

#### WGL
Expand Down
14 changes: 7 additions & 7 deletions naga/src/back/glsl/features.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{BackendResult, Error, Version, Writer};
use crate::{
AddressSpace, Binding, Bytes, Expression, Handle, ImageClass, ImageDimension, Interpolation,
Sampling, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner,
AddressSpace, Binding, Expression, Handle, ImageClass, ImageDimension, Interpolation, Sampling,
Scalar, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner,
};
use std::fmt::Write;

Expand Down Expand Up @@ -275,10 +275,10 @@ impl<'a, W> Writer<'a, W> {

for (ty_handle, ty) in self.module.types.iter() {
match ty.inner {
TypeInner::Scalar { kind, width } => self.scalar_required_features(kind, width),
TypeInner::Vector { kind, width, .. } => self.scalar_required_features(kind, width),
TypeInner::Scalar(scalar) => self.scalar_required_features(scalar),
TypeInner::Vector { scalar, .. } => self.scalar_required_features(scalar),
TypeInner::Matrix { width, .. } => {
self.scalar_required_features(ScalarKind::Float, width)
self.scalar_required_features(Scalar::float(width))
}
TypeInner::Array { base, size, .. } => {
if let TypeInner::Array { .. } = self.module.types[base].inner {
Expand Down Expand Up @@ -463,8 +463,8 @@ impl<'a, W> Writer<'a, W> {
}

/// Helper method that checks the [`Features`] needed by a scalar
fn scalar_required_features(&mut self, kind: ScalarKind, width: Bytes) {
if kind == ScalarKind::Float && width == 8 {
fn scalar_required_features(&mut self, scalar: Scalar) {
if scalar.kind == ScalarKind::Float && scalar.width == 8 {
self.features.request(Features::DOUBLE_TYPE);
}
}
Expand Down
133 changes: 71 additions & 62 deletions naga/src/back/glsl/mod.rs

Large diffs are not rendered by default.

24 changes: 13 additions & 11 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@ impl crate::ScalarKind {
Self::Bool => unreachable!(),
}
}
}

impl crate::Scalar {
/// Helper function that returns scalar related strings
///
/// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar>
pub(super) const fn to_hlsl_str(self, width: crate::Bytes) -> Result<&'static str, Error> {
match self {
Self::Sint => Ok("int"),
Self::Uint => Ok("uint"),
Self::Float => match width {
pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> {
match self.kind {
crate::ScalarKind::Sint => Ok("int"),
crate::ScalarKind::Uint => Ok("uint"),
crate::ScalarKind::Float => match self.width {
2 => Ok("half"),
4 => Ok("float"),
8 => Ok("double"),
_ => Err(Error::UnsupportedScalar(self, width)),
_ => Err(Error::UnsupportedScalar(self)),
},
Self::Bool => Ok("bool"),
crate::ScalarKind::Bool => Ok("bool"),
}
}
}
Expand Down Expand Up @@ -71,10 +73,10 @@ impl crate::TypeInner {
names: &'a crate::FastHashMap<crate::proc::NameKey, String>,
) -> Result<Cow<'a, str>, Error> {
Ok(match gctx.types[base].inner {
crate::TypeInner::Scalar { kind, width } => Cow::Borrowed(kind.to_hlsl_str(width)?),
crate::TypeInner::Vector { size, kind, width } => Cow::Owned(format!(
crate::TypeInner::Scalar(scalar) => Cow::Borrowed(scalar.to_hlsl_str()?),
crate::TypeInner::Vector { size, scalar } => Cow::Owned(format!(
"{}{}",
kind.to_hlsl_str(width)?,
scalar.to_hlsl_str()?,
crate::back::vector_size_str(size)
)),
crate::TypeInner::Matrix {
Expand All @@ -83,7 +85,7 @@ impl crate::TypeInner {
width,
} => Cow::Owned(format!(
"{}{}x{}",
crate::ScalarKind::Float.to_hlsl_str(width)?,
crate::Scalar::float(width).to_hlsl_str()?,
crate::back::vector_size_str(columns),
crate::back::vector_size_str(rows),
)),
Expand Down
12 changes: 5 additions & 7 deletions naga/src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
}
crate::ImageClass::Sampled { kind, multi } => {
let multi_str = if multi { "MS" } else { "" };
let scalar_kind_str = kind.to_hlsl_str(4)?;
let scalar_kind_str = crate::Scalar { kind, width: 4 }.to_hlsl_str()?;
write!(self.out, "{multi_str}<{scalar_kind_str}4>")?
}
crate::ImageClass::Storage { format, .. } => {
Expand Down Expand Up @@ -658,8 +658,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
let vec_ty = match module.types[member.ty].inner {
crate::TypeInner::Matrix { rows, width, .. } => crate::TypeInner::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
scalar: crate::Scalar::float(width),
},
_ => unreachable!(),
};
Expand Down Expand Up @@ -737,10 +736,9 @@ impl<'a, W: Write> super::Writer<'a, W> {
_ => unreachable!(),
};
let scalar_ty = match module.types[member.ty].inner {
crate::TypeInner::Matrix { width, .. } => crate::TypeInner::Scalar {
kind: crate::ScalarKind::Float,
width,
},
crate::TypeInner::Matrix { width, .. } => {
crate::TypeInner::Scalar(crate::Scalar::float(width))
}
_ => unreachable!(),
};
self.write_value_type(module, &scalar_ty)?;
Expand Down
4 changes: 2 additions & 2 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ pub struct ReflectionInfo {
pub enum Error {
#[error(transparent)]
IoError(#[from] FmtError),
#[error("A scalar with an unsupported width was requested: {0:?} {1:?}")]
UnsupportedScalar(crate::ScalarKind, crate::Bytes),
#[error("A scalar with an unsupported width was requested: {0:?}")]
UnsupportedScalar(crate::Scalar),
#[error("{0}")]
Unimplemented(String), // TODO: Error used only during development
#[error("{0}")]
Expand Down
32 changes: 13 additions & 19 deletions naga/src/back/hlsl/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,25 +157,21 @@ impl<W: fmt::Write> super::Writer<'_, W> {
func_ctx: &FunctionCtx,
) -> BackendResult {
match *result_ty.inner_with(&module.types) {
crate::TypeInner::Scalar { kind, width: _ } => {
crate::TypeInner::Scalar(scalar) => {
// working around the borrow checker in `self.write_expr`
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
let cast = kind.to_hlsl_cast();
let cast = scalar.kind.to_hlsl_cast();
write!(self.out, "{cast}({var_name}.Load(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, "))")?;
self.temp_access_chain = chain;
}
crate::TypeInner::Vector {
size,
kind,
width: _,
} => {
crate::TypeInner::Vector { size, scalar } => {
// working around the borrow checker in `self.write_expr`
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
let cast = kind.to_hlsl_cast();
let cast = scalar.kind.to_hlsl_cast();
write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, "))")?;
Expand All @@ -189,7 +185,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
write!(
self.out,
"{}{}x{}(",
crate::ScalarKind::Float.to_hlsl_str(width)?,
crate::Scalar::float(width).to_hlsl_str()?,
columns as u8,
rows as u8,
)?;
Expand All @@ -199,8 +195,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
let iter = (0..columns as u32).map(|i| {
let ty_inner = crate::TypeInner::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
scalar: crate::Scalar::float(width),
};
(TypeResolution::Value(ty_inner), i * row_stride)
});
Expand Down Expand Up @@ -296,7 +291,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
}
};
match *ty_resolution.inner_with(&module.types) {
crate::TypeInner::Scalar { .. } => {
crate::TypeInner::Scalar(_) => {
// working around the borrow checker in `self.write_expr`
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
Expand Down Expand Up @@ -330,7 +325,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
self.out,
"{}{}{}x{} {}{} = ",
level.next(),
crate::ScalarKind::Float.to_hlsl_str(width)?,
crate::Scalar::float(width).to_hlsl_str()?,
columns as u8,
rows as u8,
STORE_TEMP_NAME,
Expand All @@ -348,8 +343,7 @@ impl<W: fmt::Write> super::Writer<'_, W> {
.push(SubAccess::Offset(i * row_stride));
let ty_inner = crate::TypeInner::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
scalar: crate::Scalar::float(width),
};
let sv = StoreValue::TempIndex {
depth,
Expand Down Expand Up @@ -470,8 +464,8 @@ impl<W: fmt::Write> super::Writer<'_, W> {
crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
crate::TypeInner::Vector { width, .. } => Parent::Array {
stride: width as u32,
crate::TypeInner::Vector { scalar, .. } => Parent::Array {
stride: scalar.width as u32,
},
crate::TypeInner::Matrix { rows, width, .. } => Parent::Array {
// The stride between matrices is the count of rows as this is how
Expand All @@ -480,8 +474,8 @@ impl<W: fmt::Write> super::Writer<'_, W> {
},
_ => unreachable!(),
},
crate::TypeInner::ValuePointer { width, .. } => Parent::Array {
stride: width as u32,
crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array {
stride: scalar.width as u32,
},
_ => unreachable!(),
};
Expand Down
43 changes: 23 additions & 20 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -912,8 +912,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
} if member.binding.is_none() && rows == crate::VectorSize::Bi => {
let vec_ty = crate::TypeInner::Vector {
size: rows,
kind: crate::ScalarKind::Float,
width,
scalar: crate::Scalar::float(width),
};
let field_name_key = NameKey::StructMember(handle, index as u32);

Expand Down Expand Up @@ -1024,14 +1023,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
/// Adds no trailing or leading whitespace
pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
match *inner {
TypeInner::Scalar { kind, width } | TypeInner::Atomic { kind, width } => {
write!(self.out, "{}", kind.to_hlsl_str(width)?)?;
TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
write!(self.out, "{}", scalar.to_hlsl_str()?)?;
}
TypeInner::Vector { size, kind, width } => {
TypeInner::Vector { size, scalar } => {
write!(
self.out,
"{}{}",
kind.to_hlsl_str(width)?,
scalar.to_hlsl_str()?,
back::vector_size_str(size)
)?;
}
Expand All @@ -1047,7 +1046,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(
self.out,
"{}{}x{}",
crate::ScalarKind::Float.to_hlsl_str(width)?,
crate::Scalar::float(width).to_hlsl_str()?,
back::vector_size_str(columns),
back::vector_size_str(rows),
)?;
Expand Down Expand Up @@ -2484,7 +2483,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")")?;

// return x component if return type is scalar
if let TypeInner::Scalar { .. } = *func_ctx.resolve_type(expr, &module.types) {
if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
write!(self.out, ".x")?;
}
}
Expand Down Expand Up @@ -2567,23 +2566,27 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
let inner = func_ctx.resolve_type(expr, &module.types);
match convert {
Some(dst_width) => {
let scalar = crate::Scalar {
kind,
width: dst_width,
};
match *inner {
TypeInner::Vector { size, .. } => {
write!(
self.out,
"{}{}(",
kind.to_hlsl_str(dst_width)?,
scalar.to_hlsl_str()?,
back::vector_size_str(size)
)?;
}
TypeInner::Scalar { .. } => {
write!(self.out, "{}(", kind.to_hlsl_str(dst_width)?,)?;
TypeInner::Scalar(_) => {
write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
}
TypeInner::Matrix { columns, rows, .. } => {
write!(
self.out,
"{}{}x{}(",
kind.to_hlsl_str(dst_width)?,
scalar.to_hlsl_str()?,
back::vector_size_str(columns),
back::vector_size_str(rows)
)?;
Expand Down Expand Up @@ -2964,14 +2967,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
Function::CountTrailingZeros => {
match *func_ctx.resolve_type(arg, &module.types) {
TypeInner::Vector { size, kind, .. } => {
TypeInner::Vector { size, scalar } => {
let s = match size {
crate::VectorSize::Bi => ".xx",
crate::VectorSize::Tri => ".xxx",
crate::VectorSize::Quad => ".xxxx",
};

if let ScalarKind::Uint = kind {
if let ScalarKind::Uint = scalar.kind {
write!(self.out, "min((32u){s}, firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
Expand All @@ -2981,8 +2984,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")))")?;
}
}
TypeInner::Scalar { kind, .. } => {
if let ScalarKind::Uint = kind {
TypeInner::Scalar(scalar) => {
if let ScalarKind::Uint = scalar.kind {
write!(self.out, "min(32u, firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
Expand All @@ -2999,14 +3002,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
Function::CountLeadingZeros => {
match *func_ctx.resolve_type(arg, &module.types) {
TypeInner::Vector { size, kind, .. } => {
TypeInner::Vector { size, scalar } => {
let s = match size {
crate::VectorSize::Bi => ".xx",
crate::VectorSize::Tri => ".xxx",
crate::VectorSize::Quad => ".xxxx",
};

if let ScalarKind::Uint = kind {
if let ScalarKind::Uint = scalar.kind {
write!(self.out, "((31u){s} - firstbithigh(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
Expand All @@ -3021,8 +3024,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")))")?;
}
}
TypeInner::Scalar { kind, .. } => {
if let ScalarKind::Uint = kind {
TypeInner::Scalar(scalar) => {
if let ScalarKind::Uint = scalar.kind {
write!(self.out, "(31u - firstbithigh(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
Expand Down
Loading

0 comments on commit 9f91c95

Please sign in to comment.