Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support array bindings of buffers #2282

Merged
merged 6 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 50 additions & 27 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ Implementations for `BlockContext` methods.
*/

use super::{
index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, Dimension,
Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, WriterFlags,
helpers, index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext,
Dimension, Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer,
WriterFlags,
};
use crate::{arena::Handle, proc::TypeResolution};
use spirv::Word;
Expand Down Expand Up @@ -196,9 +197,8 @@ impl<'w> BlockContext<'w> {
fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool {
match self.ir_function.expressions[expr_handle] {
crate::Expression::GlobalVariable(handle) => {
let ty = self.ir_module.global_variables[handle].ty;
match self.ir_module.types[ty].inner {
crate::TypeInner::BindingArray { .. } => false,
match self.ir_module.global_variables[handle].space {
crate::AddressSpace::Handle => false,
kvark marked this conversation as resolved.
Show resolved Hide resolved
_ => true,
}
}
Expand All @@ -221,7 +221,6 @@ impl<'w> BlockContext<'w> {
block: &mut Block,
) -> Result<(), Error> {
let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);

let id = match self.ir_function.expressions[expr_handle] {
crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => {
// See `is_intermediate`; we'll handle this later in
Expand All @@ -237,9 +236,15 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
let space = match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(gvar) => {
self.ir_module.global_variables[gvar].space
}
_ => unreachable!(),
};
let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
base: binding_type,
class: spirv::StorageClass::UniformConstant,
class: helpers::map_storage_class(space),
});

let result_id = match self.write_expression_pointer(
Expand All @@ -265,15 +270,6 @@ impl<'w> BlockContext<'w> {
None,
));

if self.fun_info[index].uniformity.non_uniform_result.is_some() {
self.writer.require_any(
"NonUniformEXT",
&[spirv::Capability::ShaderNonUniform],
)?;
self.writer.use_extension("SPV_EXT_descriptor_indexing");
self.writer
.decorate(load_id, spirv::Decoration::NonUniform, &[]);
}
load_id
}
ref other => {
Expand Down Expand Up @@ -316,9 +312,15 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
let space = match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(gvar) => {
self.ir_module.global_variables[gvar].space
}
_ => unreachable!(),
};
let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
base: binding_type,
class: spirv::StorageClass::UniformConstant,
class: helpers::map_storage_class(space),
});

