Skip to content

Commit

Permalink
Support array bindings of buffers (#2282)
Browse files Browse the repository at this point in the history
* Support buffer resource arrays in IR, wgsl-in, and spv-out

* spv-out: refactor non-uniform indexing semantics to support buffers

* Update the doc comment on BindingArray type

* Improve TypeInfo restrictions on binding arrays

* Strip DATA out of binding arrays

* Include suggested documentation, more binding array tests, enforce structs
  • Loading branch information
kvark authored Apr 25, 2023
1 parent 1421a5e commit 37b3c36
Show file tree
Hide file tree
Showing 16 changed files with 380 additions and 93 deletions.
77 changes: 50 additions & 27 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ Implementations for `BlockContext` methods.
*/

use super::{
index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, Dimension,
Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, WriterFlags,
helpers, index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext,
Dimension, Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer,
WriterFlags,
};
use crate::{arena::Handle, proc::TypeResolution};
use spirv::Word;
Expand Down Expand Up @@ -196,9 +197,8 @@ impl<'w> BlockContext<'w> {
fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool {
match self.ir_function.expressions[expr_handle] {
crate::Expression::GlobalVariable(handle) => {
let ty = self.ir_module.global_variables[handle].ty;
match self.ir_module.types[ty].inner {
crate::TypeInner::BindingArray { .. } => false,
match self.ir_module.global_variables[handle].space {
crate::AddressSpace::Handle => false,
_ => true,
}
}
Expand All @@ -221,7 +221,6 @@ impl<'w> BlockContext<'w> {
block: &mut Block,
) -> Result<(), Error> {
let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);

let id = match self.ir_function.expressions[expr_handle] {
crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => {
// See `is_intermediate`; we'll handle this later in
Expand All @@ -237,9 +236,15 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
let space = match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(gvar) => {
self.ir_module.global_variables[gvar].space
}
_ => unreachable!(),
};
let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
base: binding_type,
class: spirv::StorageClass::UniformConstant,
class: helpers::map_storage_class(space),
});

let result_id = match self.write_expression_pointer(
Expand All @@ -265,15 +270,6 @@ impl<'w> BlockContext<'w> {
None,
));

if self.fun_info[index].uniformity.non_uniform_result.is_some() {
self.writer.require_any(
"NonUniformEXT",
&[spirv::Capability::ShaderNonUniform],
)?;
self.writer.use_extension("SPV_EXT_descriptor_indexing");
self.writer
.decorate(load_id, spirv::Decoration::NonUniform, &[]);
}
load_id
}
ref other => {
Expand Down Expand Up @@ -316,9 +312,15 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
let space = match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(gvar) => {
self.ir_module.global_variables[gvar].space
}
_ => unreachable!(),
};
let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
base: binding_type,
class: spirv::StorageClass::UniformConstant,
class: helpers::map_storage_class(space),
});

let result_id = match self.write_expression_pointer(
Expand Down Expand Up @@ -1403,8 +1405,8 @@ impl<'w> BlockContext<'w> {
/// Emit any needed bounds-checking expressions to `block`.
///
/// Some cases we need to generate a different return type than what the IR gives us.
/// This is because pointers to binding arrays don't exist in the IR, but we need to
/// create them to create an access chain in SPIRV.
/// This is because pointers to binding arrays of handles (such as images or samplers)
/// don't exist in the IR, but we need to create them to create an access chain in SPIRV.
///
/// On success, the return value is an [`ExpressionPointer`] value; see the
/// documentation for that type.
Expand Down Expand Up @@ -1434,11 +1436,25 @@ impl<'w> BlockContext<'w> {
// but we expect these checks to almost always succeed, and keeping branches to a
// minimum is essential.
let mut accumulated_checks = None;
// Is true if we are accessing into a binding array of buffers with a non-uniform index.
let mut is_non_uniform_binding_array = false;

self.temp_list.clear();
let root_id = loop {
expr_handle = match self.ir_function.expressions[expr_handle] {
crate::Expression::Access { base, index } => {
if let crate::Expression::GlobalVariable(var_handle) =
self.ir_function.expressions[base]
{
let gvar = &self.ir_module.global_variables[var_handle];
if let crate::TypeInner::BindingArray { .. } =
self.ir_module.types[gvar.ty].inner
{
is_non_uniform_binding_array |=
self.fun_info[index].uniformity.non_uniform_result.is_some();
}
}

let index_id = match self.write_bounds_check(base, index, block)? {
BoundsCheckResult::KnownInBounds(known_index) => {
// Even if the index is known, `OpAccessIndex`
Expand Down Expand Up @@ -1471,7 +1487,6 @@ impl<'w> BlockContext<'w> {
}
};
self.temp_list.push(index_id);

base
}
crate::Expression::AccessIndex { base, index } => {
Expand All @@ -1494,10 +1509,13 @@ impl<'w> BlockContext<'w> {
}
};

let pointer = if self.temp_list.is_empty() {
ExpressionPointer::Ready {
pointer_id: root_id,
}
let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
(
root_id,
ExpressionPointer::Ready {
pointer_id: root_id,
},
)
} else {
self.temp_list.reverse();
let pointer_id = self.gen_id();
Expand All @@ -1508,16 +1526,21 @@ impl<'w> BlockContext<'w> {
// caller to generate the branch, the access, the load or store, and
// the zero value (for loads). Otherwise, we can emit the access
// ourselves, and just hand them the id of the pointer.
match accumulated_checks {
let expr_pointer = match accumulated_checks {
Some(condition) => ExpressionPointer::Conditional { condition, access },
None => {
block.body.push(access);
ExpressionPointer::Ready { pointer_id }
}
}
};
(pointer_id, expr_pointer)
};
if is_non_uniform_binding_array {
self.writer
.decorate_non_uniform_binding_array_access(pointer_id)?;
}

Ok(pointer)
Ok(expr_pointer)
}

/// Build the instructions for matrix - matrix column operations
Expand Down
3 changes: 2 additions & 1 deletion src/back/spv/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab
},
None => false,
},
// if it's not a structure, let's wrap it to be able to put "Block"
crate::TypeInner::BindingArray { .. } => false,
// if it's not a structure or a binding array, let's wrap it to be able to put "Block"
_ => true,
}
}
6 changes: 6 additions & 0 deletions src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,15 @@ enum LocalType {
image_type_id: Word,
},
Sampler,
/// Equivalent to a [`LocalType::Pointer`] whose `base` is a Naga IR [`BindingArray`]. SPIR-V
/// permits duplicated `OpTypePointer` ids, so it's fine to have two different [`LocalType`]
/// representations for pointer types.
///
/// [`BindingArray`]: crate::TypeInner::BindingArray
PointerToBindingArray {
base: Handle<crate::Type>,
size: u64,
space: crate::AddressSpace,
},
BindingArray {
base: Handle<crate::Type>,
Expand Down
24 changes: 21 additions & 3 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -940,10 +940,11 @@ impl Writer {
let scalar_id = self.get_constant_scalar(crate::ScalarValue::Uint(size), 4);
Instruction::type_array(id, inner_ty, scalar_id)
}
LocalType::PointerToBindingArray { base, size } => {
LocalType::PointerToBindingArray { base, size, space } => {
let inner_ty =
self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size }));
Instruction::type_pointer(id, spirv::StorageClass::UniformConstant, inner_ty)
let class = map_storage_class(space);
Instruction::type_pointer(id, class, inner_ty)
}
LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
LocalType::RayQuery => Instruction::type_ray_query(id),
Expand Down Expand Up @@ -1579,6 +1580,9 @@ impl Writer {
}
}

// Note: we should be able to substitute `binding_array<Foo, 0>`,
// but there is still code that tries to register the pre-substituted type,
// and it is failing on 0.
let mut substitute_inner_type_lookup = None;
if let Some(ref res_binding) = global_variable.binding {
self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]);
Expand All @@ -1595,6 +1599,7 @@ impl Writer {
Some(LookupType::Local(LocalType::PointerToBindingArray {
base,
size: remapped_binding_array_size as u64,
space: global_variable.space,
}))
}
} else {
Expand Down Expand Up @@ -1635,7 +1640,13 @@ impl Writer {
// a runtime-sized array. In this case, we need to decorate it with
// Block.
if let crate::AddressSpace::Storage { .. } = global_variable.space {
self.decorate(inner_type_id, Decoration::Block, &[]);
let decorated_id = match ir_module.types[global_variable.ty].inner {
crate::TypeInner::BindingArray { base, .. } => {
self.get_type_id(LookupType::Handle(base))
}
_ => inner_type_id,
};
self.decorate(decorated_id, Decoration::Block, &[]);
}
if substitute_inner_type_lookup.is_some() {
inner_type_id
Expand Down Expand Up @@ -1955,6 +1966,13 @@ impl Writer {
pub const fn get_capabilities_used(&self) -> &crate::FastHashSet<spirv::Capability> {
&self.capabilities_used
}

pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
self.use_extension("SPV_EXT_descriptor_indexing");
self.decorate(id, spirv::Decoration::NonUniform, &[]);
Ok(())
}
}

#[test]
Expand Down
17 changes: 8 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ tree.
clippy::match_like_matches_macro,
clippy::collapsible_if,
clippy::derive_partial_eq_without_eq,
clippy::needless_borrowed_reference
clippy::needless_borrowed_reference,
clippy::single_match
)]
#![warn(
trivial_casts,
Expand Down Expand Up @@ -752,13 +753,12 @@ pub enum TypeInner {
/// buffers could have elements that are dynamically sized arrays, each with
/// a different length.
///
/// Binding arrays are not [`DATA`]. This means that all binding array
/// globals must be placed in the [`Handle`] address space. Referring to
/// such a global produces a `BindingArray` value directly; there are never
/// pointers to binding arrays. The only operation permitted on
/// `BindingArray` values is indexing, which yields the element by value,
/// not a pointer to the element. (This means that buffer array contents
/// cannot be stored to; [naga#1864] covers lifting this restriction.)
/// Binding arrays are in the same address spaces as their underlying type.
/// As such, referring to an array of images produces an [`Image`] value
/// directly (as opposed to a pointer). The only operation permitted on
/// `BindingArray` values is indexing, which works transparently: indexing
/// a binding array of samplers yields a [`Sampler`], indexing a pointer to the
/// binding array of storage buffers produces a pointer to the storage struct.
///
/// Unlike textures and samplers, binding arrays are not [`ARGUMENT`], so
/// they cannot be passed as arguments to functions.
Expand All @@ -774,7 +774,6 @@ pub enum TypeInner {
/// [`SamplerArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.SamplerArray
/// [`BufferArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.BufferArray
/// [`DATA`]: crate::valid::TypeFlags::DATA
/// [`Handle`]: AddressSpace::Handle
/// [`ARGUMENT`]: crate::valid::TypeFlags::ARGUMENT
/// [naga#1864]: https://github.com/gfx-rs/naga/issues/1864
BindingArray { base: Handle<Type>, size: ArraySize },
Expand Down
4 changes: 3 additions & 1 deletion src/proc/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,9 @@ impl crate::TypeInner {
match *base_inner {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } => return size.to_indexable_length(module),
Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
return size.to_indexable_length(module)
}
_ => return Err(IndexableLengthError::TypeNotIndexable),
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl<'a> ResolveContext<'a> {
width,
space,
},
Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
ref other => {
log::error!("Access sub-type {:?}", other);
return Err(ResolveError::InvalidSubAccess {
Expand Down Expand Up @@ -401,6 +402,7 @@ impl<'a> ResolveContext<'a> {
space,
}
}
Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
ref other => {
log::error!("Access index sub-type {:?}", other);
return Err(ResolveError::InvalidSubAccess {
Expand Down
55 changes: 27 additions & 28 deletions src/valid/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,14 @@ impl super::Validator {
use super::TypeFlags;

log::debug!("var {:?}", var);
let type_info = &self.types[var.ty.index()];
let inner_ty = match types[var.ty].inner {
// A binding array is (mostly) supposed to behave the same as a
// series of individually bound resources, so we can (mostly)
// validate a `binding_array<T>` as if it were just a plain `T`.
crate::TypeInner::BindingArray { base, .. } => base,
_ => var.ty,
};
let type_info = &self.types[inner_ty.index()];

let (required_type_flags, is_resource) = match var.space {
crate::AddressSpace::Function => {
Expand Down Expand Up @@ -437,22 +444,8 @@ impl super::Validator {
)
}
crate::AddressSpace::Handle => {
match types[var.ty].inner {
crate::TypeInner::Image { .. }
| crate::TypeInner::Sampler { .. }
| crate::TypeInner::BindingArray { .. }
| crate::TypeInner::AccelerationStructure
| crate::TypeInner::RayQuery => {}
_ => {
return Err(GlobalVariableError::InvalidType(var.space));
}
};
let inner_ty = match &types[var.ty].inner {
&crate::TypeInner::BindingArray { base, .. } => &types[base].inner,
ty => ty,
};
if let crate::TypeInner::Image {
class:
match types[inner_ty].inner {
crate::TypeInner::Image { class, .. } => match class {
crate::ImageClass::Storage {
format:
crate::StorageFormat::R16Unorm
Expand All @@ -462,17 +455,23 @@ impl super::Validator {
| crate::StorageFormat::Rgba16Unorm
| crate::StorageFormat::Rgba16Snorm,
..
},
..
} = *inner_ty
{
if !self
.capabilities
.contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
{
return Err(GlobalVariableError::UnsupportedCapability(
Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
));
} => {
if !self
.capabilities
.contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
{
return Err(GlobalVariableError::UnsupportedCapability(
Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
));
}
}
_ => {}
},
crate::TypeInner::Sampler { .. }
| crate::TypeInner::AccelerationStructure
| crate::TypeInner::RayQuery => {}
_ => {
return Err(GlobalVariableError::InvalidType(var.space));
}
}

Expand Down
Loading

0 comments on commit 37b3c36

Please sign in to comment.