Skip to content

Commit

Permalink
spv-out: refactor non-uniform indexing semantics to support buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Mar 18, 2023
1 parent bbb7541 commit 777f5e5
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 44 deletions.
68 changes: 46 additions & 22 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ Implementations for `BlockContext` methods.
*/

use super::{
index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, Dimension,
Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, WriterFlags,
helpers, index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext,
Dimension, Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer,
WriterFlags,
};
use crate::{arena::Handle, proc::TypeResolution};
use spirv::Word;
Expand Down Expand Up @@ -220,7 +221,6 @@ impl<'w> BlockContext<'w> {
block: &mut Block,
) -> Result<(), Error> {
let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);

let id = match self.ir_function.expressions[expr_handle] {
crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => {
// See `is_intermediate`; we'll handle this later in
Expand All @@ -236,9 +236,15 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
let space = match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(gvar) => {
self.ir_module.global_variables[gvar].space
}
_ => unreachable!(),
};
let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
base: binding_type,
class: spirv::StorageClass::UniformConstant,
class: helpers::map_storage_class(space),
});

let result_id = match self.write_expression_pointer(
Expand All @@ -264,15 +270,6 @@ impl<'w> BlockContext<'w> {
None,
));

if self.fun_info[index].uniformity.non_uniform_result.is_some() {
self.writer.require_any(
"NonUniformEXT",
&[spirv::Capability::ShaderNonUniform],
)?;
self.writer.use_extension("SPV_EXT_descriptor_indexing");
self.writer
.decorate(load_id, spirv::Decoration::NonUniform, &[]);
}
load_id
}
ref other => {
Expand Down Expand Up @@ -315,9 +312,15 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
let space = match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(gvar) => {
self.ir_module.global_variables[gvar].space
}
_ => unreachable!(),
};
let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
base: binding_type,
class: spirv::StorageClass::UniformConstant,
class: helpers::map_storage_class(space),
});

let result_id = match self.write_expression_pointer(
Expand Down Expand Up @@ -1433,11 +1436,25 @@ impl<'w> BlockContext<'w> {
// but we expect these checks to almost always succeed, and keeping branches to a
// minimum is essential.
let mut accumulated_checks = None;
// Is true if we are accessing into a binding array of buffers with a non-uniform index.
let mut is_non_uniform_binding_array = false;

self.temp_list.clear();
let root_id = loop {
expr_handle = match self.ir_function.expressions[expr_handle] {
crate::Expression::Access { base, index } => {
if let crate::Expression::GlobalVariable(var_handle) =
self.ir_function.expressions[base]
{
let gvar = &self.ir_module.global_variables[var_handle];
if let crate::TypeInner::BindingArray { .. } =
self.ir_module.types[gvar.ty].inner
{
is_non_uniform_binding_array |=
self.fun_info[index].uniformity.non_uniform_result.is_some();
}
}

let index_id = match self.write_bounds_check(base, index, block)? {
BoundsCheckResult::KnownInBounds(known_index) => {
// Even if the index is known, `OpAccessIndex`
Expand Down Expand Up @@ -1470,7 +1487,6 @@ impl<'w> BlockContext<'w> {
}
};
self.temp_list.push(index_id);

base
}
crate::Expression::AccessIndex { base, index } => {
Expand All @@ -1493,10 +1509,13 @@ impl<'w> BlockContext<'w> {
}
};

let pointer = if self.temp_list.is_empty() {
ExpressionPointer::Ready {
pointer_id: root_id,
}
let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
(
root_id,
ExpressionPointer::Ready {
pointer_id: root_id,
},
)
} else {
self.temp_list.reverse();
let pointer_id = self.gen_id();
Expand All @@ -1507,16 +1526,21 @@ impl<'w> BlockContext<'w> {
// caller to generate the branch, the access, the load or store, and
// the zero value (for loads). Otherwise, we can emit the access
// ourselves, and just hand them the id of the pointer.
match accumulated_checks {
let expr_pointer = match accumulated_checks {
Some(condition) => ExpressionPointer::Conditional { condition, access },
None => {
block.body.push(access);
ExpressionPointer::Ready { pointer_id }
}
}
};
(pointer_id, expr_pointer)
};
if is_non_uniform_binding_array {
self.writer
.decorate_non_uniform_binding_array_access(pointer_id)?;
}

Ok(pointer)
Ok(expr_pointer)
}

/// Build the instructions for matrix - matrix column operations
Expand Down
7 changes: 7 additions & 0 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,13 @@ impl Writer {
pub const fn get_capabilities_used(&self) -> &crate::FastHashSet<spirv::Capability> {
&self.capabilities_used
}

pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
self.use_extension("SPV_EXT_descriptor_indexing");
self.decorate(id, spirv::Decoration::NonUniform, &[]);
Ok(())
}
}

#[test]
Expand Down
44 changes: 22 additions & 22 deletions tests/out/spv/binding-arrays.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,28 @@ OpMemberDecorate %49 0 Offset 0
OpDecorate %65 Location 0
OpDecorate %65 Flat
OpDecorate %68 Location 0
OpDecorate %97 NonUniform
OpDecorate %121 NonUniform
OpDecorate %123 NonUniform
OpDecorate %148 NonUniform
OpDecorate %150 NonUniform
OpDecorate %188 NonUniform
OpDecorate %219 NonUniform
OpDecorate %238 NonUniform
OpDecorate %257 NonUniform
OpDecorate %279 NonUniform
OpDecorate %281 NonUniform
OpDecorate %303 NonUniform
OpDecorate %305 NonUniform
OpDecorate %327 NonUniform
OpDecorate %329 NonUniform
OpDecorate %351 NonUniform
OpDecorate %353 NonUniform
OpDecorate %375 NonUniform
OpDecorate %377 NonUniform
OpDecorate %399 NonUniform
OpDecorate %401 NonUniform
OpDecorate %424 NonUniform
OpDecorate %96 NonUniform
OpDecorate %120 NonUniform
OpDecorate %122 NonUniform
OpDecorate %147 NonUniform
OpDecorate %149 NonUniform
OpDecorate %187 NonUniform
OpDecorate %218 NonUniform
OpDecorate %237 NonUniform
OpDecorate %256 NonUniform
OpDecorate %278 NonUniform
OpDecorate %280 NonUniform
OpDecorate %302 NonUniform
OpDecorate %304 NonUniform
OpDecorate %326 NonUniform
OpDecorate %328 NonUniform
OpDecorate %350 NonUniform
OpDecorate %352 NonUniform
OpDecorate %374 NonUniform
OpDecorate %376 NonUniform
OpDecorate %398 NonUniform
OpDecorate %400 NonUniform
OpDecorate %423 NonUniform
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpConstant %4 5
Expand Down
3 changes: 3 additions & 0 deletions tests/out/spv/binding-buffer-arrays.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
; Generator: rspirv
; Bound: 66
OpCapability Shader
OpCapability ShaderNonUniform
OpExtension "SPV_KHR_storage_buffer_storage_class"
OpExtension "SPV_EXT_descriptor_indexing"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %29 "main" %24 %27
Expand All @@ -23,6 +25,7 @@ OpDecorate %24 Location 0
OpDecorate %24 Flat
OpDecorate %27 Location 0
OpDecorate %27 Flat
OpDecorate %57 NonUniform
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpConstant %4 1
Expand Down

0 comments on commit 777f5e5

Please sign in to comment.