let result_id = match self.write_expression_pointer(
Expand Down Expand Up @@ -1403,8 +1405,8 @@ impl<'w> BlockContext<'w> {
/// Emit any needed bounds-checking expressions to `block`.
///
/// Some cases we need to generate a different return type than what the IR gives us.
/// This is because pointers to binding arrays don't exist in the IR, but we need to
/// create them to create an access chain in SPIRV.
/// This is because pointers to binding arrays of handles (such as images or samplers)
/// don't exist in the IR, but we need to create them to create an access chain in SPIRV.
///
/// On success, the return value is an [`ExpressionPointer`] value; see the
/// documentation for that type.
Expand Down Expand Up @@ -1434,11 +1436,25 @@ impl<'w> BlockContext<'w> {
// but we expect these checks to almost always succeed, and keeping branches to a
// minimum is essential.
let mut accumulated_checks = None;
// Is true if we are accessing into a binding array of buffers with a non-uniform index.
let mut is_non_uniform_binding_array = false;

self.temp_list.clear();
let root_id = loop {
expr_handle = match self.ir_function.expressions[expr_handle] {
crate::Expression::Access { base, index } => {
if let crate::Expression::GlobalVariable(var_handle) =
self.ir_function.expressions[base]
{
let gvar = &self.ir_module.global_variables[var_handle];
if let crate::TypeInner::BindingArray { .. } =
self.ir_module.types[gvar.ty].inner
{
is_non_uniform_binding_array |=
self.fun_info[index].uniformity.non_uniform_result.is_some();
}
}

let index_id = match self.write_bounds_check(base, index, block)? {
BoundsCheckResult::KnownInBounds(known_index) => {
// Even if the index is known, `OpAccessIndex`
Expand Down Expand Up @@ -1471,7 +1487,6 @@ impl<'w> BlockContext<'w> {
}
};
self.temp_list.push(index_id);

base
}
crate::Expression::AccessIndex { base, index } => {
Expand All @@ -1494,10 +1509,13 @@ impl<'w> BlockContext<'w> {
}
};

let pointer = if self.temp_list.is_empty() {
ExpressionPointer::Ready {
pointer_id: root_id,
}
let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
(
root_id,
ExpressionPointer::Ready {
pointer_id: root_id,
},
)
} else {
self.temp_list.reverse();
let pointer_id = self.gen_id();
Expand All @@ -1508,16 +1526,21 @@ impl<'w> BlockContext<'w> {
// caller to generate the branch, the access, the load or store, and
// the zero value (for loads). Otherwise, we can emit the access
// ourselves, and just hand them the id of the pointer.
match accumulated_checks {
let expr_pointer = match accumulated_checks {
Some(condition) => ExpressionPointer::Conditional { condition, access },
None => {
block.body.push(access);
ExpressionPointer::Ready { pointer_id }
}
}
};
(pointer_id, expr_pointer)
};
if is_non_uniform_binding_array {
self.writer
.decorate_non_uniform_binding_array_access(pointer_id)?;
}

Ok(pointer)
Ok(expr_pointer)
}

/// Build the instructions for matrix - matrix column operations
Expand Down
3 changes: 2 additions & 1 deletion src/back/spv/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab
},
None => false,
},
// if it's not a structure, let's wrap it to be able to put "Block"
crate::TypeInner::BindingArray { .. } => false,
// if it's not a structure or a binding array, let's wrap it to be able to put "Block"
_ => true,
}
}
6 changes: 6 additions & 0 deletions src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,15 @@ enum LocalType {
image_type_id: Word,
},
Sampler,
/// Equivalent to a [`LocalType::Pointer`] whose `base` is a Naga IR [`BindingArray`]. SPIR-V
/// permits duplicated `OpTypePointer` ids, so it's fine to have two different [`LocalType`]
/// representations for pointer types.
///
/// [`BindingArray`]: crate::TypeInner::BindingArray
PointerToBindingArray {
kvark marked this conversation as resolved.
Show resolved Hide resolved
base: Handle<crate::Type>,
size: u64,
space: crate::AddressSpace,
},
BindingArray {
base: Handle<crate::Type>,
Expand Down
24 changes: 21 additions & 3 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -940,10 +940,11 @@ impl Writer {
let scalar_id = self.get_constant_scalar(crate::ScalarValue::Uint(size), 4);
Instruction::type_array(id, inner_ty, scalar_id)
}
LocalType::PointerToBindingArray { base, size } => {
LocalType::PointerToBindingArray { base, size, space } => {
let inner_ty =
self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size }));
Instruction::type_pointer(id, spirv::StorageClass::UniformConstant, inner_ty)
let class = map_storage_class(space);
Instruction::type_pointer(id, class, inner_ty)
}
LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
LocalType::RayQuery => Instruction::type_ray_query(id),
Expand Down Expand Up @@ -1579,6 +1580,9 @@ impl Writer {
}
}

