From 17653e41386a0dc155d9714ca5a4addc0d52ca11 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 1 Feb 2023 22:19:42 -0800 Subject: [PATCH 01/12] Add ray query types to the IR --- src/back/glsl/mod.rs | 2 ++ src/back/msl/writer.rs | 9 ++++++- src/back/spv/writer.rs | 4 ++- src/front/wgsl/lower/mod.rs | 2 ++ src/front/wgsl/mod.rs | 2 ++ src/front/wgsl/parse/ast.rs | 2 ++ src/front/wgsl/parse/mod.rs | 2 ++ src/lib.rs | 5 ++++ src/proc/layouter.rs | 6 ++++- src/proc/mod.rs | 6 ++++- src/valid/expression.rs | 4 +-- src/valid/handles.rs | 4 ++- src/valid/mod.rs | 6 ++++- src/valid/type.rs | 50 +++++++++++++++++++++++++----------- tests/in/ray-query.param.ron | 6 +++++ tests/in/ray-query.wgsl | 17 ++++++++++++ tests/snapshots.rs | 1 + 17 files changed, 105 insertions(+), 23 deletions(-) create mode 100644 tests/in/ray-query.param.ron create mode 100644 tests/in/ray-query.wgsl diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 44685fb99e..419f4c4385 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -879,6 +879,8 @@ impl<'a, W: Write> Writer<'a, W> { | TypeInner::Struct { .. } | TypeInner::Image { .. } | TypeInner::Sampler { .. } + | TypeInner::AccelerationStructure + | TypeInner::RayQuery | TypeInner::BindingArray { .. } => { return Err(Error::Custom(format!("Unable to write type {inner:?}"))) } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index ee23ca294a..68945315c9 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -194,6 +194,9 @@ impl<'a> Display for TypeContext<'a> { crate::TypeInner::Sampler { comparison: _ } => { write!(out, "{NAMESPACE}::sampler") } + crate::TypeInner::AccelerationStructure | crate::TypeInner::RayQuery => { + unreachable!("Ray queries are not supported yet"); + } crate::TypeInner::BindingArray { base, size } => { let base_tyname = Self { handle: base, @@ -485,7 +488,11 @@ impl crate::Type { // composite types are better to be aliased, regardless of the name Ti::Struct { .. } | Ti::Array { .. } => true, // handle types may be different, depending on the global var access, so we always inline them - Ti::Image { .. } | Ti::Sampler { .. } | Ti::BindingArray { .. } => false, + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => false, } } } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index f264c107d3..fbc53feedd 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1017,7 +1017,9 @@ impl Writer { | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Image { .. } - | crate::TypeInner::Sampler { .. } => unreachable!(), + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => unreachable!(), }; instruction.to_words(&mut self.logical_layout.declarations); diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index f3b157caa7..8dfb735af4 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -2245,6 +2245,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { class, }, ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison }, + ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure, + ast::Type::RayQuery => crate::TypeInner::RayQuery, ast::Type::BindingArray { base, size } => { let base = self.resolve_ast_type(base, ctx.reborrow())?; diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 8ac82fe45e..eb21fae6c9 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -206,6 +206,8 @@ impl crate::TypeInner { format!("texture{class_suffix}{dim_suffix}{array_suffix}{type_in_brackets}") } Ti::Sampler { .. } => "sampler".to_string(), + Ti::AccelerationStructure => "acceleration_structure".to_string(), + Ti::RayQuery => "ray_query".to_string(), Ti::BindingArray { base, size, .. } => { let member_type = &types[base]; let base = member_type.name.as_deref().unwrap_or("unknown"); diff --git a/src/front/wgsl/parse/ast.rs b/src/front/wgsl/parse/ast.rs index 734d9769fe..a5da4a49cc 100644 --- a/src/front/wgsl/parse/ast.rs +++ b/src/front/wgsl/parse/ast.rs @@ -229,6 +229,8 @@ pub enum Type<'a> { Sampler { comparison: bool, }, + AccelerationStructure, + RayQuery, BindingArray { base: Handle>, size: ArraySize<'a>, diff --git a/src/front/wgsl/parse/mod.rs b/src/front/wgsl/parse/mod.rs index 7ff762d673..f082ec1c4e 100644 --- a/src/front/wgsl/parse/mod.rs +++ b/src/front/wgsl/parse/mod.rs @@ -1367,6 +1367,8 @@ impl Parser { class: crate::ImageClass::Storage { format, access }, } } + "acceleration_structure" => ast::Type::AccelerationStructure, + "ray_query" => ast::Type::RayQuery, _ => return Ok(None), })) } diff --git a/src/lib.rs b/src/lib.rs index c1b48b8991..f8491f3a36 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -721,6 +721,11 @@ pub enum TypeInner { /// Can be used to sample values from images. Sampler { comparison: bool }, + /// Opaque object representing an acceleration structure of geometry. + AccelerationStructure, + /// Locally used handle for ray queries. + RayQuery, + /// Array of bindings. /// /// A `BindingArray` represents an array where each element draws its value diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index db07f261a4..65369d1cc8 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -238,7 +238,11 @@ impl Layouter { alignment, } } - Ti::Image { .. } | Ti::Sampler { .. } | Ti::BindingArray { .. } => TypeLayout { + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => TypeLayout { size, alignment: Alignment::ONE, }, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 6a8bfa03c7..a775272a19 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -134,7 +134,11 @@ impl super::TypeInner { count * stride } Self::Struct { span, .. } => span, - Self::Image { .. } | Self::Sampler { .. } | Self::BindingArray { .. } => 0, + Self::Image { .. } + | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery + | Self::BindingArray { .. } => 0, } } diff --git a/src/valid/expression.rs b/src/valid/expression.rs index af080fc183..9063eb0616 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1379,7 +1379,7 @@ impl super::Validator { _ => return Err(ExpressionError::InvalidCastArgument), }; let width = convert.unwrap_or(base_width); - if !self.check_width(kind, width) { + if self.check_width(kind, width).is_err() { return Err(ExpressionError::InvalidCastArgument); } ShaderStages::all() @@ -1390,7 +1390,7 @@ impl super::Validator { &crate::TypeInner::Scalar { kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint), width, - } => self.check_width(kind, width), + } => self.check_width(kind, width).is_ok(), _ => false, }; let good = match &module.types[ty].inner { diff --git a/src/valid/handles.rs b/src/valid/handles.rs index e3f9fe2531..871a73a219 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -76,7 +76,9 @@ impl super::Validator { | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Atomic { .. } | crate::TypeInner::Image { .. } - | crate::TypeInner::Sampler { .. } => (), + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => (), crate::TypeInner::Pointer { base, space: _ } => { this_handle.check_dep(base)?; } diff --git a/src/valid/mod.rs b/src/valid/mod.rs index eb92e8892d..6b3a2e1456 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -111,6 +111,8 @@ bitflags::bitflags! { const EARLY_DEPTH_TEST = 0x400; /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`]. const MULTISAMPLED_SHADING = 0x800; + /// Support for ray queries and acceleration structures. + const RAY_QUERY = 0x1000; } } @@ -238,6 +240,8 @@ impl crate::TypeInner { Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery | Self::BindingArray { .. } => false, } } @@ -302,7 +306,7 @@ impl Validator { let con = &constants[handle]; match con.inner { crate::ConstantInner::Scalar { width, ref value } => { - if !self.check_width(value.scalar_kind(), width) { + if self.check_width(value.scalar_kind(), width).is_err() { return Err(ConstantError::InvalidType); } } diff --git a/src/valid/type.rs b/src/valid/type.rs index 4fcc1a1c58..23f6ef4d1f 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -90,6 +90,8 @@ pub enum Disalignment { #[derive(Clone, Debug, thiserror::Error)] pub enum TypeError { + #[error("Capability {0:?} is required")] + MissingCapability(Capabilities), #[error("The {0:?} scalar width {1} is not supported")] InvalidWidth(crate::ScalarKind, crate::Bytes), #[error("The {0:?} scalar width {1} is not supported for an atomic")] @@ -203,13 +205,35 @@ impl TypeInfo { } impl super::Validator { - pub(super) const fn check_width(&self, kind: crate::ScalarKind, width: crate::Bytes) -> bool { - match kind { + fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> { + if self.capabilities.contains(capability) { + Ok(()) + } else { + Err(TypeError::MissingCapability(capability)) + } + } + + pub(super) fn check_width( + &self, + kind: crate::ScalarKind, + width: crate::Bytes, + ) -> Result<(), TypeError> { + let good = match kind { crate::ScalarKind::Bool => width == crate::BOOL_WIDTH, crate::ScalarKind::Float => { - width == 4 || (width == 8 && self.capabilities.contains(Capabilities::FLOAT64)) + if width == 8 { + self.require_type_capability(Capabilities::FLOAT64)?; + true + } else { + width == 4 + } } crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4, + }; + if good { + Ok(()) + } else { + Err(TypeError::InvalidWidth(kind, width)) } } @@ -228,9 +252,7 @@ impl super::Validator { use crate::TypeInner as Ti; Ok(match types[handle].inner { Ti::Scalar { kind, width } => { - if !self.check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } + self.check_width(kind, width)?; let shareable = if kind.is_numeric() { TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE } else { @@ -247,9 +269,7 @@ impl super::Validator { ) } Ti::Vector { size, kind, width } => { - if !self.check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } + self.check_width(kind, width)?; let shareable = if kind.is_numeric() { TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE } else { @@ -271,9 +291,7 @@ impl super::Validator { rows, width, } => { - if !self.check_width(crate::ScalarKind::Float, width) { - return Err(TypeError::InvalidWidth(crate::ScalarKind::Float, width)); - } + self.check_width(crate::ScalarKind::Float, width)?; TypeInfo::new( TypeFlags::DATA | TypeFlags::SIZED @@ -355,9 +373,7 @@ impl super::Validator { // However, some cases are trivial: All our implicit base types // are DATA and SIZED, so we can never return // `InvalidPointerBase` or `InvalidPointerToUnsized`. - if !self.check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } + self.check_width(kind, width)?; // `Validator::validate_function` actually checks the storage // space of pointer arguments explicitly before checking the @@ -606,6 +622,10 @@ impl super::Validator { Ti::Image { .. } | Ti::Sampler { .. } => { TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) } + Ti::AccelerationStructure | Ti::RayQuery => { + self.require_type_capability(Capabilities::RAY_QUERY)?; + TypeInfo::new(TypeFlags::empty(), Alignment::ONE) + } Ti::BindingArray { .. } => TypeInfo::new(TypeFlags::empty(), Alignment::ONE), }) } diff --git a/tests/in/ray-query.param.ron b/tests/in/ray-query.param.ron new file mode 100644 index 0000000000..9d8666954d --- /dev/null +++ b/tests/in/ray-query.param.ron @@ -0,0 +1,6 @@ +( + god_mode: true, + spv: ( + version: (1, 4), + ), +) diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl new file mode 100644 index 0000000000..26023ff2f1 --- /dev/null +++ b/tests/in/ray-query.wgsl @@ -0,0 +1,17 @@ +var acc_struct: acceleration_structure; + +struct Output { + visible: u32, +} +var output: Output; + +@compute +fn main() { + var rq: ray_query; + + rayQueryInitialize(rq, acceleration_structure, RAY_FLAGS_TERMINATE_ON_FIRST_HIT, 0xFF, vec3(0.0), 0.1, vec3(0.0, 1.0, 0.0), 100.0); + + rayQueryProceed(rq); + + output.visible = rayQueryGetCommittedIntersectionType(rq) == RAY_QUERY_COMMITTED_INTERSECTION_NONE; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 691c074a93..df94130a70 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -571,6 +571,7 @@ fn convert_wgsl() { ("sprite", Targets::SPIRV), ("force_point_size_vertex_shader_webgl", Targets::GLSL), ("invariant", Targets::GLSL), + ("ray-query", Targets::SPIRV), ]; for &(name, targets) in inputs.iter() { From bd59af51ff73b9bee799b8f3729d7e37f3ce2045 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Mon, 13 Feb 2023 21:12:06 -0800 Subject: [PATCH 02/12] Add ray query statements to the IR --- src/back/dot/mod.rs | 16 ++++++++++++++++ src/back/glsl/mod.rs | 1 + src/back/hlsl/writer.rs | 1 + src/back/msl/writer.rs | 1 + src/back/spv/block.rs | 11 +++++++++++ src/back/wgsl/writer.rs | 1 + src/front/spv/mod.rs | 3 ++- src/lib.rs | 17 +++++++++++++++++ src/proc/terminator.rs | 1 + src/valid/analyzer.rs | 12 ++++++++++++ src/valid/function.rs | 3 +++ src/valid/handles.rs | 14 ++++++++++++++ tests/in/ray-query.wgsl | 20 +++++++++++++++++++- 13 files changed, 99 insertions(+), 2 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 1167357e8d..f53d2faa1d 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -252,6 +252,22 @@ impl StatementGraph { } "Atomic" } + S::RayQuery { query, ref fun } => { + self.dependencies.push((id, query, "query")); + if let crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } = *fun + { + self.dependencies.push(( + id, + acceleration_structure, + "acceleration_structure", + )); + self.dependencies.push((id, descriptor, "descriptor")); + } + "RayQuery" + } }; // Set the last node to the merge node last_node = merge_id; diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 419f4c4385..e81c27c505 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2197,6 +2197,7 @@ impl<'a, W: Write> Writer<'a, W> { self.write_expr(value, ctx)?; writeln!(self.out, ");")?; } + Statement::RayQuery { .. } => unreachable!(), } Ok(()) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index d11032bbf5..9e9aa19d76 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1980,6 +1980,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } + Statement::RayQuery { .. } => unreachable!(), } Ok(()) diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 68945315c9..fa222d8d4b 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -2759,6 +2759,7 @@ impl Writer { // done writeln!(self.out, ";")?; } + crate::Statement::RayQuery { .. } => unreachable!(), } } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index c3fa8455e9..02a83a19aa 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2196,6 +2196,17 @@ impl<'w> BlockContext<'w> { block.body.push(instruction); } + crate::Statement::RayQuery { query, ref fun } => { + let query_id = self.cached[query]; + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => {} + crate::RayQueryFunction::Proceed => {} + crate::RayQueryFunction::Terminate => {} + } + } } } diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index f24f4a9c26..90d3b5330e 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -937,6 +937,7 @@ impl Writer { writeln!(self.out, "{level}workgroupBarrier();")?; } } + Statement::RayQuery { .. } => unreachable!(), } Ok(()) diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index ce42be35b2..c69a230cb0 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3672,7 +3672,8 @@ impl> Frontend { | S::Barrier(_) | S::Store { .. } | S::ImageStore { .. } - | S::Atomic { .. } => {} + | S::Atomic { .. } + | S::RayQuery { .. } => {} S::Call { function: ref mut callee, ref arguments, diff --git a/src/lib.rs b/src/lib.rs index f8491f3a36..efb05c8ff1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1473,6 +1473,19 @@ pub struct SwitchCase { pub fall_through: bool, } +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum RayQueryFunction { + Initialize { + acceleration_structure: Handle, + descriptor: Handle, + }, + Proceed, + Terminate, +} + //TODO: consider removing `Clone`. It's not valid to clone `Statement::Emit` anyway. /// Instructions which make up an executable block. // Clone is used only for error reporting and is not intended for end users @@ -1646,6 +1659,10 @@ pub enum Statement { arguments: Vec>, result: Option>, }, + RayQuery { + query: Handle, + fun: RayQueryFunction, + }, } /// A function argument. diff --git a/src/proc/terminator.rs b/src/proc/terminator.rs index 5915616cc5..ca0c3f10bc 100644 --- a/src/proc/terminator.rs +++ b/src/proc/terminator.rs @@ -34,6 +34,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::Store { .. } | S::ImageStore { .. } | S::Call { .. } + | S::RayQuery { .. } | S::Atomic { .. } | S::Barrier(_)), ) diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 40d5f95c10..8d19430980 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -893,6 +893,18 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::RayQuery { query, ref fun } => { + let _ = self.add_ref(query); + if let crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } = *fun + { + let _ = self.add_ref(acceleration_structure); + let _ = self.add_ref(descriptor); + } + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/src/valid/function.rs b/src/valid/function.rs index 464496f6d6..a13a07bcfa 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -807,6 +807,9 @@ impl super::Validator { } => { self.validate_atomic(pointer, fun, value, result, context)?; } + S::RayQuery { query: _, fun: _ } => { + //TODO + } } } Ok(BlockInfo { stages, finished }) diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 871a73a219..20d8448a1a 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -496,6 +496,20 @@ impl super::Validator { validate_expr_opt(result)?; Ok(()) } + crate::Statement::RayQuery { query, ref fun } => { + validate_expr(query)?; + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + validate_expr(acceleration_structure)?; + validate_expr(descriptor)?; + } + crate::RayQueryFunction::Proceed | crate::RayQueryFunction::Terminate => {} + } + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 26023ff2f1..b772e69d7d 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -1,5 +1,23 @@ var acc_struct: acceleration_structure; +/* +let RAY_FLAG_NONE = 0u; +let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 4u; + +let RAY_QUERY_INTERSECTION_NONE = 0u; +let RAY_QUERY_INTERSECTION_TRIANGLE = 1u; +let RAY_QUERY_INTERSECTION_GENERATED = 2u; +let RAY_QUERY_INTERSECTION_AABB = 4u; + +struct RayDesc { + flags: u32, + cull_mask: u32, + origin: vec3, + t_min: f32, + dir: vec3, + t_max: f32, +}*/ + struct Output { visible: u32, } @@ -9,7 +27,7 @@ var output: Output; fn main() { var rq: ray_query; - rayQueryInitialize(rq, acceleration_structure, RAY_FLAGS_TERMINATE_ON_FIRST_HIT, 0xFF, vec3(0.0), 0.1, vec3(0.0, 1.0, 0.0), 100.0); + rayQueryInitialize(rq, acceleration_structure, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFF, vec3(0.0), 0.1, vec3(0.0, 1.0, 0.0), 100.0)); rayQueryProceed(rq); From 0753179e74613fe3badcd9c3da3cbb6c276397d1 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 15 Feb 2023 23:35:03 -0800 Subject: [PATCH 03/12] Ray query expressions and special types --- src/back/dot/mod.rs | 6 ++ src/back/glsl/mod.rs | 6 +- src/back/hlsl/writer.rs | 6 +- src/back/msl/writer.rs | 6 +- src/back/spv/block.rs | 4 + src/back/wgsl/writer.rs | 6 +- src/front/glsl/constants.rs | 5 + src/front/glsl/types.rs | 18 +--- src/front/mod.rs | 1 + src/front/type_gen.rs | 153 +++++++++++++++++++++++++++ src/front/wgsl/lower/construction.rs | 4 + src/front/wgsl/lower/mod.rs | 58 ++++++++-- src/front/wgsl/parse/ast.rs | 3 + src/front/wgsl/parse/mod.rs | 13 +++ src/lib.rs | 25 +++++ src/proc/typifier.rs | 31 ++++++ src/valid/analyzer.rs | 20 ++-- src/valid/expression.rs | 1 + src/valid/handles.rs | 16 ++- src/valid/interface.rs | 4 +- src/valid/type.rs | 6 +- tests/in/ray-query.wgsl | 30 ++++-- tests/out/ir/access.ron | 4 + tests/out/ir/collatz.ron | 4 + tests/out/ir/shadow.ron | 4 + 25 files changed, 388 insertions(+), 46 deletions(-) create mode 100644 src/front/type_gen.rs diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index f53d2faa1d..d293c3adc1 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -566,6 +566,12 @@ fn write_function_expressions( edges.insert("", expr); ("ArrayLength".into(), 7) } + E::RayQueryProceedResult => ("rayQueryProceedResult".into(), 4), + E::RayQueryGetIntersection { query, committed } => { + edges.insert("", query); + let ty = if committed { "Committed" } else { "Candidate" }; + (format!("rayQueryGet{}Intersection", ty).into(), 4) + } }; // give uniform expressions an outline diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index e81c27c505..9195b96837 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -3280,13 +3280,17 @@ impl<'a, W: Write> Writer<'a, W> { } } // These expressions never show up in `Emit`. - Expression::CallResult(_) | Expression::AtomicResult { .. } => unreachable!(), + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; self.write_expr(expr, ctx)?; write!(self.out, ".length())")? } + // not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), } Ok(()) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 9e9aa19d76..f9e52914f7 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2879,8 +2879,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_expr(module, reject, func_ctx)?; write!(self.out, ")")? } + // Not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), // Nothing to do here, since call expression already cached - Expression::CallResult(_) | Expression::AtomicResult { .. } => {} + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult => {} } if !closing_bracket.is_empty() { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index fa222d8d4b..2ec5bad339 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1838,7 +1838,9 @@ impl Writer { _ => return Err(Error::Validation), }, // has to be a named expression - crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => { + crate::Expression::CallResult(_) + | crate::Expression::AtomicResult { .. } + | crate::Expression::RayQueryProceedResult => { unreachable!() } crate::Expression::ArrayLength(expr) => { @@ -1863,6 +1865,8 @@ impl Writer { write!(self.out, ")")?; } } + // hot supported yet + crate::Expression::RayQueryGetIntersection { .. } => unreachable!(), } Ok(()) } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 02a83a19aa..090899b4a7 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1386,6 +1386,10 @@ impl<'w> BlockContext<'w> { id } crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, + //TODO + crate::Expression::RayQueryProceedResult => unreachable!(), + //TODO + crate::Expression::RayQueryGetIntersection { .. } => unreachable!(), }; self.cached[expr_handle] = id; diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 90d3b5330e..92086c94a8 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1622,8 +1622,12 @@ impl Writer { write!(self.out, ")")? } + // Not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), // Nothing to do here, since call expression already cached - Expression::CallResult(_) | Expression::AtomicResult { .. } => {} + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult => {} } Ok(()) diff --git a/src/front/glsl/constants.rs b/src/front/glsl/constants.rs index d9a6fc7cd7..045a9c6ffb 100644 --- a/src/front/glsl/constants.rs +++ b/src/front/glsl/constants.rs @@ -37,6 +37,8 @@ pub enum ConstantSolvingError { Load, #[error("Constants don't support image expressions")] ImageExpression, + #[error("Constants don't support ray query expressions")] + RayQueryExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -295,6 +297,9 @@ impl<'a> ConstantSolver<'a> { Expression::ImageSample { .. } | Expression::ImageLoad { .. } | Expression::ImageQuery { .. } => Err(ConstantSolvingError::ImageExpression), + Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { + Err(ConstantSolvingError::RayQueryExpression) + } } } diff --git a/src/front/glsl/types.rs b/src/front/glsl/types.rs index 632378c60b..a7967848d5 100644 --- a/src/front/glsl/types.rs +++ b/src/front/glsl/types.rs @@ -246,14 +246,7 @@ impl Frontend { expr: Handle, meta: Span, ) -> Result<()> { - let resolve_ctx = ResolveContext { - constants: &self.module.constants, - types: &self.module.types, - global_vars: &self.module.global_variables, - local_vars: &ctx.locals, - functions: &self.module.functions, - arguments: &ctx.arguments, - }; + let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments); ctx.typifier .grow(expr, &ctx.expressions, &resolve_ctx) @@ -312,14 +305,7 @@ impl Frontend { expr: Handle, meta: Span, ) -> Result<()> { - let resolve_ctx = ResolveContext { - constants: &self.module.constants, - types: &self.module.types, - global_vars: &self.module.global_variables, - local_vars: &ctx.locals, - functions: &self.module.functions, - arguments: &ctx.arguments, - }; + let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments); ctx.typifier .invalidate(expr, &ctx.expressions, &resolve_ctx) diff --git a/src/front/mod.rs b/src/front/mod.rs index 071e805a69..d6f38671ea 100644 --- a/src/front/mod.rs +++ b/src/front/mod.rs @@ -3,6 +3,7 @@ Frontend parsers that consume binary and text shaders and load them into [`Modul */ mod interpolator; +mod type_gen; #[cfg(feature = "glsl-in")] pub mod glsl; diff --git a/src/front/type_gen.rs b/src/front/type_gen.rs new file mode 100644 index 0000000000..18d9ddd54c --- /dev/null +++ b/src/front/type_gen.rs @@ -0,0 +1,153 @@ +/*! +Type generators. +*/ + +use crate::{arena::Handle, span::Span}; + +impl crate::Module { + pub(super) fn generate_ray_desc_type(&mut self) -> Handle { + if let Some(handle) = self.special_types.ray_desc { + return handle; + } + + let width = 4; + let ty_flag = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width, + kind: crate::ScalarKind::Uint, + }, + }, + Span::UNDEFINED, + ); + let ty_scalar = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width, + kind: crate::ScalarKind::Float, + }, + }, + Span::UNDEFINED, + ); + let ty_vector = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + kind: crate::ScalarKind::Float, + width, + }, + }, + Span::UNDEFINED, + ); + + let handle = self.types.insert( + crate::Type { + name: Some("RayDesc".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("flags".to_string()), + ty: ty_flag, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("cull_mask".to_string()), + ty: ty_flag, + binding: None, + offset: 4, + }, + crate::StructMember { + name: Some("tmin".to_string()), + ty: ty_scalar, + binding: None, + offset: 8, + }, + crate::StructMember { + name: Some("tmax".to_string()), + ty: ty_scalar, + binding: None, + offset: 12, + }, + crate::StructMember { + name: Some("origin".to_string()), + ty: ty_vector, + binding: None, + offset: 16, + }, + crate::StructMember { + name: Some("dir".to_string()), + ty: ty_vector, + binding: None, + offset: 32, + }, + ], + span: 48, + }, + }, + Span::UNDEFINED, + ); + + self.special_types.ray_desc = Some(handle); + handle + } + + pub(super) fn generate_ray_intersection_type(&mut self) -> Handle { + if let Some(handle) = self.special_types.ray_intersection { + return handle; + } + + let width = 4; + let ty_flag = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width, + kind: crate::ScalarKind::Uint, + }, + }, + Span::UNDEFINED, + ); + let ty_scalar = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width, + kind: crate::ScalarKind::Float, + }, + }, + Span::UNDEFINED, + ); + + let handle = self.types.insert( + crate::Type { + name: Some("RayIntersection".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("kind".to_string()), + ty: ty_flag, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("t".to_string()), + ty: ty_scalar, + binding: None, + offset: 4, + }, + //TODO: the rest + ], + span: 8, + }, + }, + Span::UNDEFINED, + ); + + self.special_types.ray_intersection = Some(handle); + handle + } +} diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index 723d4441f5..4b0371573a 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -660,6 +660,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }); ConcreteConstructorHandle::Type(ty) } + ast::ConstructorType::RayDesc => { + let ty = ctx.module.generate_ray_desc_type(); + ConcreteConstructorHandle::Type(ty) + } ast::ConstructorType::Type(ty) => ConcreteConstructorHandle::Type(ty), }; diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 8dfb735af4..971b3032b9 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -234,14 +234,8 @@ impl<'a> ExpressionContext<'a, '_, '_> { /// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner /// [`Typifier`]: Typifier fn grow_types(&mut self, handle: Handle) -> Result<&mut Self, Error<'a>> { - let resolve_ctx = ResolveContext { - constants: &self.module.constants, - types: &self.module.types, - global_vars: &self.module.global_variables, - local_vars: self.local_vars, - functions: &self.module.functions, - arguments: self.arguments, - }; + let resolve_ctx = + ResolveContext::with_locals(&self.module, self.local_vars, self.arguments); self.typifier .grow(handle, self.naga_expressions, &resolve_ctx) .map_err(Error::InvalidResolve)?; @@ -1919,6 +1913,54 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { query: crate::ImageQuery::NumSamples, } } + "rayQueryInitialize" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let query = self.expression(args.next()?, ctx.reborrow())?; + let acceleration_structure = + self.expression(args.next()?, ctx.reborrow())?; + let descriptor = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let _ = ctx.module.generate_ray_desc_type(); + let fun = crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + }; + + ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions)); + ctx.emitter.start(ctx.naga_expressions); + ctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(None); + } + "rayQueryProceed" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let fun = crate::RayQueryFunction::Proceed; + + ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions)); + let result = ctx + .naga_expressions + .append(crate::Expression::RayQueryProceedResult, span); + ctx.emitter.start(ctx.naga_expressions); + ctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(Some(result)); + } + "rayQueryGetCommittedIntersection" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let _ = ctx.module.generate_ray_intersection_type(); + + crate::Expression::RayQueryGetIntersection { + query, + committed: true, + } + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; diff --git a/src/front/wgsl/parse/ast.rs b/src/front/wgsl/parse/ast.rs index a5da4a49cc..9354c6c765 100644 --- a/src/front/wgsl/parse/ast.rs +++ b/src/front/wgsl/parse/ast.rs @@ -370,6 +370,9 @@ pub enum ConstructorType<'a> { size: ArraySize<'a>, }, + /// Ray description. + RayDesc, + /// Constructing a value of a known Naga IR type. /// /// This variant is produced only during lowering, when we have Naga types diff --git a/src/front/wgsl/parse/mod.rs b/src/front/wgsl/parse/mod.rs index f082ec1c4e..e4a0d160fd 100644 --- a/src/front/wgsl/parse/mod.rs +++ b/src/front/wgsl/parse/mod.rs @@ -441,6 +441,7 @@ impl Parser { })) } "array" => ast::ConstructorType::PartialArray, + "RayDesc" => ast::ConstructorType::RayDesc, "atomic" | "binding_array" | "sampler" @@ -622,6 +623,18 @@ impl Parser { let num = res.map_err(|err| Error::BadNumber(span, err))?; ast::Expression::Literal(ast::Literal::Number(num)) } + (Token::Word("RAY_FLAG_NONE"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(0))) + } + (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(4))) + } + (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(0))) + } (Token::Word(word), span) => { let start = lexer.start_byte_offset(); let _ = lexer.next(); diff --git a/src/lib.rs b/src/lib.rs index efb05c8ff1..9344fd0a53 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -107,6 +107,9 @@ Naga's rules for when `Expression`s are evaluated are as follows: [`Atomic`] statement, representing the result of the atomic operation, is evaluated when the `Atomic` statement is executed. +- Similarly, an [`RayQueryProceedResult`] expression, which is a boolean + indicating if the ray query is finished. + - All other expressions are evaluated when the (unique) [`Statement::Emit`] statement that covers them is executed. @@ -1441,6 +1444,13 @@ pub enum Expression { /// This doesn't match the semantics of spirv's `OpArrayLength`, which must be passed /// a pointer to a structure containing a runtime array in its' last field. ArrayLength(Handle), + /// Result of `rayQueryProceed`. + RayQueryProceedResult, + /// Result of `rayQueryGet*Intersection`. + RayQueryGetIntersection { + query: Handle, + committed: bool, + }, } pub use block::Block; @@ -1779,6 +1789,19 @@ pub struct EntryPoint { pub function: Function, } +/// Set of special types that can be optionally generated by the frontends. +#[derive(Debug, Default)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct SpecialTypes { + /// Type for `RayDesc`. + ray_desc: Option>, + /// Type for `RayIntersection`. + ray_intersection: Option>, +} + /// Shader module. /// /// A module is a set of constants, global variables and functions, as well as @@ -1798,6 +1821,8 @@ pub struct EntryPoint { pub struct Module { /// Arena for the types defined in this module. pub types: UniqueArena, + /// Dictionary of special type handles. + pub special_types: SpecialTypes, /// Arena for the constants defined in this module. pub constants: Arena, /// Arena for the global variables defined in this module. diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index b9ac468313..f7a4ba94d5 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -193,11 +193,14 @@ pub enum ResolveError { IncompatibleOperands(String), #[error("Function argument {0} doesn't exist")] FunctionArgumentNotFound(u32), + #[error("Special type is not registered within the module")] + MissingSpecialType, } pub struct ResolveContext<'a> { pub constants: &'a Arena, pub types: &'a UniqueArena, + pub special_types: &'a crate::SpecialTypes, pub global_vars: &'a Arena, pub local_vars: &'a Arena, pub functions: &'a Arena, @@ -205,6 +208,23 @@ pub struct ResolveContext<'a> { } impl<'a> ResolveContext<'a> { + /// Initialize a resolve context from the module. + pub fn with_locals( + module: &'a crate::Module, + local_vars: &'a Arena, + arguments: &'a [crate::FunctionArgument], + ) -> Self { + Self { + constants: &module.constants, + types: &module.types, + special_types: &module.special_types, + global_vars: &module.global_variables, + local_vars, + functions: &module.functions, + arguments, + } + } + /// Determine the type of `expr`. /// /// The `past` argument must be a closure that can resolve the types of any @@ -867,6 +887,17 @@ impl<'a> ResolveContext<'a> { kind: crate::ScalarKind::Uint, width: 4, }), + crate::Expression::RayQueryProceedResult => TypeResolution::Value(Ti::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }), + crate::Expression::RayQueryGetIntersection { .. } => { + let result = self + .special_types + .ray_intersection + .ok_or(ResolveError::MissingSpecialType)?; + TypeResolution::Handle(result) + } }) } } diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 8d19430980..e9b155b6eb 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -686,7 +686,7 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, E::CallResult(function) => other_functions[function.index()].uniformity.clone(), - E::AtomicResult { .. } => Uniformity { + E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity { non_uniform_result: Some(handle), requirements: UniformityRequirements::empty(), }, @@ -694,6 +694,13 @@ impl FunctionInfo { non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY), requirements: UniformityRequirements::empty(), }, + E::RayQueryGetIntersection { + query, + committed: _, + } => Uniformity { + non_uniform_result: self.add_ref(query), + requirements: UniformityRequirements::empty(), + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -934,14 +941,8 @@ impl ModuleInfo { expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(), sampling: crate::FastHashSet::default(), }; - let resolve_context = ResolveContext { - constants: &module.constants, - types: &module.types, - global_vars: &module.global_variables, - local_vars: &fun.local_variables, - functions: &module.functions, - arguments: &fun.arguments, - }; + let resolve_context = + ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments); for (handle, expr) in fun.expressions.iter() { if let Err(source) = info.process_expression( @@ -1064,6 +1065,7 @@ fn uniform_control_flow() { let resolve_context = ResolveContext { constants: &constant_arena, types: &type_arena, + special_types: &crate::SpecialTypes::default(), global_vars: &global_var_arena, local_vars: &Arena::new(), functions: &Arena::new(), diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 9063eb0616..408dccaf10 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1427,6 +1427,7 @@ impl super::Validator { return Err(ExpressionError::InvalidArrayType(expr)); } }, + E::RayQueryProceedResult | E::RayQueryGetIntersection { .. } => ShaderStages::all(), }; Ok(stages) } diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 20d8448a1a..1257eb1bdf 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -39,6 +39,7 @@ impl super::Validator { ref functions, ref global_variables, ref types, + ref special_types, } = module; // NOTE: Types being first is important. All other forms of validation depend on this. @@ -194,6 +195,13 @@ impl super::Validator { validate_function(Some(function_handle), function)?; } + if let Some(ty) = special_types.ray_desc { + validate_type(ty)?; + } + if let Some(ty) = special_types.ray_intersection { + validate_type(ty)?; + } + Ok(()) } @@ -379,10 +387,16 @@ impl super::Validator { handle.check_dep(function)?; } } - crate::Expression::AtomicResult { .. } => (), + crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; } + crate::Expression::RayQueryGetIntersection { + query, + committed: _, + } => { + handle.check_dep(query)?; + } } Ok(()) } diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 13dbd75761..d9ee9f5402 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -440,7 +440,9 @@ impl super::Validator { match types[var.ty].inner { crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } - | crate::TypeInner::BindingArray { .. } => {} + | crate::TypeInner::BindingArray { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => {} _ => { return Err(GlobalVariableError::InvalidType(var.space)); } diff --git a/src/valid/type.rs b/src/valid/type.rs index 23f6ef4d1f..2a2d5e7335 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -622,10 +622,14 @@ impl super::Validator { Ti::Image { .. } | Ti::Sampler { .. } => { TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) } - Ti::AccelerationStructure | Ti::RayQuery => { + Ti::AccelerationStructure => { self.require_type_capability(Capabilities::RAY_QUERY)?; TypeInfo::new(TypeFlags::empty(), Alignment::ONE) } + Ti::RayQuery => { + self.require_type_capability(Capabilities::RAY_QUERY)?; + TypeInfo::new(TypeFlags::DATA | TypeFlags::SIZED, Alignment::ONE) + } Ti::BindingArray { .. } => TypeInfo::new(TypeFlags::empty(), Alignment::ONE), }) } diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index b772e69d7d..0aec9ca142 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -1,3 +1,4 @@ +@group(0) @binding(0) var acc_struct: acceleration_structure; /* @@ -12,24 +13,41 @@ let RAY_QUERY_INTERSECTION_AABB = 4u; struct RayDesc { flags: u32, cull_mask: u32, - origin: vec3, t_min: f32, - dir: vec3, t_max: f32, -}*/ + origin: vec3, + dir: vec3, +} + +struct RayIntersection { + kind: u32, + t: f32, + instance_custom_index: u32, + instance_id: u32, + sbt_record_offset: u32, + geometry_index: u32, + primitive_index: u32, + barycentrics: vec2, + front_face: bool, + //TODO: object ray direction, origin, matrices +} +*/ struct Output { visible: u32, } + +@group(0) @binding(1) var output: Output; -@compute +@compute @workgroup_size(1) fn main() { var rq: ray_query; - rayQueryInitialize(rq, acceleration_structure, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFF, vec3(0.0), 0.1, vec3(0.0, 1.0, 0.0), 100.0)); + rayQueryInitialize(rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), vec3(0.0, 1.0, 0.0))); rayQueryProceed(rq); - output.visible = rayQueryGetCommittedIntersectionType(rq) == RAY_QUERY_COMMITTED_INTERSECTION_NONE; + let intersection = rayQueryGetCommittedIntersection(rq); + output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE); } diff --git a/tests/out/ir/access.ron b/tests/out/ir/access.ron index e544ee1a5d..41772b9332 100644 --- a/tests/out/ir/access.ron +++ b/tests/out/ir/access.ron @@ -333,6 +333,10 @@ ), ), ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ), constants: [ ( name: None, diff --git a/tests/out/ir/collatz.ron b/tests/out/ir/collatz.ron index 00cab8e885..1be31e6eff 100644 --- a/tests/out/ir/collatz.ron +++ b/tests/out/ir/collatz.ron @@ -38,6 +38,10 @@ ), ), ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ), constants: [ ( name: None, diff --git a/tests/out/ir/shadow.ron b/tests/out/ir/shadow.ron index 8956076ef3..9311f9e188 100644 --- a/tests/out/ir/shadow.ron +++ b/tests/out/ir/shadow.ron @@ -286,6 +286,10 @@ ), ), ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ), constants: [ ( name: None, From fbeb223129945a8d983fb3ad48a115b407d358f7 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 16 Feb 2023 23:03:56 -0800 Subject: [PATCH 04/12] spv-out: basic ray query support --- src/back/spv/block.rs | 162 +++++++++++++++++++++++++++++++-- src/back/spv/image.rs | 2 +- src/back/spv/instructions.rs | 60 ++++++++++++ src/back/spv/mod.rs | 8 +- src/back/spv/writer.rs | 100 +++++++++++--------- src/front/wgsl/lower/mod.rs | 34 ++++++- src/lib.rs | 4 +- src/valid/handles.rs | 5 +- tests/in/ray-query.wgsl | 6 +- tests/out/spv/ray-query.spvasm | 81 +++++++++++++++++ 10 files changed, 400 insertions(+), 62 deletions(-) create mode 100644 tests/out/spv/ray-query.spvasm diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 090899b4a7..77d7267fd8 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1084,9 +1084,9 @@ impl<'w> BlockContext<'w> { } } crate::Expression::FunctionArgument(index) => self.function.parameter_id(index), - crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => { - self.cached[expr_handle] - } + crate::Expression::CallResult(_) + | crate::Expression::AtomicResult { .. } + | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -1386,10 +1386,61 @@ impl<'w> BlockContext<'w> { id } crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, - //TODO - crate::Expression::RayQueryProceedResult => unreachable!(), - //TODO - crate::Expression::RayQueryGetIntersection { .. } => unreachable!(), + crate::Expression::RayQueryGetIntersection { query, committed } => { + let width = 4; + let query_id = self.cached[query]; + let intersection_id = self.writer.get_constant_scalar( + crate::ScalarValue::Uint( + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, + ), + width, + ); + if !committed { + return Err(Error::FeatureNotImplemented("candidate intersection")); + } + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let kind_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + flag_type_id, + kind_id, + query_id, + intersection_id, + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let t_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); + + let id = self.gen_id(); + let intersection_type_id = self.get_type_id(LookupType::Handle( + self.ir_module.special_types.ray_intersection.unwrap(), + )); + //Note: the arguments must match `generate_ray_intersection_type` layout + block.body.push(Instruction::composite_construct( + intersection_type_id, + id, + &[kind_id, t_id], + )); + id + } }; self.cached[expr_handle] = id; @@ -2206,8 +2257,101 @@ impl<'w> BlockContext<'w> { crate::RayQueryFunction::Initialize { acceleration_structure, descriptor, - } => {} - crate::RayQueryFunction::Proceed => {} + } => { + //Note: composite extract indices and types must match `generate_ray_desc_type` + let desc_id = self.cached[descriptor]; + let acc_struct_id = self.get_image_id(acceleration_structure); + let width = 4; + + let flag_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let ray_flags_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + ray_flags_id, + desc_id, + &[0], + )); + let cull_mask_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + cull_mask_id, + desc_id, + &[1], + )); + + let scalar_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let tmin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmin_id, + desc_id, + &[2], + )); + let tmax_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmax_id, + desc_id, + &[3], + )); + + let vector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let ray_origin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_origin_id, + desc_id, + &[4], + )); + let ray_dir_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_dir_id, + desc_id, + &[5], + )); + + block.body.push(Instruction::ray_query_initialize( + query_id, + acc_struct_id, + ray_flags_id, + cull_mask_id, + ray_origin_id, + tmin_id, + ray_dir_id, + tmax_id, + )); + } + crate::RayQueryFunction::Proceed { result } => { + let id = self.gen_id(); + self.cached[result] = id; + let result_type_id = + self.get_expression_type_id(&self.fun_info[result].ty); + + block.body.push(Instruction::ray_query_proceed( + result_type_id, + id, + query_id, + )); + } crate::RayQueryFunction::Terminate => {} } } diff --git a/src/back/spv/image.rs b/src/back/spv/image.rs index 1a136af77e..dc4f249949 100644 --- a/src/back/spv/image.rs +++ b/src/back/spv/image.rs @@ -373,7 +373,7 @@ impl<'w> BlockContext<'w> { }) } - fn get_image_id(&mut self, expr_handle: Handle) -> Word { + pub(super) fn get_image_id(&mut self, expr_handle: Handle) -> Word { let id = match self.ir_function.expressions[expr_handle] { crate::Expression::GlobalVariable(handle) => { self.writer.global_variables[handle.index()].handle_id diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index c213790188..3038d60d40 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -249,6 +249,18 @@ impl super::Instruction { instruction } + pub(super) fn type_acceleration_structure(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeAccelerationStructureKHR); + instruction.set_result(id); + instruction + } + + pub(super) fn type_ray_query(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeRayQueryKHR); + instruction.set_result(id); + instruction + } + pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self { let mut instruction = Self::new(Op::TypeSampledImage); instruction.set_result(id); @@ -627,6 +639,54 @@ impl super::Instruction { instruction } + // + // Ray Query Instructions + // + pub(super) fn ray_query_initialize( + query: Word, + acceleration_structure: Word, + ray_flags: Word, + cull_mask: Word, + ray_origin: Word, + ray_tmin: Word, + ray_dir: Word, + ray_tmax: Word, + ) -> Self { + let mut instruction = Self::new(Op::RayQueryInitializeKHR); + instruction.add_operand(query); + instruction.add_operand(acceleration_structure); + instruction.add_operand(ray_flags); + instruction.add_operand(cull_mask); + instruction.add_operand(ray_origin); + instruction.add_operand(ray_tmin); + instruction.add_operand(ray_dir); + instruction.add_operand(ray_tmax); + instruction + } + + pub(super) fn ray_query_proceed(result_type_id: Word, id: Word, query: Word) -> Self { + let mut instruction = Self::new(Op::RayQueryProceedKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction + } + + pub(super) fn ray_query_get_intersection( + op: Op, + result_type_id: Word, + id: Word, + query: Word, + intersection: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction.add_operand(intersection); + instruction + } + // // Conversion Instructions // diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 72898219af..5fba4f0dea 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -295,6 +295,8 @@ enum LocalType { base: Handle, size: u64, }, + AccelerationStructure, + RayQuery, } /// A type encountered during SPIR-V generation. @@ -383,7 +385,11 @@ fn make_local(inner: &crate::TypeInner) -> Option { class, } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)), crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler, - _ => return None, + crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure, + crate::TypeInner::RayQuery => LocalType::RayQuery, + crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } + | crate::TypeInner::BindingArray { .. } => return None, }) } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index fbc53feedd..800a40ed68 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -350,9 +350,12 @@ impl Writer { pointer_type_id, id, spirv::StorageClass::Function, - init_word.or_else(|| { - let type_id = self.get_type_id(LookupType::Handle(variable.ty)); - Some(self.write_constant_null(type_id)) + init_word.or_else(|| match ir_module.types[variable.ty].inner { + crate::TypeInner::RayQuery => None, + _ => { + let type_id = self.get_type_id(LookupType::Handle(variable.ty)); + Some(self.write_constant_null(type_id)) + } }), ); function @@ -814,47 +817,54 @@ impl Writer { } } - fn request_image_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> { - if let crate::TypeInner::Image { - dim, - arrayed, - class, - } = *inner - { - let sampled = match class { - crate::ImageClass::Sampled { .. } => true, - crate::ImageClass::Depth { .. } => true, - crate::ImageClass::Storage { format, .. } => { - self.request_image_format_capabilities(format.into())?; - false - } - }; + fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> { + match *inner { + crate::TypeInner::Image { + dim, + arrayed, + class, + } => { + let sampled = match class { + crate::ImageClass::Sampled { .. } => true, + crate::ImageClass::Depth { .. } => true, + crate::ImageClass::Storage { format, .. } => { + self.request_image_format_capabilities(format.into())?; + false + } + }; - match dim { - crate::ImageDimension::D1 => { - if sampled { - self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?; - } else { - self.require_any("1D storage images", &[spirv::Capability::Image1D])?; + match dim { + crate::ImageDimension::D1 => { + if sampled { + self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?; + } else { + self.require_any("1D storage images", &[spirv::Capability::Image1D])?; + } } - } - crate::ImageDimension::Cube if arrayed => { - if sampled { - self.require_any( - "sampled cube array images", - &[spirv::Capability::SampledCubeArray], - )?; - } else { - self.require_any( - "cube array storage images", - &[spirv::Capability::ImageCubeArray], - )?; + crate::ImageDimension::Cube if arrayed => { + if sampled { + self.require_any( + "sampled cube array images", + &[spirv::Capability::SampledCubeArray], + )?; + } else { + self.require_any( + "cube array storage images", + &[spirv::Capability::ImageCubeArray], + )?; + } } + _ => {} } - _ => {} } + crate::TypeInner::AccelerationStructure => { + self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?; + } + crate::TypeInner::RayQuery => { + self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?; + } + _ => {} } - Ok(()) } @@ -935,6 +945,8 @@ impl Writer { self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size })); Instruction::type_pointer(id, spirv::StorageClass::UniformConstant, inner_ty) } + LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id), + LocalType::RayQuery => Instruction::type_ray_query(id), }; instruction.to_words(&mut self.logical_layout.declarations); @@ -961,9 +973,9 @@ impl Writer { self.write_type_declaration_local(id, local); - // If it's an image type, request SPIR-V capabilities here, so - // write_type_declaration_local can stay infallible. - self.request_image_capabilities(&ty.inner)?; + // If it's an type that needs SPIR-V capabilities, request them now, + // so write_type_declaration_local can stay infallible. + self.request_type_capabilities(&ty.inner)?; id } @@ -1758,6 +1770,8 @@ impl Writer { .iter() .flat_map(|entry| entry.function.arguments.iter()) .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty)); + let has_ray_query = ir_module.special_types.ray_desc.is_some() + | ir_module.special_types.ray_intersection.is_some(); if self.physical_layout.version < 0x10300 && has_storage_buffers { // enable the storage buffer class on < SPV-1.3 @@ -1768,6 +1782,10 @@ impl Writer { Instruction::extension("SPV_KHR_multiview") .to_words(&mut self.logical_layout.extensions) } + if has_ray_query { + Instruction::extension("SPV_KHR_ray_query") + .to_words(&mut self.logical_layout.extensions) + } Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations); Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450") .to_words(&mut self.logical_layout.ext_inst_imports); diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 971b3032b9..8cba80fb5c 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1915,7 +1915,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "rayQueryInitialize" => { let mut args = ctx.prepare_args(arguments, 3, span); - let query = self.expression(args.next()?, ctx.reborrow())?; + let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; let acceleration_structure = self.expression(args.next()?, ctx.reborrow())?; let descriptor = self.expression(args.next()?, ctx.reborrow())?; @@ -1935,15 +1935,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "rayQueryProceed" => { let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.expression(args.next()?, ctx.reborrow())?; + let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; args.finish()?; - let fun = crate::RayQueryFunction::Proceed; - ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions)); let result = ctx .naga_expressions .append(crate::Expression::RayQueryProceedResult, span); + let fun = crate::RayQueryFunction::Proceed { result }; + ctx.emitter.start(ctx.naga_expressions); ctx.block .push(crate::Statement::RayQuery { query, fun }, span); @@ -1951,7 +1951,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "rayQueryGetCommittedIntersection" => { let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.expression(args.next()?, ctx.reborrow())?; + let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; args.finish()?; let _ = ctx.module.generate_ray_intersection_type(); @@ -2422,4 +2422,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { binding } + + fn ray_query_pointer( + &mut self, + expr: Handle>, + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let pointer = self.expression(expr, ctx.reborrow())?; + + ctx.grow_types(pointer)?; + match *ctx.resolved_inner(pointer) { + crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { + crate::TypeInner::RayQuery => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to ray query op", other); + Err(Error::InvalidAtomicPointer(span)) + } + }, + ref other => { + log::error!("Type {:?} passed to ray query op", other); + Err(Error::InvalidAtomicPointer(span)) + } + } + } } diff --git a/src/lib.rs b/src/lib.rs index 9344fd0a53..162c23cd6b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1492,7 +1492,9 @@ pub enum RayQueryFunction { acceleration_structure: Handle, descriptor: Handle, }, - Proceed, + Proceed { + result: Handle, + }, Terminate, } diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 1257eb1bdf..c9e6cad502 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -520,7 +520,10 @@ impl super::Validator { validate_expr(acceleration_structure)?; validate_expr(descriptor)?; } - crate::RayQueryFunction::Proceed | crate::RayQueryFunction::Terminate => {} + crate::RayQueryFunction::Proceed { result } => { + validate_expr(result)?; + } + crate::RayQueryFunction::Terminate => {} } Ok(()) } diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 0aec9ca142..7b1053bdcb 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -44,10 +44,10 @@ var output: Output; fn main() { var rq: ray_query; - rayQueryInitialize(rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), vec3(0.0, 1.0, 0.0))); + rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), vec3(0.0, 1.0, 0.0))); - rayQueryProceed(rq); + rayQueryProceed(&rq); - let intersection = rayQueryGetCommittedIntersection(rq); + let intersection = rayQueryGetCommittedIntersection(&rq); output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE); } diff --git a/tests/out/spv/ray-query.spvasm b/tests/out/spv/ray-query.spvasm new file mode 100644 index 0000000000..455a8bffe9 --- /dev/null +++ b/tests/out/spv/ray-query.spvasm @@ -0,0 +1,81 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 52 +OpCapability RayQueryKHR +OpCapability Shader +OpExtension "SPV_KHR_ray_query" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %26 "main" %18 %20 +OpExecutionMode %26 LocalSize 1 1 1 +OpMemberDecorate %13 0 Offset 0 +OpMemberDecorate %16 0 Offset 0 +OpMemberDecorate %16 1 Offset 4 +OpMemberDecorate %16 2 Offset 8 +OpMemberDecorate %16 3 Offset 12 +OpMemberDecorate %16 4 Offset 16 +OpMemberDecorate %16 5 Offset 32 +OpMemberDecorate %17 0 Offset 0 +OpMemberDecorate %17 1 Offset 4 +OpDecorate %18 DescriptorSet 0 +OpDecorate %18 Binding 0 +OpDecorate %20 DescriptorSet 0 +OpDecorate %20 Binding 1 +OpDecorate %21 Block +OpMemberDecorate %21 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 0 +%3 = OpConstant %4 4 +%5 = OpConstant %4 255 +%7 = OpTypeFloat 32 +%6 = OpConstant %7 0.1 +%8 = OpConstant %7 100.0 +%9 = OpConstant %7 0.0 +%10 = OpConstant %7 1.0 +%11 = OpConstant %4 0 +%12 = OpTypeAccelerationStructureNV +%13 = OpTypeStruct %4 +%14 = OpTypeRayQueryKHR +%15 = OpTypeVector %7 3 +%16 = OpTypeStruct %4 %4 %7 %7 %15 %15 +%17 = OpTypeStruct %4 %7 +%19 = OpTypePointer UniformConstant %12 +%18 = OpVariable %19 UniformConstant +%21 = OpTypeStruct %13 +%22 = OpTypePointer StorageBuffer %21 +%20 = OpVariable %22 StorageBuffer +%24 = OpTypePointer Function %14 +%27 = OpTypeFunction %2 +%29 = OpTypePointer StorageBuffer %13 +%42 = OpTypeBool +%43 = OpConstant %4 1 +%47 = OpTypePointer StorageBuffer %4 +%26 = OpFunction %2 None %27 +%25 = OpLabel +%23 = OpVariable %24 Function +%28 = OpLoad %12 %18 +%30 = OpAccessChain %29 %20 %11 +OpBranch %31 +%31 = OpLabel +%32 = OpCompositeConstruct %15 %9 %9 %9 +%33 = OpCompositeConstruct %15 %9 %10 %9 +%34 = OpCompositeConstruct %16 %3 %5 %6 %8 %32 %33 +%35 = OpCompositeExtract %4 %34 0 +%36 = OpCompositeExtract %4 %34 1 +%37 = OpCompositeExtract %7 %34 2 +%38 = OpCompositeExtract %7 %34 3 +%39 = OpCompositeExtract %15 %34 4 +%40 = OpCompositeExtract %15 %34 5 +OpRayQueryInitializeKHR %23 %28 %35 %36 %39 %37 %40 %38 +%41 = OpRayQueryProceedKHR %42 %23 +%44 = OpRayQueryGetIntersectionTypeKHR %4 %23 %43 +%45 = OpRayQueryGetIntersectionTKHR %7 %23 %43 +%46 = OpCompositeConstruct %17 %44 %45 +%48 = OpCompositeExtract %4 %46 0 +%49 = OpIEqual %42 %48 %11 +%50 = OpSelect %4 %49 %43 %11 +%51 = OpAccessChain %47 %30 %11 +OpStore %51 %50 +OpReturn +OpFunctionEnd \ No newline at end of file From 06cde6235e212dc562ab9b764f9eb3134c4f0bbf Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 17 Feb 2023 21:50:31 -0800 Subject: [PATCH 05/12] clippy fixes --- src/back/spv/instructions.rs | 1 + src/front/wgsl/lower/mod.rs | 3 +-- src/lib.rs | 1 + src/proc/typifier.rs | 2 +- src/valid/type.rs | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 3038d60d40..96d0278285 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -642,6 +642,7 @@ impl super::Instruction { // // Ray Query Instructions // + #[allow(clippy::too_many_arguments)] pub(super) fn ray_query_initialize( query: Word, acceleration_structure: Word, diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 8cba80fb5c..730a33815a 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -234,8 +234,7 @@ impl<'a> ExpressionContext<'a, '_, '_> { /// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner /// [`Typifier`]: Typifier fn grow_types(&mut self, handle: Handle) -> Result<&mut Self, Error<'a>> { - let resolve_ctx = - ResolveContext::with_locals(&self.module, self.local_vars, self.arguments); + let resolve_ctx = ResolveContext::with_locals(self.module, self.local_vars, self.arguments); self.typifier .grow(handle, self.naga_expressions, &resolve_ctx) .map_err(Error::InvalidResolve)?; diff --git a/src/lib.rs b/src/lib.rs index 162c23cd6b..bf5f9a3a2a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -169,6 +169,7 @@ need to be stored in a local variable to be carried upwards in the statement tree. [`AtomicResult`]: Expression::AtomicResult +[`RayQueryProceedResult`]: Expression::RayQueryProceedResult [`CallResult`]: Expression::CallResult [`Constant`]: Expression::Constant [`Derivative`]: Expression::Derivative diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index f7a4ba94d5..0bb9019a29 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -209,7 +209,7 @@ pub struct ResolveContext<'a> { impl<'a> ResolveContext<'a> { /// Initialize a resolve context from the module. - pub fn with_locals( + pub const fn with_locals( module: &'a crate::Module, local_vars: &'a Arena, arguments: &'a [crate::FunctionArgument], diff --git a/src/valid/type.rs b/src/valid/type.rs index 2a2d5e7335..d8dd37d09b 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -205,7 +205,7 @@ impl TypeInfo { } impl super::Validator { - fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> { + const fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> { if self.capabilities.contains(capability) { Ok(()) } else { From a328fdc11b818badaaf7e121ae3f9c637f21edc9 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 21 Feb 2023 23:31:32 -0800 Subject: [PATCH 06/12] fill up the ray query intersection struct --- src/back/spv/block.rs | 154 +------------------ src/back/spv/mod.rs | 1 + src/back/spv/ray.rs | 273 +++++++++++++++++++++++++++++++++ src/front/type_gen.rs | 91 ++++++++++- tests/in/ray-query.wgsl | 3 +- tests/out/spv/ray-query.spvasm | 122 +++++++++------ 6 files changed, 440 insertions(+), 204 deletions(-) create mode 100644 src/back/spv/ray.rs diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 77d7267fd8..b28b94fe91 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1387,59 +1387,10 @@ impl<'w> BlockContext<'w> { } crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, crate::Expression::RayQueryGetIntersection { query, committed } => { - let width = 4; - let query_id = self.cached[query]; - let intersection_id = self.writer.get_constant_scalar( - crate::ScalarValue::Uint( - spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, - ), - width, - ); if !committed { return Err(Error::FeatureNotImplemented("candidate intersection")); } - - let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Uint, - width, - pointer_space: None, - })); - let kind_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionTypeKHR, - flag_type_id, - kind_id, - query_id, - intersection_id, - )); - - let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Float, - width, - pointer_space: None, - })); - let t_id = self.gen_id(); - block.body.push(Instruction::ray_query_get_intersection( - spirv::Op::RayQueryGetIntersectionTKHR, - scalar_type_id, - t_id, - query_id, - intersection_id, - )); - - let id = self.gen_id(); - let intersection_type_id = self.get_type_id(LookupType::Handle( - self.ir_module.special_types.ray_intersection.unwrap(), - )); - //Note: the arguments must match `generate_ray_intersection_type` layout - block.body.push(Instruction::composite_construct( - intersection_type_id, - id, - &[kind_id, t_id], - )); - id + self.write_ray_query_get_intersection(query, block) } }; @@ -2252,108 +2203,7 @@ impl<'w> BlockContext<'w> { block.body.push(instruction); } crate::Statement::RayQuery { query, ref fun } => { - let query_id = self.cached[query]; - match *fun { - crate::RayQueryFunction::Initialize { - acceleration_structure, - descriptor, - } => { - //Note: composite extract indices and types must match `generate_ray_desc_type` - let desc_id = self.cached[descriptor]; - let acc_struct_id = self.get_image_id(acceleration_structure); - let width = 4; - - let flag_type_id = - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Uint, - width, - pointer_space: None, - })); - let ray_flags_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - flag_type_id, - ray_flags_id, - desc_id, - &[0], - )); - let cull_mask_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - flag_type_id, - cull_mask_id, - desc_id, - &[1], - )); - - let scalar_type_id = - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - kind: crate::ScalarKind::Float, - width, - pointer_space: None, - })); - let tmin_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - scalar_type_id, - tmin_id, - desc_id, - &[2], - )); - let tmax_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - scalar_type_id, - tmax_id, - desc_id, - &[3], - )); - - let vector_type_id = - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Tri), - kind: crate::ScalarKind::Float, - width, - pointer_space: None, - })); - let ray_origin_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - vector_type_id, - ray_origin_id, - desc_id, - &[4], - )); - let ray_dir_id = self.gen_id(); - block.body.push(Instruction::composite_extract( - vector_type_id, - ray_dir_id, - desc_id, - &[5], - )); - - block.body.push(Instruction::ray_query_initialize( - query_id, - acc_struct_id, - ray_flags_id, - cull_mask_id, - ray_origin_id, - tmin_id, - ray_dir_id, - tmax_id, - )); - } - crate::RayQueryFunction::Proceed { result } => { - let id = self.gen_id(); - self.cached[result] = id; - let result_type_id = - self.get_expression_type_id(&self.fun_info[result].ty); - - block.body.push(Instruction::ray_query_proceed( - result_type_id, - id, - query_id, - )); - } - crate::RayQueryFunction::Terminate => {} - } + self.write_ray_query_function(query, fun, &mut block); } } } diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 5fba4f0dea..9b084911b1 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -10,6 +10,7 @@ mod image; mod index; mod instructions; mod layout; +mod ray; mod recyclable; mod selection; mod writer; diff --git a/src/back/spv/ray.rs b/src/back/spv/ray.rs new file mode 100644 index 0000000000..79eb2ff971 --- /dev/null +++ b/src/back/spv/ray.rs @@ -0,0 +1,273 @@ +/*! +Generating SPIR-V for ray query operations. +*/ + +use super::{Block, BlockContext, Instruction, LocalType, LookupType}; +use crate::arena::Handle; + +impl<'w> BlockContext<'w> { + pub(super) fn write_ray_query_function( + &mut self, + query: Handle, + function: &crate::RayQueryFunction, + block: &mut Block, + ) { + let query_id = self.cached[query]; + match *function { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //Note: composite extract indices and types must match `generate_ray_desc_type` + let desc_id = self.cached[descriptor]; + let acc_struct_id = self.get_image_id(acceleration_structure); + let width = 4; + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let ray_flags_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + ray_flags_id, + desc_id, + &[0], + )); + let cull_mask_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + cull_mask_id, + desc_id, + &[1], + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let tmin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmin_id, + desc_id, + &[2], + )); + let tmax_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmax_id, + desc_id, + &[3], + )); + + let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let ray_origin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_origin_id, + desc_id, + &[4], + )); + let ray_dir_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_dir_id, + desc_id, + &[5], + )); + + block.body.push(Instruction::ray_query_initialize( + query_id, + acc_struct_id, + ray_flags_id, + cull_mask_id, + ray_origin_id, + tmin_id, + ray_dir_id, + tmax_id, + )); + } + crate::RayQueryFunction::Proceed { result } => { + let id = self.gen_id(); + self.cached[result] = id; + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + + block + .body + .push(Instruction::ray_query_proceed(result_type_id, id, query_id)); + } + crate::RayQueryFunction::Terminate => {} + } + } + + pub(super) fn write_ray_query_get_intersection( + &mut self, + query: Handle, + block: &mut Block, + ) -> spirv::Word { + let width = 4; + let query_id = self.cached[query]; + let intersection_id = self.writer.get_constant_scalar( + crate::ScalarValue::Uint( + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, + ), + width, + ); + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let kind_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + flag_type_id, + kind_id, + query_id, + intersection_id, + )); + let instance_custom_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR, + flag_type_id, + instance_custom_index_id, + query_id, + intersection_id, + )); + let instance_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceIdKHR, + flag_type_id, + instance_id, + query_id, + intersection_id, + )); + let sbt_record_offset_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR, + flag_type_id, + sbt_record_offset_id, + query_id, + intersection_id, + )); + let geometry_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionGeometryIndexKHR, + flag_type_id, + geometry_index_id, + query_id, + intersection_id, + )); + let primitive_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR, + flag_type_id, + primitive_index_id, + query_id, + intersection_id, + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let t_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); + + let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Bi), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let barycentrics_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionBarycentricsKHR, + barycentrics_type_id, + barycentrics_id, + query_id, + intersection_id, + )); + + let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + pointer_space: None, + })); + let front_face_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionFrontFaceKHR, + bool_type_id, + front_face_id, + query_id, + intersection_id, + )); + + let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width, + })); + let object_to_world_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionObjectToWorldKHR, + transform_type_id, + object_to_world_id, + query_id, + intersection_id, + )); + let world_to_object_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionWorldToObjectKHR, + transform_type_id, + world_to_object_id, + query_id, + intersection_id, + )); + + let id = self.gen_id(); + let intersection_type_id = self.get_type_id(LookupType::Handle( + self.ir_module.special_types.ray_intersection.unwrap(), + )); + //Note: the arguments must match `generate_ray_intersection_type` layout + block.body.push(Instruction::composite_construct( + intersection_type_id, + id, + &[ + kind_id, + t_id, + instance_custom_index_id, + instance_id, + sbt_record_offset_id, + geometry_index_id, + primitive_index_id, + barycentrics_id, + front_face_id, + object_to_world_id, + world_to_object_id, + ], + )); + id + } +} diff --git a/src/front/type_gen.rs b/src/front/type_gen.rs index 18d9ddd54c..bb734ac69c 100644 --- a/src/front/type_gen.rs +++ b/src/front/type_gen.rs @@ -5,6 +5,7 @@ Type generators. use crate::{arena::Handle, span::Span}; impl crate::Module { + //Note: has to match `struct RayDesc` pub(super) fn generate_ray_desc_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_desc { return handle; @@ -95,6 +96,7 @@ impl crate::Module { handle } + //Note: has to match `struct RayIntersection` pub(super) fn generate_ray_intersection_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_intersection { return handle; @@ -121,6 +123,38 @@ impl crate::Module { }, Span::UNDEFINED, ); + let ty_barycentrics = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + width, + size: crate::VectorSize::Bi, + kind: crate::ScalarKind::Float, + }, + }, + Span::UNDEFINED, + ); + let ty_bool = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + width: crate::BOOL_WIDTH, + kind: crate::ScalarKind::Bool, + }, + }, + Span::UNDEFINED, + ); + let ty_transform = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width, + }, + }, + Span::UNDEFINED, + ); let handle = self.types.insert( crate::Type { @@ -139,9 +173,62 @@ impl crate::Module { binding: None, offset: 4, }, - //TODO: the rest + crate::StructMember { + name: Some("instance_custom_index".to_string()), + ty: ty_flag, + binding: None, + offset: 8, + }, + crate::StructMember { + name: Some("instance_id".to_string()), + ty: ty_flag, + binding: None, + offset: 12, + }, + crate::StructMember { + name: Some("sbt_record_offset".to_string()), + ty: ty_flag, + binding: None, + offset: 16, + }, + crate::StructMember { + name: Some("geometry_index".to_string()), + ty: ty_flag, + binding: None, + offset: 20, + }, + crate::StructMember { + name: Some("primitive_index".to_string()), + ty: ty_flag, + binding: None, + offset: 24, + }, + crate::StructMember { + name: Some("barycentrics".to_string()), + ty: ty_barycentrics, + binding: None, + offset: 28, + }, + crate::StructMember { + name: Some("front_face".to_string()), + ty: ty_bool, + binding: None, + offset: 36, + }, + crate::StructMember { + name: Some("object_to_world".to_string()), + ty: ty_transform, + binding: None, + offset: 48, + }, + crate::StructMember { + name: Some("world_to_object".to_string()), + ty: ty_transform, + binding: None, + offset: 112, + }, ], - span: 8, + span: 176, }, }, Span::UNDEFINED, diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 7b1053bdcb..5eabf3a2d3 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -29,7 +29,8 @@ struct RayIntersection { primitive_index: u32, barycentrics: vec2, front_face: bool, - //TODO: object ray direction, origin, matrices + object_to_world: mat4x3, + world_to_object: mat4x3, } */ diff --git a/tests/out/spv/ray-query.spvasm b/tests/out/spv/ray-query.spvasm index 455a8bffe9..6bc41ee30f 100644 --- a/tests/out/spv/ray-query.spvasm +++ b/tests/out/spv/ray-query.spvasm @@ -1,14 +1,14 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 52 +; Bound: 63 OpCapability RayQueryKHR OpCapability Shader OpExtension "SPV_KHR_ray_query" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %26 "main" %18 %20 -OpExecutionMode %26 LocalSize 1 1 1 +OpEntryPoint GLCompute %29 "main" %21 %23 +OpExecutionMode %29 LocalSize 1 1 1 OpMemberDecorate %13 0 Offset 0 OpMemberDecorate %16 0 Offset 0 OpMemberDecorate %16 1 Offset 4 @@ -16,14 +16,27 @@ OpMemberDecorate %16 2 Offset 8 OpMemberDecorate %16 3 Offset 12 OpMemberDecorate %16 4 Offset 16 OpMemberDecorate %16 5 Offset 32 -OpMemberDecorate %17 0 Offset 0 -OpMemberDecorate %17 1 Offset 4 -OpDecorate %18 DescriptorSet 0 -OpDecorate %18 Binding 0 -OpDecorate %20 DescriptorSet 0 -OpDecorate %20 Binding 1 -OpDecorate %21 Block -OpMemberDecorate %21 0 Offset 0 +OpMemberDecorate %20 0 Offset 0 +OpMemberDecorate %20 1 Offset 4 +OpMemberDecorate %20 2 Offset 8 +OpMemberDecorate %20 3 Offset 12 +OpMemberDecorate %20 4 Offset 16 +OpMemberDecorate %20 5 Offset 20 +OpMemberDecorate %20 6 Offset 24 +OpMemberDecorate %20 7 Offset 28 +OpMemberDecorate %20 8 Offset 36 +OpMemberDecorate %20 9 Offset 48 +OpMemberDecorate %20 9 ColMajor +OpMemberDecorate %20 9 MatrixStride 16 +OpMemberDecorate %20 10 Offset 112 +OpMemberDecorate %20 10 ColMajor +OpMemberDecorate %20 10 MatrixStride 16 +OpDecorate %21 DescriptorSet 0 +OpDecorate %21 Binding 0 +OpDecorate %23 DescriptorSet 0 +OpDecorate %23 Binding 1 +OpDecorate %24 Block +OpMemberDecorate %24 0 Offset 0 %2 = OpTypeVoid %4 = OpTypeInt 32 0 %3 = OpConstant %4 4 @@ -39,43 +52,54 @@ OpMemberDecorate %21 0 Offset 0 %14 = OpTypeRayQueryKHR %15 = OpTypeVector %7 3 %16 = OpTypeStruct %4 %4 %7 %7 %15 %15 -%17 = OpTypeStruct %4 %7 -%19 = OpTypePointer UniformConstant %12 -%18 = OpVariable %19 UniformConstant -%21 = OpTypeStruct %13 -%22 = OpTypePointer StorageBuffer %21 -%20 = OpVariable %22 StorageBuffer -%24 = OpTypePointer Function %14 -%27 = OpTypeFunction %2 -%29 = OpTypePointer StorageBuffer %13 -%42 = OpTypeBool -%43 = OpConstant %4 1 -%47 = OpTypePointer StorageBuffer %4 -%26 = OpFunction %2 None %27 -%25 = OpLabel -%23 = OpVariable %24 Function -%28 = OpLoad %12 %18 -%30 = OpAccessChain %29 %20 %11 -OpBranch %31 -%31 = OpLabel -%32 = OpCompositeConstruct %15 %9 %9 %9 -%33 = OpCompositeConstruct %15 %9 %10 %9 -%34 = OpCompositeConstruct %16 %3 %5 %6 %8 %32 %33 -%35 = OpCompositeExtract %4 %34 0 -%36 = OpCompositeExtract %4 %34 1 -%37 = OpCompositeExtract %7 %34 2 -%38 = OpCompositeExtract %7 %34 3 -%39 = OpCompositeExtract %15 %34 4 -%40 = OpCompositeExtract %15 %34 5 -OpRayQueryInitializeKHR %23 %28 %35 %36 %39 %37 %40 %38 -%41 = OpRayQueryProceedKHR %42 %23 -%44 = OpRayQueryGetIntersectionTypeKHR %4 %23 %43 -%45 = OpRayQueryGetIntersectionTKHR %7 %23 %43 -%46 = OpCompositeConstruct %17 %44 %45 -%48 = OpCompositeExtract %4 %46 0 -%49 = OpIEqual %42 %48 %11 -%50 = OpSelect %4 %49 %43 %11 -%51 = OpAccessChain %47 %30 %11 -OpStore %51 %50 +%17 = OpTypeVector %7 2 +%18 = OpTypeBool +%19 = OpTypeMatrix %15 4 +%20 = OpTypeStruct %4 %7 %4 %4 %4 %4 %4 %17 %18 %19 %19 +%22 = OpTypePointer UniformConstant %12 +%21 = OpVariable %22 UniformConstant +%24 = OpTypeStruct %13 +%25 = OpTypePointer StorageBuffer %24 +%23 = OpVariable %25 StorageBuffer +%27 = OpTypePointer Function %14 +%30 = OpTypeFunction %2 +%32 = OpTypePointer StorageBuffer %13 +%45 = OpConstant %4 1 +%58 = OpTypePointer StorageBuffer %4 +%29 = OpFunction %2 None %30 +%28 = OpLabel +%26 = OpVariable %27 Function +%31 = OpLoad %12 %21 +%33 = OpAccessChain %32 %23 %11 +OpBranch %34 +%34 = OpLabel +%35 = OpCompositeConstruct %15 %9 %9 %9 +%36 = OpCompositeConstruct %15 %9 %10 %9 +%37 = OpCompositeConstruct %16 %3 %5 %6 %8 %35 %36 +%38 = OpCompositeExtract %4 %37 0 +%39 = OpCompositeExtract %4 %37 1 +%40 = OpCompositeExtract %7 %37 2 +%41 = OpCompositeExtract %7 %37 3 +%42 = OpCompositeExtract %15 %37 4 +%43 = OpCompositeExtract %15 %37 5 +OpRayQueryInitializeKHR %26 %31 %38 %39 %42 %40 %43 %41 +%44 = OpRayQueryProceedKHR %18 %26 +%46 = OpRayQueryGetIntersectionTypeKHR %4 %26 %45 +%47 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %4 %26 %45 +%48 = OpRayQueryGetIntersectionInstanceIdKHR %4 %26 %45 +%49 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %4 %26 %45 +%50 = OpRayQueryGetIntersectionGeometryIndexKHR %4 %26 %45 +%51 = OpRayQueryGetIntersectionPrimitiveIndexKHR %4 %26 %45 +%52 = OpRayQueryGetIntersectionTKHR %7 %26 %45 +%53 = OpRayQueryGetIntersectionBarycentricsKHR %17 %26 %45 +%54 = OpRayQueryGetIntersectionFrontFaceKHR %18 %26 %45 +%55 = OpRayQueryGetIntersectionObjectToWorldKHR %19 %26 %45 +%56 = OpRayQueryGetIntersectionWorldToObjectKHR %19 %26 %45 +%57 = OpCompositeConstruct %20 %46 %52 %47 %48 %49 %50 %51 %53 %54 %55 %56 +%59 = OpCompositeExtract %4 %57 0 +%60 = OpIEqual %18 %59 %11 +%61 = OpSelect %4 %60 %45 %11 +%62 = OpAccessChain %58 %33 %11 +OpStore %62 %61 OpReturn OpFunctionEnd \ No newline at end of file From cfa4cba01dddd247d455ecdf534d6cceb8d1ea37 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 22 Feb 2023 23:14:49 -0800 Subject: [PATCH 07/12] wgsl: handle RayDesc/RayIntersection at the type decl level instead of an AST constructor --- src/front/wgsl/lower/construction.rs | 4 ---- src/front/wgsl/lower/mod.rs | 23 +++++++++++++++----- src/front/wgsl/parse/ast.rs | 5 ++--- src/front/wgsl/parse/mod.rs | 3 ++- tests/out/spv/ray-query.spvasm | 32 ++++++++++++++-------------- 5 files changed, 38 insertions(+), 29 deletions(-) diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index 4b0371573a..723d4441f5 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -660,10 +660,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }); ConcreteConstructorHandle::Type(ty) } - ast::ConstructorType::RayDesc => { - let ty = ctx.module.generate_ray_desc_type(); - ConcreteConstructorHandle::Type(ty) - } ast::ConstructorType::Type(ty) => ConcreteConstructorHandle::Type(ty), }; diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 730a33815a..8df3460b2e 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -637,14 +637,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { module: &mut module, }; - for decl in self.index.visit_ordered() { - let span = tu.decls.get_span(decl); - let decl = &tu.decls[decl]; + for decl_handle in self.index.visit_ordered() { + let span = tu.decls.get_span(decl_handle); + let decl = &tu.decls[decl_handle]; + + //TODO: find a nicer way? + if let Some(dep) = decl.dependencies.iter().find(|dep| dep.ident == "RayDesc") { + let ty_handle = ctx.module.generate_ray_desc_type(); + ctx.globals + .insert(dep.ident, LoweredGlobalDecl::Type(ty_handle)); + } match decl.kind { ast::GlobalDeclKind::Fn(ref f) => { - let decl = self.function(f, span, ctx.reborrow())?; - ctx.globals.insert(f.name.name, decl); + let lowered_decl = self.function(f, span, ctx.reborrow())?; + ctx.globals.insert(f.name.name, lowered_decl); } ast::GlobalDeclKind::Var(ref v) => { let ty = self.resolve_ast_type(v.ty, ctx.reborrow())?; @@ -2302,6 +2309,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, } } + ast::Type::RayDesc => { + return Ok(ctx.module.generate_ray_desc_type()); + } + ast::Type::RayIntersection => { + return Ok(ctx.module.generate_ray_intersection_type()); + } ast::Type::User(ref ident) => { return match ctx.globals.get(ident.name) { Some(&LoweredGlobalDecl::Type(handle)) => Ok(handle), diff --git a/src/front/wgsl/parse/ast.rs b/src/front/wgsl/parse/ast.rs index 9354c6c765..2a56ac6f80 100644 --- a/src/front/wgsl/parse/ast.rs +++ b/src/front/wgsl/parse/ast.rs @@ -231,6 +231,8 @@ pub enum Type<'a> { }, AccelerationStructure, RayQuery, + RayDesc, + RayIntersection, BindingArray { base: Handle>, size: ArraySize<'a>, @@ -370,9 +372,6 @@ pub enum ConstructorType<'a> { size: ArraySize<'a>, }, - /// Ray description. - RayDesc, - /// Constructing a value of a known Naga IR type. /// /// This variant is produced only during lowering, when we have Naga types diff --git a/src/front/wgsl/parse/mod.rs b/src/front/wgsl/parse/mod.rs index e4a0d160fd..7a030259b8 100644 --- a/src/front/wgsl/parse/mod.rs +++ b/src/front/wgsl/parse/mod.rs @@ -441,7 +441,6 @@ impl Parser { })) } "array" => ast::ConstructorType::PartialArray, - "RayDesc" => ast::ConstructorType::RayDesc, "atomic" | "binding_array" | "sampler" @@ -1382,6 +1381,8 @@ impl Parser { } "acceleration_structure" => ast::Type::AccelerationStructure, "ray_query" => ast::Type::RayQuery, + "RayDesc" => ast::Type::RayDesc, + "RayIntersection" => ast::Type::RayIntersection, _ => return Ok(None), })) } diff --git a/tests/out/spv/ray-query.spvasm b/tests/out/spv/ray-query.spvasm index 6bc41ee30f..1a1a18bba1 100644 --- a/tests/out/spv/ray-query.spvasm +++ b/tests/out/spv/ray-query.spvasm @@ -10,12 +10,12 @@ OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %29 "main" %21 %23 OpExecutionMode %29 LocalSize 1 1 1 OpMemberDecorate %13 0 Offset 0 -OpMemberDecorate %16 0 Offset 0 -OpMemberDecorate %16 1 Offset 4 -OpMemberDecorate %16 2 Offset 8 -OpMemberDecorate %16 3 Offset 12 -OpMemberDecorate %16 4 Offset 16 -OpMemberDecorate %16 5 Offset 32 +OpMemberDecorate %15 0 Offset 0 +OpMemberDecorate %15 1 Offset 4 +OpMemberDecorate %15 2 Offset 8 +OpMemberDecorate %15 3 Offset 12 +OpMemberDecorate %15 4 Offset 16 +OpMemberDecorate %15 5 Offset 32 OpMemberDecorate %20 0 Offset 0 OpMemberDecorate %20 1 Offset 4 OpMemberDecorate %20 2 Offset 8 @@ -49,19 +49,19 @@ OpMemberDecorate %24 0 Offset 0 %11 = OpConstant %4 0 %12 = OpTypeAccelerationStructureNV %13 = OpTypeStruct %4 -%14 = OpTypeRayQueryKHR -%15 = OpTypeVector %7 3 -%16 = OpTypeStruct %4 %4 %7 %7 %15 %15 +%14 = OpTypeVector %7 3 +%15 = OpTypeStruct %4 %4 %7 %7 %14 %14 +%16 = OpTypeRayQueryKHR %17 = OpTypeVector %7 2 %18 = OpTypeBool -%19 = OpTypeMatrix %15 4 +%19 = OpTypeMatrix %14 4 %20 = OpTypeStruct %4 %7 %4 %4 %4 %4 %4 %17 %18 %19 %19 %22 = OpTypePointer UniformConstant %12 %21 = OpVariable %22 UniformConstant %24 = OpTypeStruct %13 %25 = OpTypePointer StorageBuffer %24 %23 = OpVariable %25 StorageBuffer -%27 = OpTypePointer Function %14 +%27 = OpTypePointer Function %16 %30 = OpTypeFunction %2 %32 = OpTypePointer StorageBuffer %13 %45 = OpConstant %4 1 @@ -73,15 +73,15 @@ OpMemberDecorate %24 0 Offset 0 %33 = OpAccessChain %32 %23 %11 OpBranch %34 %34 = OpLabel -%35 = OpCompositeConstruct %15 %9 %9 %9 -%36 = OpCompositeConstruct %15 %9 %10 %9 -%37 = OpCompositeConstruct %16 %3 %5 %6 %8 %35 %36 +%35 = OpCompositeConstruct %14 %9 %9 %9 +%36 = OpCompositeConstruct %14 %9 %10 %9 +%37 = OpCompositeConstruct %15 %3 %5 %6 %8 %35 %36 %38 = OpCompositeExtract %4 %37 0 %39 = OpCompositeExtract %4 %37 1 %40 = OpCompositeExtract %7 %37 2 %41 = OpCompositeExtract %7 %37 3 -%42 = OpCompositeExtract %15 %37 4 -%43 = OpCompositeExtract %15 %37 5 +%42 = OpCompositeExtract %14 %37 4 +%43 = OpCompositeExtract %14 %37 5 OpRayQueryInitializeKHR %26 %31 %38 %39 %42 %40 %43 %41 %44 = OpRayQueryProceedKHR %18 %26 %46 = OpRayQueryGetIntersectionTypeKHR %4 %26 %45 From 9d218d4b40f81bf017f691a8657558757f1b9465 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sun, 26 Feb 2023 00:27:03 -0800 Subject: [PATCH 08/12] msl: ray query support --- src/back/mod.rs | 23 +++ src/back/msl/mod.rs | 10 +- src/back/msl/writer.rs | 199 +++++++++++++++++++-- tests/in/ray-query.param.ron | 8 + tests/in/ray-query.wgsl | 13 +- tests/out/msl/binding-arrays.msl | 2 +- tests/out/msl/bounds-check-image-rzsw.msl | 2 +- tests/out/msl/bounds-check-zero-atomic.msl | 2 +- tests/out/msl/bounds-check-zero.msl | 2 +- tests/out/msl/policy-mix.msl | 2 +- tests/out/msl/ray-query.msl | 58 ++++++ tests/out/msl/resource-binding-map.msl | 2 +- tests/snapshots.rs | 2 +- 13 files changed, 295 insertions(+), 30 deletions(-) create mode 100644 tests/out/msl/ray-query.msl diff --git a/src/back/mod.rs b/src/back/mod.rs index 6755983c07..ac37a498a2 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -218,3 +218,26 @@ impl crate::Statement { } } } + +bitflags::bitflags! { + /// Ray flags. + #[derive(Default)] + pub struct RayFlag: u32 { + const OPAQUE = 0x01; + const NO_OPAQUE = 0x02; + const TERMINATE_ON_FIRST_HIT = 0x04; + const SKIP_CLOSEST_HIT_SHADER = 0x08; + const CULL_FRONT_FACING = 0x10; + const CULL_BACK_FACING = 0x20; + const CULL_OPAQUE = 0x40; + const CULL_NO_OPAQUE = 0x80; + const SKIP_TRIANGLES = 0x100; + const SKIP_AABBS = 0x200; + } +} + +#[repr(u32)] +enum RayIntersectionType { + Triangle = 1, + BoundingBox = 4, +} diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index d794976602..3174d4b756 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -314,10 +314,7 @@ impl Options { match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), - texture: None, - sampler: None, - binding_array_size: None, - mutable: false, + ..Default::default() })), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", @@ -338,10 +335,7 @@ impl Options { match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), - texture: None, - sampler: None, - binding_array_size: None, - mutable: false, + ..Default::default() })), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 2ec5bad339..88600a8f11 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -25,6 +25,13 @@ const WRAPPED_ARRAY_FIELD: &str = "inner"; // Some more general handling of pointers is needed to be implemented here. const ATOMIC_REFERENCE: &str = "&"; +const RT_NAMESPACE: &str = "metal::raytracing"; +const RAY_QUERY_TYPE: &str = "_RayQuery"; +const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector"; +const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection"; +const RAY_QUERY_FIELD_READY: &str = "ready"; +const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type"; + /// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. /// /// The `sizes` slice determines whether this function writes a @@ -194,8 +201,11 @@ impl<'a> Display for TypeContext<'a> { crate::TypeInner::Sampler { comparison: _ } => { write!(out, "{NAMESPACE}::sampler") } - crate::TypeInner::AccelerationStructure | crate::TypeInner::RayQuery => { - unreachable!("Ray queries are not supported yet"); + crate::TypeInner::AccelerationStructure => { + write!(out, "{RT_NAMESPACE}::instance_acceleration_structure") + } + crate::TypeInner::RayQuery => { + write!(out, "{RAY_QUERY_TYPE}") } crate::TypeInner::BindingArray { base, size } => { let base_tyname = Self { @@ -1865,8 +1875,39 @@ impl Writer { write!(self.out, ")")?; } } - // hot supported yet - crate::Expression::RayQueryGetIntersection { .. } => unreachable!(), + crate::Expression::RayQueryGetIntersection { query, committed } => { + if !committed { + unimplemented!() + } + let ty = context.module.special_types.ray_intersection.unwrap(); + let type_name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?; + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?; + let fields = [ + "distance", + "user_instance_id", + "instance_id", + "", // SBT offset + "geometry_id", + "primitive_id", + "triangle_barycentric_coord", + "triangle_front_facing", + "", // padding + "object_to_world_transform", + "world_to_object_transform", + ]; + for field in fields { + write!(self.out, ", ")?; + if field.is_empty() { + write!(self.out, "{{}}")?; + } else { + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?; + } + } + write!(self.out, "}}")?; + } } Ok(()) } @@ -2320,6 +2361,7 @@ impl Writer { ) { use crate::Expression; self.need_bake_expressions.clear(); + for (expr_handle, expr) in func.expressions.iter() { // Expressions whose reference count is above the // threshold should always be stored in temporaries. @@ -2327,6 +2369,16 @@ impl Writer { let min_ref_count = func.expressions[expr_handle].bake_ref_count(); if min_ref_count <= expr_info.ref_count { self.need_bake_expressions.insert(expr_handle); + } else { + match expr_info.ty { + // force ray desc to be baked: it's used multiple times internally + TypeResolution::Handle(h) + if Some(h) == context.module.special_types.ray_desc => + { + self.need_bake_expressions.insert(expr_handle); + } + _ => {} + } } if let Expression::Math { fun, arg, arg1, .. } = *expr { @@ -2338,11 +2390,11 @@ impl Writer { // times, once for each component (see `put_dot_product`), so to // avoid duplicated evaluation, we must bake integer operands. - use crate::TypeInner; // check what kind of product this is depending // on the resolve type of the Dot function itself - let inner = context.resolve_type(expr_handle); - if let TypeInner::Scalar { kind, .. } = *inner { + if let crate::TypeInner::Scalar { kind, .. } = + *context.resolve_type(expr_handle) + { match kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint => { self.need_bake_expressions.insert(arg); @@ -2763,7 +2815,100 @@ impl Writer { // done writeln!(self.out, ";")?; } - crate::Statement::RayQuery { .. } => unreachable!(), + crate::Statement::RayQuery { query, ref fun } => { + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //TODO: how to deal with winding? + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?; + { + let f_opaque = back::RayFlag::CULL_OPAQUE.bits(); + let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?; + } + { + let f_opaque = back::RayFlag::OPAQUE.bits(); + let f_no_opaque = back::RayFlag::NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?; + } + { + let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + writeln!(self.out, ".flags & {flag}) != 0);")?; + } + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray(" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".origin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".dir, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmax), ")?; + self.put_expression(acceleration_structure, &context.expression, true)?; + write!(self.out, ", ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".cull_mask);")?; + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?; + } + crate::RayQueryFunction::Proceed { result } => { + write!(self.out, "{level}")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?; + //TODO: actually proceed? + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?; + } + crate::RayQueryFunction::Terminate => { + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?; + } + } + } } } @@ -2875,14 +3020,41 @@ impl Writer { writeln!(self.out)?; // Work around Metal bug where `uint` is not available by default writeln!(self.out, "using {NAMESPACE}::uint;")?; - writeln!(self.out)?; + if module.types.iter().any(|(_, t)| match t.inner { + crate::TypeInner::RayQuery => true, + _ => false, + }) { + let tab = back::INDENT; + writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?; + let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>"); + writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?; + writeln!( + self.out, + "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};" + )?; + writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?; + writeln!(self.out, "}};")?; + writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?; + let v_triangle = back::RayIntersectionType::Triangle as u32; + let v_bbox = back::RayIntersectionType::BoundingBox as u32; + writeln!( + self.out, + "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : " + )?; + writeln!( + self.out, + "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;" + )?; + writeln!(self.out, "}}")?; + } if options .bounds_check_policies .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite) { self.put_default_constructible()?; } + writeln!(self.out)?; { let mut indices = vec![]; @@ -2924,11 +3096,12 @@ impl Writer { /// /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite fn put_default_constructible(&mut self) -> BackendResult { + let tab = back::INDENT; writeln!(self.out, "struct DefaultConstructible {{")?; - writeln!(self.out, " template")?; - writeln!(self.out, " operator T() && {{")?; - writeln!(self.out, " return T {{}};")?; - writeln!(self.out, " }}")?; + writeln!(self.out, "{tab}template")?; + writeln!(self.out, "{tab}operator T() && {{")?; + writeln!(self.out, "{tab}{tab}return T {{}};")?; + writeln!(self.out, "{tab}}}")?; writeln!(self.out, "}};")?; Ok(()) } diff --git a/tests/in/ray-query.param.ron b/tests/in/ray-query.param.ron index 9d8666954d..c400db8c64 100644 --- a/tests/in/ray-query.param.ron +++ b/tests/in/ray-query.param.ron @@ -3,4 +3,12 @@ spv: ( version: (1, 4), ), + msl: ( + lang_version: (2, 4), + spirv_cross_compatibility: false, + fake_missing_bindings: true, + zero_initialize_workgroup_memory: false, + per_entry_point_map: {}, + inline_samplers: [], + ), ) diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 5eabf3a2d3..b755d8f60a 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -2,8 +2,17 @@ var acc_struct: acceleration_structure; /* -let RAY_FLAG_NONE = 0u; -let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 4u; +let RAY_FLAG_NONE = 0x00u; +let RAY_FLAG_OPAQUE = 0x01u; +let RAY_FLAG_NO_OPAQUE = 0x02u; +let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u; +let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u; +let RAY_FLAG_CULL_FRONT_FACING = 0x10u; +let RAY_FLAG_CULL_BACK_FACING = 0x20u; +let RAY_FLAG_CULL_OPAQUE = 0x40u; +let RAY_FLAG_CULL_NO_OPAQUE = 0x80u; +let RAY_FLAG_SKIP_TRIANGLES = 0x100u; +let RAY_FLAG_SKIP_AABBS = 0x200u; let RAY_QUERY_INTERSECTION_NONE = 0u; let RAY_QUERY_INTERSECTION_TRIANGLE = 1u; diff --git a/tests/out/msl/binding-arrays.msl b/tests/out/msl/binding-arrays.msl index da1078b5d8..694f79452d 100644 --- a/tests/out/msl/binding-arrays.msl +++ b/tests/out/msl/binding-arrays.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct UniformIndex { uint index; }; diff --git a/tests/out/msl/bounds-check-image-rzsw.msl b/tests/out/msl/bounds-check-image-rzsw.msl index 9032af14ca..eeb03c9849 100644 --- a/tests/out/msl/bounds-check-image-rzsw.msl +++ b/tests/out/msl/bounds-check-image-rzsw.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + constant metal::int2 const_type_4_ = {0, 0}; constant metal::int3 const_type_7_ = {0, 0, 0}; constant metal::float4 const_type_2_ = {0.0, 0.0, 0.0, 0.0}; diff --git a/tests/out/msl/bounds-check-zero-atomic.msl b/tests/out/msl/bounds-check-zero-atomic.msl index 95028ee796..daaa079233 100644 --- a/tests/out/msl/bounds-check-zero-atomic.msl +++ b/tests/out/msl/bounds-check-zero-atomic.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct _mslBufferSizes { uint size0; }; diff --git a/tests/out/msl/bounds-check-zero.msl b/tests/out/msl/bounds-check-zero.msl index fece92de35..816983d98b 100644 --- a/tests/out/msl/bounds-check-zero.msl +++ b/tests/out/msl/bounds-check-zero.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct _mslBufferSizes { uint size0; }; diff --git a/tests/out/msl/policy-mix.msl b/tests/out/msl/policy-mix.msl index 842c57e58c..7eb0c61ede 100644 --- a/tests/out/msl/policy-mix.msl +++ b/tests/out/msl/policy-mix.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct type_1 { metal::float4 inner[10]; }; diff --git a/tests/out/msl/ray-query.msl b/tests/out/msl/ray-query.msl new file mode 100644 index 0000000000..2a09737873 --- /dev/null +++ b/tests/out/msl/ray-query.msl @@ -0,0 +1,58 @@ +// language: metal2.4 +#include +#include + +using metal::uint; +struct _RayQuery { + metal::raytracing::intersector intersector; + metal::raytracing::intersector::result_type intersection; + bool ready = false; +}; +constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) { + return ty==metal::raytracing::intersection_type::triangle ? 1 : + ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0; +} + +struct Output { + uint visible_; +}; +struct RayDesc { + uint flags; + uint cull_mask; + float tmin; + float tmax; + metal::float3 origin; + metal::float3 dir; +}; +struct RayIntersection { + uint kind; + float t; + uint instance_custom_index; + uint instance_id; + uint sbt_record_offset; + uint geometry_index; + uint primitive_index; + metal::float2 barycentrics; + bool front_face; + char _pad9[11]; + metal::float4x3 object_to_world; + metal::float4x3 world_to_object; +}; + +kernel void main_( + metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]] +, device Output& output [[user(fake0)]] +) { + _RayQuery rq = {}; + RayDesc _e12 = RayDesc {4u, 255u, 0.10000000149011612, 100.0, metal::float3(0.0), metal::float3(0.0, 1.0, 0.0)}; + rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle); + rq.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none); + rq.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); + rq.intersector.accept_any_intersection((_e12.flags & 4) != 0); + rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq.ready = true; + bool _e13 = rq.ready; + rq.ready = false; + RayIntersection intersection = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform}; + output.visible_ = static_cast(intersection.kind == 0u); + return; +} diff --git a/tests/out/msl/resource-binding-map.msl b/tests/out/msl/resource-binding-map.msl index 4e0b601320..b4a53d97b5 100644 --- a/tests/out/msl/resource-binding-map.msl +++ b/tests/out/msl/resource-binding-map.msl @@ -3,7 +3,6 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { @@ -11,6 +10,7 @@ struct DefaultConstructible { } }; + struct entry_point_oneInput { }; struct entry_point_oneOutput { diff --git a/tests/snapshots.rs b/tests/snapshots.rs index df94130a70..d968a0dfc1 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -571,7 +571,7 @@ fn convert_wgsl() { ("sprite", Targets::SPIRV), ("force_point_size_vertex_shader_webgl", Targets::GLSL), ("invariant", Targets::GLSL), - ("ray-query", Targets::SPIRV), + ("ray-query", Targets::SPIRV | Targets::METAL), ]; for &(name, targets) in inputs.iter() { From e4ad3159eec9b207fa4739a330823cbdeee5cb7c Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 27 Feb 2023 12:42:57 -0800 Subject: [PATCH 09/12] Expand on the documentation for ray-tracing features. --- src/back/mod.rs | 9 ++++++- src/front/type_gen.rs | 29 +++++++++++++++++++-- src/lib.rs | 60 ++++++++++++++++++++++++++++++++++++++++--- src/valid/handles.rs | 2 +- 4 files changed, 92 insertions(+), 8 deletions(-) diff --git a/src/back/mod.rs b/src/back/mod.rs index ac37a498a2..f51262524c 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -220,7 +220,14 @@ impl crate::Statement { } bitflags::bitflags! { - /// Ray flags. + /// Ray flags, for a [`RayDesc`]'s `flags` field. + /// + /// Note that these exactly correspond to the SPIR-V "Ray Flags" mask, and + /// the SPIR-V backend passes them directly through to the + /// `OpRayQueryInitializeKHR` instruction. (We have to choose something, so + /// we might as well make one back end's life easier.) + /// + /// [`RayDesc`]: crate::Module::generate_ray_desc_type #[derive(Default)] pub struct RayFlag: u32 { const OPAQUE = 0x01; diff --git a/src/front/type_gen.rs b/src/front/type_gen.rs index bb734ac69c..b695b52792 100644 --- a/src/front/type_gen.rs +++ b/src/front/type_gen.rs @@ -5,7 +5,21 @@ Type generators. use crate::{arena::Handle, span::Span}; impl crate::Module { - //Note: has to match `struct RayDesc` + /// Populate this module's [`SpecialTypes::ray_desc`] type. + /// + /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of + /// an [`Initialize`] [`RayQuery`] statement. In WGSL, it is a struct type + /// referred to as `RayDesc`. + /// + /// Backends consume values of this type to drive platform APIs, so if you + /// change any its fields, you must update the backends to match. Look for + /// backend code dealing with [`RayQueryFunction::Initialize`]. + /// + /// [`SpecialTypes::ray_desc`]: crate::SpecialTypes::ray_desc + /// [`descriptor`]: crate::RayQueryFunction::Initialize::descriptor + /// [`Initialize`]: crate::RayQueryFunction::Initialize + /// [`RayQuery`]: crate::Statement::RayQuery + /// [`RayQueryFunction::Initialize`]: crate::RayQueryFunction::Initialize pub(super) fn generate_ray_desc_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_desc { return handle; @@ -96,7 +110,18 @@ impl crate::Module { handle } - //Note: has to match `struct RayIntersection` + /// Populate this module's [`SpecialTypes::ray_intersection`] type. + /// + /// [`SpecialTypes::ray_intersection`] is the type of a + /// `RayQueryGetIntersection` expression. In WGSL, it is a struct type + /// referred to as `RayIntersection`. + /// + /// Backends construct values of this type based on platform APIs, so if you + /// change any its fields, you must update the backends to match. Look for + /// the backend's handling for [`Expression::RayQueryGetIntersection`]. + /// + /// [`SpecialTypes::ray_intersection`]: crate::SpecialTypes::ray_intersection + /// [`Expression::RayQueryGetIntersection`]: crate::Expression::RayQueryGetIntersection pub(super) fn generate_ray_intersection_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_intersection { return handle; diff --git a/src/lib.rs b/src/lib.rs index bf5f9a3a2a..f91dc7ce56 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -107,8 +107,10 @@ Naga's rules for when `Expression`s are evaluated are as follows: [`Atomic`] statement, representing the result of the atomic operation, is evaluated when the `Atomic` statement is executed. -- Similarly, an [`RayQueryProceedResult`] expression, which is a boolean - indicating if the ray query is finished. +- A [`RayQueryProceedResult`] expression, which is a boolean + indicating if the ray query is finished, is evaluated when the + [`RayQuery`] statement whose [`Proceed::result`] points to it is + executed. - All other expressions are evaluated when the (unique) [`Statement::Emit`] statement that covers them is executed. @@ -184,6 +186,9 @@ tree. [`Call`]: Statement::Call [`Emit`]: Statement::Emit [`Store`]: Statement::Store +[`RayQuery`]: Statement::RayQuery + +[`Proceed::result`]: RayQueryFunction::Proceed::result [`Validator::validate`]: valid::Validator::validate [`ModuleInfo`]: valid::ModuleInfo @@ -727,6 +732,7 @@ pub enum TypeInner { /// Opaque object representing an acceleration structure of geometry. AccelerationStructure, + /// Locally used handle for ray queries. RayQuery, @@ -1445,9 +1451,16 @@ pub enum Expression { /// This doesn't match the semantics of spirv's `OpArrayLength`, which must be passed /// a pointer to a structure containing a runtime array in its' last field. ArrayLength(Handle), - /// Result of `rayQueryProceed`. + + /// Result of a [`Proceed`] [`RayQuery`] statement. + /// + /// [`Proceed`]: RayQueryFunction::Proceed + /// [`RayQuery`]: Statement::RayQuery RayQueryProceedResult, - /// Result of `rayQueryGet*Intersection`. + + /// Return an intersection found by `query`. + /// + /// If `committed` is true, return the committed result available when RayQueryGetIntersection { query: Handle, committed: bool, @@ -1484,18 +1497,45 @@ pub struct SwitchCase { pub fall_through: bool, } +/// An operation that a [`RayQuery` statement] applies to its [`query`] operand. +/// +/// [`RayQuery` statement]: Statement::RayQuery +/// [`query`]: Statement::RayQuery::query #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum RayQueryFunction { + /// Initialize the `RayQuery` object. Initialize { + /// The acceleration structure within which this query should search for hits. + /// + /// The expression must be an [`AccelerationStructure`]. + /// + /// [`AccelerationStructure`]: TypeInner::AccelerationStructure acceleration_structure: Handle, + + #[allow(rustdoc::private_intra_doc_links)] + /// A struct of detailed parameters for the ray query. + /// + /// This expression should have the struct type given in + /// [`SpecialTypes::ray_desc`]. This is available in the WGSL + /// front end as the `RayDesc` type. descriptor: Handle, }, + + /// Start or continue the query given by the statement's [`query`] operand. + /// + /// After executing this statement, the `result` expression is a + /// [`Bool`] scalar indicating whether there are more intersection + /// candidates to consider. + /// + /// [`query`]: Statement::RayQuery::query + /// [`Bool`]: ScalarKind::Bool Proceed { result: Handle, }, + Terminate, } @@ -1673,7 +1713,12 @@ pub enum Statement { result: Option>, }, RayQuery { + /// The [`RayQuery`] object this statement operates on. + /// + /// [`RayQuery`]: TypeInner::RayQuery query: Handle, + + /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, } @@ -1800,8 +1845,15 @@ pub struct EntryPoint { #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct SpecialTypes { /// Type for `RayDesc`. + /// + /// Call [`Module::generate_ray_desc_type`] to populate this if + /// needed and return the handle. ray_desc: Option>, + /// Type for `RayIntersection`. + /// + /// Call [`Module::generate_ray_intersection_type`] to populate + /// this if needed and return the handle. ray_intersection: Option>, } diff --git a/src/valid/handles.rs b/src/valid/handles.rs index c9e6cad502..fdd43cd585 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -1,4 +1,4 @@ -//! Implementation of [`super::Validator::validate_module_handles`]. +//! Implementation of `Validator::validate_module_handles`. use crate::{ arena::{BadHandle, BadRangeError}, From b52dd0f0b2ac55da58e27a7f9218ded2ed5ca894 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 17 Mar 2023 22:53:13 -0700 Subject: [PATCH 10/12] ray query: validation, better test --- src/valid/expression.rs | 23 +++- src/valid/function.rs | 95 +++++++++++--- tests/in/ray-query.wgsl | 14 +- tests/out/msl/ray-query.msl | 47 +++++-- tests/out/spv/ray-query.spvasm | 229 ++++++++++++++++++++------------- 5 files changed, 284 insertions(+), 124 deletions(-) diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 408dccaf10..1a91fe4d0a 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -35,6 +35,8 @@ pub enum ExpressionError { InvalidPointerType(Handle), #[error("Array length of {0:?} can't be done")] InvalidArrayType(Handle), + #[error("Get intersection of {0:?} can't be done")] + InvalidRayQueryType(Handle), #[error("Splatting {0:?} can't be done")] InvalidSplatType(Handle), #[error("Swizzling {0:?} can't be done")] @@ -1427,7 +1429,26 @@ impl super::Validator { return Err(ExpressionError::InvalidArrayType(expr)); } }, - E::RayQueryProceedResult | E::RayQueryGetIntersection { .. } => ShaderStages::all(), + E::RayQueryProceedResult => ShaderStages::all(), + E::RayQueryGetIntersection { + query, + committed: _, + } => match resolver[query] { + Ti::Pointer { + base, + space: crate::AddressSpace::Function, + } => match resolver.types[base].inner { + Ti::RayQuery => ShaderStages::all(), + ref other => { + log::error!("Intersection result of a pointer to {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, + ref other => { + log::error!("Intersection result of {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index a13a07bcfa..737f33dc28 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -47,8 +47,6 @@ pub enum AtomicError { InvalidPointer(Handle), #[error("Operand {0:?} has invalid type.")] InvalidOperand(Handle), - #[error("Result expression {0:?} has already been introduced earlier")] - ResultAlreadyInScope(Handle), #[error("Result type for {0:?} doesn't match the statement")] ResultTypeMismatch(Handle), } @@ -131,6 +129,14 @@ pub enum FunctionError { }, #[error("Atomic operation is invalid")] InvalidAtomic(#[from] AtomicError), + #[error("Ray Query {0:?} is not a local variable")] + InvalidRayQueryExpression(Handle), + #[error("Acceleration structure {0:?} is not a matching expression")] + InvalidAccelerationStructure(Handle), + #[error("Ray descriptor {0:?} is not a matching expression")] + InvalidRayDescriptor(Handle), + #[error("Ray Query {0:?} does not have a matching type")] + InvalidRayQueryType(Handle), #[error( "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" )] @@ -169,8 +175,10 @@ struct BlockContext<'a> { info: &'a FunctionInfo, expressions: &'a Arena, types: &'a UniqueArena, + local_vars: &'a Arena, global_vars: &'a Arena, functions: &'a Arena, + special_types: &'a crate::SpecialTypes, prev_infos: &'a [FunctionInfo], return_type: Option>, } @@ -188,8 +196,10 @@ impl<'a> BlockContext<'a> { info, expressions: &fun.expressions, types: &module.types, + local_vars: &fun.local_variables, global_vars: &module.global_variables, functions: &module.functions, + special_types: &module.special_types, prev_infos, return_type: fun.result.as_ref().map(|fr| fr.ty), } @@ -299,6 +309,21 @@ impl super::Validator { Ok(callee_info.available_stages) } + #[cfg(feature = "validate")] + fn emit_expression( + &mut self, + handle: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + if self.valid_expression_set.insert(handle.index()) { + self.valid_expression_list.push(handle); + Ok(()) + } else { + Err(FunctionError::ExpressionAlreadyInScope(handle) + .with_span_handle(handle, context.expressions)) + } + } + #[cfg(feature = "validate")] fn validate_atomic( &mut self, @@ -347,13 +372,7 @@ impl super::Validator { } } - if self.valid_expression_set.insert(result.index()) { - self.valid_expression_list.push(result); - } else { - return Err(AtomicError::ResultAlreadyInScope(result) - .with_span_handle(result, context.expressions) - .into_other()); - } + self.emit_expression(result, context)?; match context.expressions[result] { crate::Expression::AtomicResult { ty, comparison } if { @@ -401,12 +420,7 @@ impl super::Validator { match *statement { S::Emit(ref range) => { for handle in range.clone() { - if self.valid_expression_set.insert(handle.index()) { - self.valid_expression_list.push(handle); - } else { - return Err(FunctionError::ExpressionAlreadyInScope(handle) - .with_span_handle(handle, context.expressions)); - } + self.emit_expression(handle, context)?; } } S::Block(ref block) => { @@ -807,8 +821,55 @@ impl super::Validator { } => { self.validate_atomic(pointer, fun, value, result, context)?; } - S::RayQuery { query: _, fun: _ } => { - //TODO + S::RayQuery { query, ref fun } => { + let query_var = match *context.get_expression(query) { + crate::Expression::LocalVariable(var) => &context.local_vars[var], + ref other => { + log::error!("Unexpected ray query expression {other:?}"); + return Err(FunctionError::InvalidRayQueryExpression(query) + .with_span_static(span, "invalid query expression")); + } + }; + match context.types[query_var.ty].inner { + Ti::RayQuery => {} + ref other => { + log::error!("Unexpected ray query type {other:?}"); + return Err(FunctionError::InvalidRayQueryType(query_var.ty) + .with_span_static(span, "invalid query type")); + } + } + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + match *context + .resolve_type(acceleration_structure, &self.valid_expression_set)? + { + Ti::AccelerationStructure => {} + _ => { + return Err(FunctionError::InvalidAccelerationStructure( + acceleration_structure, + ) + .with_span_static(span, "invalid acceleration structure")) + } + } + let desc_ty_given = + context.resolve_type(descriptor, &self.valid_expression_set)?; + let desc_ty_expected = context + .special_types + .ray_desc + .map(|handle| &context.types[handle].inner); + if Some(desc_ty_given) != desc_ty_expected { + return Err(FunctionError::InvalidRayDescriptor(descriptor) + .with_span_static(span, "invalid ray descriptor")); + } + } + crate::RayQueryFunction::Proceed { result } => { + self.emit_expression(result, context)?; + } + crate::RayQueryFunction::Terminate => {} + } } } } diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index b755d8f60a..1a9c967490 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -45,19 +45,29 @@ struct RayIntersection { struct Output { visible: u32, + normal: vec3, } @group(0) @binding(1) var output: Output; +fn get_torus_normal(world_point: vec3, intersection: RayIntersection) -> vec3 { + let local_point = intersection.world_to_object * vec4(world_point, 1.0); + let point_on_guiding_line = normalize(local_point.xy) * 2.4; + let world_point_on_guiding_line = intersection.object_to_world * vec4(point_on_guiding_line, 0.0, 1.0); + return normalize(world_point - world_point_on_guiding_line); +} + @compute @workgroup_size(1) fn main() { var rq: ray_query; - rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), vec3(0.0, 1.0, 0.0))); + let dir = vec3(0.0, 1.0, 0.0); + rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), dir)); - rayQueryProceed(&rq); + while (rayQueryProceed(&rq)) {} let intersection = rayQueryGetCommittedIntersection(&rq); output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE); + output.normal = get_torus_normal(dir * intersection.t, intersection); } diff --git a/tests/out/msl/ray-query.msl b/tests/out/msl/ray-query.msl index 2a09737873..dc24f80674 100644 --- a/tests/out/msl/ray-query.msl +++ b/tests/out/msl/ray-query.msl @@ -15,14 +15,8 @@ constexpr metal::uint _map_intersection_type(const metal::raytracing::intersecti struct Output { uint visible_; -}; -struct RayDesc { - uint flags; - uint cull_mask; - float tmin; - float tmax; - metal::float3 origin; - metal::float3 dir; + char _pad1[12]; + metal::float3 normal; }; struct RayIntersection { uint kind; @@ -38,21 +32,48 @@ struct RayIntersection { metal::float4x3 object_to_world; metal::float4x3 world_to_object; }; +struct RayDesc { + uint flags; + uint cull_mask; + float tmin; + float tmax; + metal::float3 origin; + metal::float3 dir; +}; + +metal::float3 get_torus_normal( + metal::float3 world_point, + RayIntersection intersection +) { + metal::float3 local_point = intersection.world_to_object * metal::float4(world_point, 1.0); + metal::float2 point_on_guiding_line = metal::normalize(local_point.xy) * 2.4000000953674316; + metal::float3 world_point_on_guiding_line = intersection.object_to_world * metal::float4(point_on_guiding_line, 0.0, 1.0); + return metal::normalize(world_point - world_point_on_guiding_line); +} kernel void main_( metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]] , device Output& output [[user(fake0)]] ) { _RayQuery rq = {}; - RayDesc _e12 = RayDesc {4u, 255u, 0.10000000149011612, 100.0, metal::float3(0.0), metal::float3(0.0, 1.0, 0.0)}; + metal::float3 dir = metal::float3(0.0, 1.0, 0.0); + RayDesc _e12 = RayDesc {4u, 255u, 0.10000000149011612, 100.0, metal::float3(0.0), dir}; rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle); rq.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none); rq.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); rq.intersector.accept_any_intersection((_e12.flags & 4) != 0); rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq.ready = true; - bool _e13 = rq.ready; - rq.ready = false; - RayIntersection intersection = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform}; - output.visible_ = static_cast(intersection.kind == 0u); + while(true) { + bool _e13 = rq.ready; + rq.ready = false; + if (_e13) { + } else { + break; + } + } + RayIntersection intersection_1 = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform}; + output.visible_ = static_cast(intersection_1.kind == 0u); + metal::float3 _e25 = get_torus_normal(dir * intersection_1.t, intersection_1); + output.normal = _e25; return; } diff --git a/tests/out/spv/ray-query.spvasm b/tests/out/spv/ray-query.spvasm index 1a1a18bba1..306cda758c 100644 --- a/tests/out/spv/ray-query.spvasm +++ b/tests/out/spv/ray-query.spvasm @@ -1,105 +1,152 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 63 +; Bound: 95 OpCapability RayQueryKHR OpCapability Shader OpExtension "SPV_KHR_ray_query" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %29 "main" %21 %23 -OpExecutionMode %29 LocalSize 1 1 1 -OpMemberDecorate %13 0 Offset 0 +OpEntryPoint GLCompute %48 "main" %23 %25 +OpExecutionMode %48 LocalSize 1 1 1 OpMemberDecorate %15 0 Offset 0 -OpMemberDecorate %15 1 Offset 4 -OpMemberDecorate %15 2 Offset 8 -OpMemberDecorate %15 3 Offset 12 -OpMemberDecorate %15 4 Offset 16 -OpMemberDecorate %15 5 Offset 32 -OpMemberDecorate %20 0 Offset 0 -OpMemberDecorate %20 1 Offset 4 -OpMemberDecorate %20 2 Offset 8 -OpMemberDecorate %20 3 Offset 12 -OpMemberDecorate %20 4 Offset 16 -OpMemberDecorate %20 5 Offset 20 -OpMemberDecorate %20 6 Offset 24 -OpMemberDecorate %20 7 Offset 28 -OpMemberDecorate %20 8 Offset 36 -OpMemberDecorate %20 9 Offset 48 -OpMemberDecorate %20 9 ColMajor -OpMemberDecorate %20 9 MatrixStride 16 -OpMemberDecorate %20 10 Offset 112 -OpMemberDecorate %20 10 ColMajor -OpMemberDecorate %20 10 MatrixStride 16 -OpDecorate %21 DescriptorSet 0 -OpDecorate %21 Binding 0 +OpMemberDecorate %15 1 Offset 16 +OpMemberDecorate %19 0 Offset 0 +OpMemberDecorate %19 1 Offset 4 +OpMemberDecorate %19 2 Offset 8 +OpMemberDecorate %19 3 Offset 12 +OpMemberDecorate %19 4 Offset 16 +OpMemberDecorate %19 5 Offset 20 +OpMemberDecorate %19 6 Offset 24 +OpMemberDecorate %19 7 Offset 28 +OpMemberDecorate %19 8 Offset 36 +OpMemberDecorate %19 9 Offset 48 +OpMemberDecorate %19 9 ColMajor +OpMemberDecorate %19 9 MatrixStride 16 +OpMemberDecorate %19 10 Offset 112 +OpMemberDecorate %19 10 ColMajor +OpMemberDecorate %19 10 MatrixStride 16 +OpMemberDecorate %21 0 Offset 0 +OpMemberDecorate %21 1 Offset 4 +OpMemberDecorate %21 2 Offset 8 +OpMemberDecorate %21 3 Offset 12 +OpMemberDecorate %21 4 Offset 16 +OpMemberDecorate %21 5 Offset 32 OpDecorate %23 DescriptorSet 0 -OpDecorate %23 Binding 1 -OpDecorate %24 Block -OpMemberDecorate %24 0 Offset 0 +OpDecorate %23 Binding 0 +OpDecorate %25 DescriptorSet 0 +OpDecorate %25 Binding 1 +OpDecorate %26 Block +OpMemberDecorate %26 0 Offset 0 %2 = OpTypeVoid -%4 = OpTypeInt 32 0 -%3 = OpConstant %4 4 -%5 = OpConstant %4 255 -%7 = OpTypeFloat 32 -%6 = OpConstant %7 0.1 -%8 = OpConstant %7 100.0 -%9 = OpConstant %7 0.0 -%10 = OpConstant %7 1.0 -%11 = OpConstant %4 0 -%12 = OpTypeAccelerationStructureNV -%13 = OpTypeStruct %4 -%14 = OpTypeVector %7 3 -%15 = OpTypeStruct %4 %4 %7 %7 %14 %14 -%16 = OpTypeRayQueryKHR -%17 = OpTypeVector %7 2 -%18 = OpTypeBool -%19 = OpTypeMatrix %14 4 -%20 = OpTypeStruct %4 %7 %4 %4 %4 %4 %4 %17 %18 %19 %19 -%22 = OpTypePointer UniformConstant %12 -%21 = OpVariable %22 UniformConstant -%24 = OpTypeStruct %13 -%25 = OpTypePointer StorageBuffer %24 -%23 = OpVariable %25 StorageBuffer -%27 = OpTypePointer Function %16 -%30 = OpTypeFunction %2 -%32 = OpTypePointer StorageBuffer %13 -%45 = OpConstant %4 1 -%58 = OpTypePointer StorageBuffer %4 -%29 = OpFunction %2 None %30 +%4 = OpTypeFloat 32 +%3 = OpConstant %4 1.0 +%5 = OpConstant %4 2.4 +%6 = OpConstant %4 0.0 +%8 = OpTypeInt 32 0 +%7 = OpConstant %8 4 +%9 = OpConstant %8 255 +%10 = OpConstant %4 0.1 +%11 = OpConstant %4 100.0 +%12 = OpConstant %8 0 +%13 = OpTypeAccelerationStructureNV +%14 = OpTypeVector %4 3 +%15 = OpTypeStruct %8 %14 +%16 = OpTypeVector %4 2 +%17 = OpTypeBool +%18 = OpTypeMatrix %14 4 +%19 = OpTypeStruct %8 %4 %8 %8 %8 %8 %8 %16 %17 %18 %18 +%20 = OpTypeVector %4 4 +%21 = OpTypeStruct %8 %8 %4 %4 %14 %14 +%22 = OpTypeRayQueryKHR +%24 = OpTypePointer UniformConstant %13 +%23 = OpVariable %24 UniformConstant +%26 = OpTypeStruct %15 +%27 = OpTypePointer StorageBuffer %26 +%25 = OpVariable %27 StorageBuffer +%32 = OpTypeFunction %14 %14 %19 +%46 = OpTypePointer Function %22 +%49 = OpTypeFunction %2 +%51 = OpTypePointer StorageBuffer %15 +%72 = OpConstant %8 1 +%85 = OpTypePointer StorageBuffer %8 +%90 = OpTypePointer StorageBuffer %14 +%31 = OpFunction %14 None %32 +%29 = OpFunctionParameter %14 +%30 = OpFunctionParameter %19 %28 = OpLabel -%26 = OpVariable %27 Function -%31 = OpLoad %12 %21 -%33 = OpAccessChain %32 %23 %11 -OpBranch %34 -%34 = OpLabel -%35 = OpCompositeConstruct %14 %9 %9 %9 -%36 = OpCompositeConstruct %14 %9 %10 %9 -%37 = OpCompositeConstruct %15 %3 %5 %6 %8 %35 %36 -%38 = OpCompositeExtract %4 %37 0 -%39 = OpCompositeExtract %4 %37 1 -%40 = OpCompositeExtract %7 %37 2 -%41 = OpCompositeExtract %7 %37 3 -%42 = OpCompositeExtract %14 %37 4 -%43 = OpCompositeExtract %14 %37 5 -OpRayQueryInitializeKHR %26 %31 %38 %39 %42 %40 %43 %41 -%44 = OpRayQueryProceedKHR %18 %26 -%46 = OpRayQueryGetIntersectionTypeKHR %4 %26 %45 -%47 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %4 %26 %45 -%48 = OpRayQueryGetIntersectionInstanceIdKHR %4 %26 %45 -%49 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %4 %26 %45 -%50 = OpRayQueryGetIntersectionGeometryIndexKHR %4 %26 %45 -%51 = OpRayQueryGetIntersectionPrimitiveIndexKHR %4 %26 %45 -%52 = OpRayQueryGetIntersectionTKHR %7 %26 %45 -%53 = OpRayQueryGetIntersectionBarycentricsKHR %17 %26 %45 -%54 = OpRayQueryGetIntersectionFrontFaceKHR %18 %26 %45 -%55 = OpRayQueryGetIntersectionObjectToWorldKHR %19 %26 %45 -%56 = OpRayQueryGetIntersectionWorldToObjectKHR %19 %26 %45 -%57 = OpCompositeConstruct %20 %46 %52 %47 %48 %49 %50 %51 %53 %54 %55 %56 -%59 = OpCompositeExtract %4 %57 0 -%60 = OpIEqual %18 %59 %11 -%61 = OpSelect %4 %60 %45 %11 -%62 = OpAccessChain %58 %33 %11 -OpStore %62 %61 +OpBranch %33 +%33 = OpLabel +%34 = OpCompositeExtract %18 %30 10 +%35 = OpCompositeConstruct %20 %29 %3 +%36 = OpMatrixTimesVector %14 %34 %35 +%37 = OpVectorShuffle %16 %36 %36 0 1 +%38 = OpExtInst %16 %1 Normalize %37 +%39 = OpVectorTimesScalar %16 %38 %5 +%40 = OpCompositeExtract %18 %30 9 +%41 = OpCompositeConstruct %20 %39 %6 %3 +%42 = OpMatrixTimesVector %14 %40 %41 +%43 = OpFSub %14 %29 %42 +%44 = OpExtInst %14 %1 Normalize %43 +OpReturnValue %44 +OpFunctionEnd +%48 = OpFunction %2 None %49 +%47 = OpLabel +%45 = OpVariable %46 Function +%50 = OpLoad %13 %23 +%52 = OpAccessChain %51 %25 %12 +OpBranch %53 +%53 = OpLabel +%54 = OpCompositeConstruct %14 %6 %3 %6 +%55 = OpCompositeConstruct %14 %6 %6 %6 +%56 = OpCompositeConstruct %21 %7 %9 %10 %11 %55 %54 +%57 = OpCompositeExtract %8 %56 0 +%58 = OpCompositeExtract %8 %56 1 +%59 = OpCompositeExtract %4 %56 2 +%60 = OpCompositeExtract %4 %56 3 +%61 = OpCompositeExtract %14 %56 4 +%62 = OpCompositeExtract %14 %56 5 +OpRayQueryInitializeKHR %45 %50 %57 %58 %61 %59 %62 %60 +OpBranch %63 +%63 = OpLabel +OpLoopMerge %64 %66 None +OpBranch %65 +%65 = OpLabel +%67 = OpRayQueryProceedKHR %17 %45 +OpSelectionMerge %68 None +OpBranchConditional %67 %68 %69 +%69 = OpLabel +OpBranch %64 +%68 = OpLabel +OpBranch %70 +%70 = OpLabel +OpBranch %71 +%71 = OpLabel +OpBranch %66 +%66 = OpLabel +OpBranch %63 +%64 = OpLabel +%73 = OpRayQueryGetIntersectionTypeKHR %8 %45 %72 +%74 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %8 %45 %72 +%75 = OpRayQueryGetIntersectionInstanceIdKHR %8 %45 %72 +%76 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %8 %45 %72 +%77 = OpRayQueryGetIntersectionGeometryIndexKHR %8 %45 %72 +%78 = OpRayQueryGetIntersectionPrimitiveIndexKHR %8 %45 %72 +%79 = OpRayQueryGetIntersectionTKHR %4 %45 %72 +%80 = OpRayQueryGetIntersectionBarycentricsKHR %16 %45 %72 +%81 = OpRayQueryGetIntersectionFrontFaceKHR %17 %45 %72 +%82 = OpRayQueryGetIntersectionObjectToWorldKHR %18 %45 %72 +%83 = OpRayQueryGetIntersectionWorldToObjectKHR %18 %45 %72 +%84 = OpCompositeConstruct %19 %73 %79 %74 %75 %76 %77 %78 %80 %81 %82 %83 +%86 = OpCompositeExtract %8 %84 0 +%87 = OpIEqual %17 %86 %12 +%88 = OpSelect %8 %87 %72 %12 +%89 = OpAccessChain %85 %52 %12 +OpStore %89 %88 +%91 = OpCompositeExtract %4 %84 1 +%92 = OpVectorTimesScalar %14 %54 %91 +%93 = OpFunctionCall %14 %31 %92 %84 +%94 = OpAccessChain %90 %52 %72 +OpStore %94 %93 OpReturn OpFunctionEnd \ No newline at end of file From 85383aaea3b9fa4ce0e2ed8e2f2ba2f42f89767a Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 21 Mar 2023 22:47:53 -0700 Subject: [PATCH 11/12] Address Jim's review notes, use typegen module for atomic struct --- src/back/dot/mod.rs | 28 +++++++----- src/back/mod.rs | 4 +- src/back/spv/writer.rs | 2 +- src/front/type_gen.rs | 53 ++++++++++++++++++++++- src/front/wgsl/error.rs | 6 +++ src/front/wgsl/lower/mod.rs | 53 ++++------------------- src/lib.rs | 6 +-- tests/in/ray-query.wgsl | 4 +- tests/out/wgsl/atomicCompareExchange.wgsl | 4 +- 9 files changed, 92 insertions(+), 68 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index d293c3adc1..1eebbee067 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -254,19 +254,25 @@ impl StatementGraph { } S::RayQuery { query, ref fun } => { self.dependencies.push((id, query, "query")); - if let crate::RayQueryFunction::Initialize { - acceleration_structure, - descriptor, - } = *fun - { - self.dependencies.push(( - id, + match *fun { + crate::RayQueryFunction::Initialize { acceleration_structure, - "acceleration_structure", - )); - self.dependencies.push((id, descriptor, "descriptor")); + descriptor, + } => { + self.dependencies.push(( + id, + acceleration_structure, + "acceleration_structure", + )); + self.dependencies.push((id, descriptor, "descriptor")); + "RayQueryInitialize" + } + crate::RayQueryFunction::Proceed { result } => { + self.emits.push((id, result)); + "RayQueryProceed" + } + crate::RayQueryFunction::Terminate => "RayQueryTerminate", } - "RayQuery" } }; // Set the last node to the merge node diff --git a/src/back/mod.rs b/src/back/mod.rs index f51262524c..8467ee787b 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -234,8 +234,8 @@ bitflags::bitflags! { const NO_OPAQUE = 0x02; const TERMINATE_ON_FIRST_HIT = 0x04; const SKIP_CLOSEST_HIT_SHADER = 0x08; - const CULL_FRONT_FACING = 0x10; - const CULL_BACK_FACING = 0x20; + const CULL_BACK_FACING = 0x10; + const CULL_FRONT_FACING = 0x20; const CULL_OPAQUE = 0x40; const CULL_NO_OPAQUE = 0x80; const SKIP_TRIANGLES = 0x100; diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 800a40ed68..ba235e6d03 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -973,7 +973,7 @@ impl Writer { self.write_type_declaration_local(id, local); - // If it's an type that needs SPIR-V capabilities, request them now, + // If it's a type that needs SPIR-V capabilities, request them now, // so write_type_declaration_local can stay infallible. self.request_type_capabilities(&ty.inner)?; diff --git a/src/front/type_gen.rs b/src/front/type_gen.rs index b695b52792..1ee454c448 100644 --- a/src/front/type_gen.rs +++ b/src/front/type_gen.rs @@ -5,6 +5,55 @@ Type generators. use crate::{arena::Handle, span::Span}; impl crate::Module { + pub fn generate_atomic_compare_exchange_result( + &mut self, + kind: crate::ScalarKind, + width: crate::Bytes, + ) -> Handle { + let bool_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }, + }, + Span::UNDEFINED, + ); + let scalar_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind, width }, + }, + Span::UNDEFINED, + ); + + self.types.insert( + crate::Type { + name: Some(format!( + "__atomic_compare_exchange_result<{kind:?},{width}>" + )), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("old_value".to_string()), + ty: scalar_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("exchanged".to_string()), + ty: bool_ty, + binding: None, + offset: 4, + }, + ], + span: 8, + }, + }, + Span::UNDEFINED, + ) + } /// Populate this module's [`SpecialTypes::ray_desc`] type. /// /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of @@ -20,7 +69,7 @@ impl crate::Module { /// [`Initialize`]: crate::RayQueryFunction::Initialize /// [`RayQuery`]: crate::Statement::RayQuery /// [`RayQueryFunction::Initialize`]: crate::RayQueryFunction::Initialize - pub(super) fn generate_ray_desc_type(&mut self) -> Handle { + pub fn generate_ray_desc_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_desc { return handle; } @@ -122,7 +171,7 @@ impl crate::Module { /// /// [`SpecialTypes::ray_intersection`]: crate::SpecialTypes::ray_intersection /// [`Expression::RayQueryGetIntersection`]: crate::Expression::RayQueryGetIntersection - pub(super) fn generate_ray_intersection_type(&mut self) -> Handle { + pub fn generate_ray_intersection_type(&mut self) -> Handle { if let Some(handle) = self.special_types.ray_intersection { return handle; } diff --git a/src/front/wgsl/error.rs b/src/front/wgsl/error.rs index a4e6540237..2e71a76624 100644 --- a/src/front/wgsl/error.rs +++ b/src/front/wgsl/error.rs @@ -188,6 +188,7 @@ pub enum Error<'a> { MissingAttribute(&'static str, Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), + InvalidRayQueryPointer(Span), Pointer(&'static str, Span), NotPointer(Span), NotReference(&'static str, Span), @@ -526,6 +527,11 @@ impl<'a> Error<'a> { labels: vec![(span, "atomic operand type is invalid".into())], notes: vec![], }, + Error::InvalidRayQueryPointer(span) => ParseError { + message: "ray query operation is done on a pointer to a non-ray-query".to_string(), + labels: vec![(span, "ray query pointer is invalid".into())], + notes: vec![], + }, Error::NotPointer(span) => ParseError { message: "the operand of the `*` operator must be a pointer".to_string(), labels: vec![(span, "expression is not a pointer".into())], diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 8df3460b2e..314eea52ec 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -641,6 +641,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let span = tu.decls.get_span(decl_handle); let decl = &tu.decls[decl_handle]; + //NOTE: This is done separately from `resolve_ast_type` because `RayDesc` may be + // first encountered in a local constructor invocation. //TODO: find a nicer way? if let Some(dep) = decl.dependencies.iter().find(|dep| dep.ident == "RayDesc") { let ty_handle = ctx.module.generate_ray_desc_type(); @@ -1733,50 +1735,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expression = match *ctx.resolved_inner(value) { crate::TypeInner::Scalar { kind, width } => { - let bool_ty = ctx.module.types.insert( - crate::Type { - name: None, - inner: crate::TypeInner::Scalar { - kind: crate::ScalarKind::Bool, - width: crate::BOOL_WIDTH, - }, - }, - Span::UNDEFINED, - ); - let scalar_ty = ctx.module.types.insert( - crate::Type { - name: None, - inner: crate::TypeInner::Scalar { kind, width }, - }, - Span::UNDEFINED, - ); - let struct_ty = ctx.module.types.insert( - crate::Type { - name: Some( - "__atomic_compare_exchange_result".to_string(), - ), - inner: crate::TypeInner::Struct { - members: vec![ - crate::StructMember { - name: Some("old_value".to_string()), - ty: scalar_ty, - binding: None, - offset: 0, - }, - crate::StructMember { - name: Some("exchanged".to_string()), - ty: bool_ty, - binding: None, - offset: 4, - }, - ], - span: 8, - }, - }, - Span::UNDEFINED, - ); crate::Expression::AtomicResult { - ty: struct_ty, + //TODO: cache this to avoid generating duplicate types + ty: ctx + .module + .generate_atomic_compare_exchange_result(kind, width), comparison: true, } } @@ -2449,12 +2412,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { crate::TypeInner::RayQuery => Ok(pointer), ref other => { log::error!("Pointer type to {:?} passed to ray query op", other); - Err(Error::InvalidAtomicPointer(span)) + Err(Error::InvalidRayQueryPointer(span)) } }, ref other => { log::error!("Type {:?} passed to ray query op", other); - Err(Error::InvalidAtomicPointer(span)) + Err(Error::InvalidRayQueryPointer(span)) } } } diff --git a/src/lib.rs b/src/lib.rs index f91dc7ce56..a70015d16d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1460,7 +1460,7 @@ pub enum Expression { /// Return an intersection found by `query`. /// - /// If `committed` is true, return the committed result available when + /// If `committed` is true, return the committed result available when RayQueryGetIntersection { query: Handle, committed: bool, @@ -1848,13 +1848,13 @@ pub struct SpecialTypes { /// /// Call [`Module::generate_ray_desc_type`] to populate this if /// needed and return the handle. - ray_desc: Option>, + pub ray_desc: Option>, /// Type for `RayIntersection`. /// /// Call [`Module::generate_ray_intersection_type`] to populate /// this if needed and return the handle. - ray_intersection: Option>, + pub ray_intersection: Option>, } /// Shader module. diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 1a9c967490..4826547ded 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -7,8 +7,8 @@ let RAY_FLAG_OPAQUE = 0x01u; let RAY_FLAG_NO_OPAQUE = 0x02u; let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u; let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u; -let RAY_FLAG_CULL_FRONT_FACING = 0x10u; -let RAY_FLAG_CULL_BACK_FACING = 0x20u; +let RAY_FLAG_CULL_BACK_FACING = 0x10u; +let RAY_FLAG_CULL_FRONT_FACING = 0x20u; let RAY_FLAG_CULL_OPAQUE = 0x40u; let RAY_FLAG_CULL_NO_OPAQUE = 0x80u; let RAY_FLAG_SKIP_TRIANGLES = 0x100u; diff --git a/tests/out/wgsl/atomicCompareExchange.wgsl b/tests/out/wgsl/atomicCompareExchange.wgsl index 2c213c8fec..bfad298fab 100644 --- a/tests/out/wgsl/atomicCompareExchange.wgsl +++ b/tests/out/wgsl/atomicCompareExchange.wgsl @@ -1,9 +1,9 @@ -struct gen___atomic_compare_exchange_result { +struct gen___atomic_compare_exchange_resultSint4_ { old_value: i32, exchanged: bool, } -struct gen___atomic_compare_exchange_result_1 { +struct gen___atomic_compare_exchange_resultUint4_ { old_value: u32, exchanged: bool, } From 5807db4dbc7c24dbcfbf152db2adfb64bfa3a35f Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 22 Mar 2023 09:00:08 -0700 Subject: [PATCH 12/12] Clean up handling of `RayDesc` builtin type, somewhat. --- src/front/wgsl/lower/mod.rs | 20 +++++++++++--------- tests/out/spv/ray-query.spvasm | 20 ++++++++++---------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 314eea52ec..bc9cce1bee 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -641,15 +641,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let span = tu.decls.get_span(decl_handle); let decl = &tu.decls[decl_handle]; - //NOTE: This is done separately from `resolve_ast_type` because `RayDesc` may be - // first encountered in a local constructor invocation. - //TODO: find a nicer way? - if let Some(dep) = decl.dependencies.iter().find(|dep| dep.ident == "RayDesc") { - let ty_handle = ctx.module.generate_ray_desc_type(); - ctx.globals - .insert(dep.ident, LoweredGlobalDecl::Type(ty_handle)); - } - match decl.kind { ast::GlobalDeclKind::Fn(ref f) => { let lowered_decl = self.function(f, span, ctx.reborrow())?; @@ -1930,6 +1921,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { committed: true, } } + "RayDesc" => { + let ty = ctx.module.generate_ray_desc_type(); + let handle = self.construct( + span, + &ast::ConstructorType::Type(ty), + function.span, + arguments, + ctx.reborrow(), + )?; + return Ok(Some(handle)); + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; diff --git a/tests/out/spv/ray-query.spvasm b/tests/out/spv/ray-query.spvasm index 306cda758c..31dc7d75e6 100644 --- a/tests/out/spv/ray-query.spvasm +++ b/tests/out/spv/ray-query.spvasm @@ -26,12 +26,12 @@ OpMemberDecorate %19 9 MatrixStride 16 OpMemberDecorate %19 10 Offset 112 OpMemberDecorate %19 10 ColMajor OpMemberDecorate %19 10 MatrixStride 16 -OpMemberDecorate %21 0 Offset 0 -OpMemberDecorate %21 1 Offset 4 -OpMemberDecorate %21 2 Offset 8 -OpMemberDecorate %21 3 Offset 12 -OpMemberDecorate %21 4 Offset 16 -OpMemberDecorate %21 5 Offset 32 +OpMemberDecorate %22 0 Offset 0 +OpMemberDecorate %22 1 Offset 4 +OpMemberDecorate %22 2 Offset 8 +OpMemberDecorate %22 3 Offset 12 +OpMemberDecorate %22 4 Offset 16 +OpMemberDecorate %22 5 Offset 32 OpDecorate %23 DescriptorSet 0 OpDecorate %23 Binding 0 OpDecorate %25 DescriptorSet 0 @@ -57,15 +57,15 @@ OpMemberDecorate %26 0 Offset 0 %18 = OpTypeMatrix %14 4 %19 = OpTypeStruct %8 %4 %8 %8 %8 %8 %8 %16 %17 %18 %18 %20 = OpTypeVector %4 4 -%21 = OpTypeStruct %8 %8 %4 %4 %14 %14 -%22 = OpTypeRayQueryKHR +%21 = OpTypeRayQueryKHR +%22 = OpTypeStruct %8 %8 %4 %4 %14 %14 %24 = OpTypePointer UniformConstant %13 %23 = OpVariable %24 UniformConstant %26 = OpTypeStruct %15 %27 = OpTypePointer StorageBuffer %26 %25 = OpVariable %27 StorageBuffer %32 = OpTypeFunction %14 %14 %19 -%46 = OpTypePointer Function %22 +%46 = OpTypePointer Function %21 %49 = OpTypeFunction %2 %51 = OpTypePointer StorageBuffer %15 %72 = OpConstant %8 1 @@ -99,7 +99,7 @@ OpBranch %53 %53 = OpLabel %54 = OpCompositeConstruct %14 %6 %3 %6 %55 = OpCompositeConstruct %14 %6 %6 %6 -%56 = OpCompositeConstruct %21 %7 %9 %10 %11 %55 %54 +%56 = OpCompositeConstruct %22 %7 %9 %10 %11 %55 %54 %57 = OpCompositeExtract %8 %56 0 %58 = OpCompositeExtract %8 %56 1 %59 = OpCompositeExtract %4 %56 2