// Note: we should be able to substitute `binding_array<Foo, 0>`,
// but there is still code that tries to register the pre-substituted type,
// and it is failing on 0.
let mut substitute_inner_type_lookup = None;
if let Some(ref res_binding) = global_variable.binding {
self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]);
Expand All @@ -1595,6 +1599,7 @@ impl Writer {
Some(LookupType::Local(LocalType::PointerToBindingArray {
base,
size: remapped_binding_array_size as u64,
space: global_variable.space,
}))
}
} else {
Expand Down Expand Up @@ -1635,7 +1640,13 @@ impl Writer {
// a runtime-sized array. In this case, we need to decorate it with
// Block.
if let crate::AddressSpace::Storage { .. } = global_variable.space {
self.decorate(inner_type_id, Decoration::Block, &[]);
let decorated_id = match ir_module.types[global_variable.ty].inner {
crate::TypeInner::BindingArray { base, .. } => {
self.get_type_id(LookupType::Handle(base))
}
_ => inner_type_id,
};
self.decorate(decorated_id, Decoration::Block, &[]);
}
if substitute_inner_type_lookup.is_some() {
inner_type_id
Expand Down Expand Up @@ -1955,6 +1966,13 @@ impl Writer {
pub const fn get_capabilities_used(&self) -> &crate::FastHashSet<spirv::Capability> {
&self.capabilities_used
}

pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
self.use_extension("SPV_EXT_descriptor_indexing");
self.decorate(id, spirv::Decoration::NonUniform, &[]);
Ok(())
}
}

#[test]
Expand Down
17 changes: 8 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ tree.
clippy::match_like_matches_macro,
clippy::collapsible_if,
clippy::derive_partial_eq_without_eq,
clippy::needless_borrowed_reference
clippy::needless_borrowed_reference,
clippy::single_match
)]
#![warn(
trivial_casts,
Expand Down Expand Up @@ -752,13 +753,12 @@ pub enum TypeInner {
/// buffers could have elements that are dynamically sized arrays, each with
/// a different length.
///
/// Binding arrays are not [`DATA`]. This means that all binding array
/// globals must be placed in the [`Handle`] address space. Referring to
/// such a global produces a `BindingArray` value directly; there are never
/// pointers to binding arrays. The only operation permitted on
/// `BindingArray` values is indexing, which yields the element by value,
/// not a pointer to the element. (This means that buffer array contents
/// cannot be stored to; [naga#1864] covers lifting this restriction.)
/// Binding arrays are in the same address spaces as their underlying type.
/// As such, referring to an array of images produces an [`Image`] value
/// directly (as opposed to a pointer). The only operation permitted on
/// `BindingArray` values is indexing, which works transparently: indexing
/// a binding array of samplers yields a [`Sampler`], indexing a pointer to the
/// binding array of storage buffers produces a pointer to the storage struct.
///
/// Unlike textures and samplers, binding arrays are not [`ARGUMENT`], so
/// they cannot be passed as arguments to functions.
Expand All @@ -774,7 +774,6 @@ pub enum TypeInner {
/// [`SamplerArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.SamplerArray
/// [`BufferArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.BufferArray
/// [`DATA`]: crate::valid::TypeFlags::DATA
/// [`Handle`]: AddressSpace::Handle
/// [`ARGUMENT`]: crate::valid::TypeFlags::ARGUMENT
/// [naga#1864]: https://github.com/gfx-rs/naga/issues/1864
BindingArray { base: Handle<Type>, size: ArraySize },
Expand Down
4 changes: 3 additions & 1 deletion src/proc/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,9 @@ impl crate::TypeInner {
match *base_inner {
Ti::Vector { size, .. } => size as _,
Ti::Matrix { columns, .. } => columns as _,
Ti::Array { size, .. } => return size.to_indexable_length(module),
Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
return size.to_indexable_length(module)
}
_ => return Err(IndexableLengthError::TypeNotIndexable),
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ impl<'a> ResolveContext<'a> {
width,
space,
},
Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
jimblandy marked this conversation as resolved.
Show resolved Hide resolved
ref other => {
log::error!("Access sub-type {:?}", other);
return Err(ResolveError::InvalidSubAccess {
Expand Down Expand Up @@ -401,6 +402,7 @@ impl<'a> ResolveContext<'a> {
space,
}
}
Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
ref other => {
log::error!("Access index sub-type {:?}", other);
return Err(ResolveError::InvalidSubAccess {
Expand Down
55 changes: 27 additions & 28 deletions src/valid/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,14 @@ impl super::Validator {
use super::TypeFlags;

log::debug!("var {:?}", var);
let type_info = &self.types[var.ty.index()];
let inner_ty = match types[var.ty].inner {
// A binding array is (mostly) supposed to behave the same as a
// series of individually bound resources, so we can (mostly)
// validate a `binding_array<T>` as if it were just a plain `T`.
crate::TypeInner::BindingArray { base, .. } => base,
kvark marked this conversation as resolved.
Show resolved Hide resolved
_ => var.ty,
};
let type_info = &self.types[inner_ty.index()];

let (required_type_flags, is_resource) = match var.space {
crate::AddressSpace::Function => {
Expand Down Expand Up @@ -437,22 +444,8 @@ impl super::Validator {
)
}
crate::AddressSpace::Handle => {
match types[var.ty].inner {
crate::TypeInner::Image { .. }
| crate::TypeInner::Sampler { .. }
| crate::TypeInner::BindingArray { .. }
| crate::TypeInner::AccelerationStructure
| crate::TypeInner::RayQuery => {}
_ => {
return Err(GlobalVariableError::InvalidType(var.space));
}
};
let inner_ty = match &types[var.ty].inner {
&crate::TypeInner::BindingArray { base, .. } => &types[base].inner,
ty => ty,
};
if let crate::TypeInner::Image {
class:
match types[inner_ty].inner {
crate::TypeInner::Image { class, .. } => match class {
crate::ImageClass::Storage {
format:
crate::StorageFormat::R16Unorm
Expand All @@ -462,17 +455,23 @@ impl super::Validator {
| crate::StorageFormat::Rgba16Unorm
| crate::StorageFormat::Rgba16Snorm,
..
},
..
} = *inner_ty
{
if !self
.capabilities
.contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
{
return Err(GlobalVariableError::UnsupportedCapability(
Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
));
} => {
if !self
.capabilities
.contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
{
return Err(GlobalVariableError::UnsupportedCapability(
Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
));
}
}
_ => {}
},
crate::TypeInner::Sampler { .. }
| crate::TypeInner::AccelerationStructure
| crate::TypeInner::RayQuery => {}
_ => {
return Err(GlobalVariableError::InvalidType(var.space));
}
}

Expand Down
Loading