From 8a2a7ac681dbbced96482d1e4cf472356d3d37be Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 16 Feb 2023 23:03:56 -0800 Subject: [PATCH] spv-out: basic ray query support --- src/back/spv/block.rs | 162 +++++++++++++++++++++++++++++++-- src/back/spv/image.rs | 2 +- src/back/spv/instructions.rs | 60 ++++++++++++ src/back/spv/mod.rs | 8 +- src/back/spv/writer.rs | 100 +++++++++++--------- src/front/wgsl/lower/mod.rs | 34 ++++++- src/lib.rs | 4 +- src/valid/handles.rs | 5 +- tests/in/ray-query.wgsl | 6 +- tests/out/spv/ray-query.spvasm | 81 +++++++++++++++++ 10 files changed, 400 insertions(+), 62 deletions(-) create mode 100644 tests/out/spv/ray-query.spvasm diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 004c717b2a..56fefae3a9 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1069,9 +1069,9 @@ impl<'w> BlockContext<'w> { } } crate::Expression::FunctionArgument(index) => self.function.parameter_id(index), - crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => { - self.cached[expr_handle] - } + crate::Expression::CallResult(_) + | crate::Expression::AtomicResult { .. } + | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -1364,10 +1364,61 @@ impl<'w> BlockContext<'w> { id } crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, - //TODO - crate::Expression::RayQueryProceedResult => unreachable!(), - //TODO - crate::Expression::RayQueryGetIntersection { .. } => unreachable!(), + crate::Expression::RayQueryGetIntersection { query, committed } => { + let width = 4; + let query_id = self.cached[query]; + let intersection_id = self.writer.get_constant_scalar( + crate::ScalarValue::Uint( + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, + ), + width, + ); + if !committed { + return Err(Error::FeatureNotImplemented("candidate intersection")); + } + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let kind_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + flag_type_id, + kind_id, + query_id, + intersection_id, + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let t_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); + + let id = self.gen_id(); + let intersection_type_id = self.get_type_id(LookupType::Handle( + self.ir_module.special_types.ray_intersection.unwrap(), + )); + //Note: the arguments must match `generate_ray_intersection_type` layout + block.body.push(Instruction::composite_construct( + intersection_type_id, + id, + &[kind_id, t_id], + )); + id + } }; self.cached[expr_handle] = id; @@ -2181,8 +2232,101 @@ impl<'w> BlockContext<'w> { crate::RayQueryFunction::Initialize { acceleration_structure, descriptor, - } => {} - crate::RayQueryFunction::Proceed => {} + } => { + //Note: composite extract indices and types must match `generate_ray_desc_type` + let desc_id = self.cached[descriptor]; + let acc_struct_id = self.get_image_id(acceleration_structure); + let width = 4; + + let flag_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + })); + let ray_flags_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + ray_flags_id, + desc_id, + &[0], + )); + let cull_mask_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + cull_mask_id, + desc_id, + &[1], + )); + + let scalar_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let tmin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmin_id, + desc_id, + &[2], + )); + let tmax_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmax_id, + desc_id, + &[3], + )); + + let vector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + let ray_origin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_origin_id, + desc_id, + &[4], + )); + let ray_dir_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_dir_id, + desc_id, + &[5], + )); + + block.body.push(Instruction::ray_query_initialize( + query_id, + acc_struct_id, + ray_flags_id, + cull_mask_id, + ray_origin_id, + tmin_id, + ray_dir_id, + tmax_id, + )); + } + crate::RayQueryFunction::Proceed { result } => { + let id = self.gen_id(); + self.cached[result] = id; + let result_type_id = + self.get_expression_type_id(&self.fun_info[result].ty); + + block.body.push(Instruction::ray_query_proceed( + result_type_id, + id, + query_id, + )); + } crate::RayQueryFunction::Terminate => {} } } diff --git a/src/back/spv/image.rs b/src/back/spv/image.rs index bf2ce28f1d..b0a1c7aaa4 100644 --- a/src/back/spv/image.rs +++ b/src/back/spv/image.rs @@ -373,7 +373,7 @@ impl<'w> BlockContext<'w> { }) } - fn get_image_id(&mut self, expr_handle: Handle) -> Word { + pub(super) fn get_image_id(&mut self, expr_handle: Handle) -> Word { let id = match self.ir_function.expressions[expr_handle] { crate::Expression::GlobalVariable(handle) => { self.writer.global_variables[handle.index()].handle_id diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index c213790188..3038d60d40 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -249,6 +249,18 @@ impl super::Instruction { instruction } + pub(super) fn type_acceleration_structure(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeAccelerationStructureKHR); + instruction.set_result(id); + instruction + } + + pub(super) fn type_ray_query(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeRayQueryKHR); + instruction.set_result(id); + instruction + } + pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self { let mut instruction = Self::new(Op::TypeSampledImage); instruction.set_result(id); @@ -627,6 +639,54 @@ impl super::Instruction { instruction } + // + // Ray Query Instructions + // + pub(super) fn ray_query_initialize( + query: Word, + acceleration_structure: Word, + ray_flags: Word, + cull_mask: Word, + ray_origin: Word, + ray_tmin: Word, + ray_dir: Word, + ray_tmax: Word, + ) -> Self { + let mut instruction = Self::new(Op::RayQueryInitializeKHR); + instruction.add_operand(query); + instruction.add_operand(acceleration_structure); + instruction.add_operand(ray_flags); + instruction.add_operand(cull_mask); + instruction.add_operand(ray_origin); + instruction.add_operand(ray_tmin); + instruction.add_operand(ray_dir); + instruction.add_operand(ray_tmax); + instruction + } + + pub(super) fn ray_query_proceed(result_type_id: Word, id: Word, query: Word) -> Self { + let mut instruction = Self::new(Op::RayQueryProceedKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction + } + + pub(super) fn ray_query_get_intersection( + op: Op, + result_type_id: Word, + id: Word, + query: Word, + intersection: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction.add_operand(intersection); + instruction + } + // // Conversion Instructions // diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index e19d825a05..a12834c4ed 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -295,6 +295,8 @@ enum LocalType { base: Handle, size: u64, }, + AccelerationStructure, + RayQuery, } /// A type encountered during SPIR-V generation. @@ -383,7 +385,11 @@ fn make_local(inner: &crate::TypeInner) -> Option { class, } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)), crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler, - _ => return None, + crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure, + crate::TypeInner::RayQuery => LocalType::RayQuery, + crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } + | crate::TypeInner::BindingArray { .. } => return None, }) } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index e5b116b14f..fc439f7a99 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -350,9 +350,12 @@ impl Writer { pointer_type_id, id, spirv::StorageClass::Function, - init_word.or_else(|| { - let type_id = self.get_type_id(LookupType::Handle(variable.ty)); - Some(self.write_constant_null(type_id)) + init_word.or_else(|| match ir_module.types[variable.ty].inner { + crate::TypeInner::RayQuery => None, + _ => { + let type_id = self.get_type_id(LookupType::Handle(variable.ty)); + Some(self.write_constant_null(type_id)) + } }), ); function @@ -814,47 +817,54 @@ impl Writer { } } - fn request_image_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> { - if let crate::TypeInner::Image { - dim, - arrayed, - class, - } = *inner - { - let sampled = match class { - crate::ImageClass::Sampled { .. } => true, - crate::ImageClass::Depth { .. } => true, - crate::ImageClass::Storage { format, .. } => { - self.request_image_format_capabilities(format.into())?; - false - } - }; + fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> { + match *inner { + crate::TypeInner::Image { + dim, + arrayed, + class, + } => { + let sampled = match class { + crate::ImageClass::Sampled { .. } => true, + crate::ImageClass::Depth { .. } => true, + crate::ImageClass::Storage { format, .. } => { + self.request_image_format_capabilities(format.into())?; + false + } + }; - match dim { - crate::ImageDimension::D1 => { - if sampled { - self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?; - } else { - self.require_any("1D storage images", &[spirv::Capability::Image1D])?; + match dim { + crate::ImageDimension::D1 => { + if sampled { + self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?; + } else { + self.require_any("1D storage images", &[spirv::Capability::Image1D])?; + } } - } - crate::ImageDimension::Cube if arrayed => { - if sampled { - self.require_any( - "sampled cube array images", - &[spirv::Capability::SampledCubeArray], - )?; - } else { - self.require_any( - "cube array storage images", - &[spirv::Capability::ImageCubeArray], - )?; + crate::ImageDimension::Cube if arrayed => { + if sampled { + self.require_any( + "sampled cube array images", + &[spirv::Capability::SampledCubeArray], + )?; + } else { + self.require_any( + "cube array storage images", + &[spirv::Capability::ImageCubeArray], + )?; + } } + _ => {} } - _ => {} } + crate::TypeInner::AccelerationStructure => { + self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?; + } + crate::TypeInner::RayQuery => { + self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?; + } + _ => {} } - Ok(()) } @@ -935,6 +945,8 @@ impl Writer { self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size })); Instruction::type_pointer(id, spirv::StorageClass::UniformConstant, inner_ty) } + LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id), + LocalType::RayQuery => Instruction::type_ray_query(id), }; instruction.to_words(&mut self.logical_layout.declarations); @@ -961,9 +973,9 @@ impl Writer { self.write_type_declaration_local(id, local); - // If it's an image type, request SPIR-V capabilities here, so - // write_type_declaration_local can stay infallible. - self.request_image_capabilities(&ty.inner)?; + // If it's an type that needs SPIR-V capabilities, request them now, + // so write_type_declaration_local can stay infallible. + self.request_type_capabilities(&ty.inner)?; id } @@ -1736,6 +1748,8 @@ impl Writer { .iter() .flat_map(|entry| entry.function.arguments.iter()) .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty)); + let has_ray_query = ir_module.special_types.ray_desc.is_some() + | ir_module.special_types.ray_intersection.is_some(); if self.physical_layout.version < 0x10300 && has_storage_buffers { // enable the storage buffer class on < SPV-1.3 @@ -1746,6 +1760,10 @@ impl Writer { Instruction::extension("SPV_KHR_multiview") .to_words(&mut self.logical_layout.extensions) } + if has_ray_query { + Instruction::extension("SPV_KHR_ray_query") + .to_words(&mut self.logical_layout.extensions) + } Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations); Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450") .to_words(&mut self.logical_layout.ext_inst_imports); diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index da870ec747..32a1271e29 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1915,7 +1915,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "rayQueryInitialize" => { let mut args = ctx.prepare_args(arguments, 3, span); - let query = self.expression(args.next()?, ctx.reborrow())?; + let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; let acceleration_structure = self.expression(args.next()?, ctx.reborrow())?; let descriptor = self.expression(args.next()?, ctx.reborrow())?; @@ -1935,15 +1935,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "rayQueryProceed" => { let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.expression(args.next()?, ctx.reborrow())?; + let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; args.finish()?; - let fun = crate::RayQueryFunction::Proceed; - ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions)); let result = ctx .naga_expressions .append(crate::Expression::RayQueryProceedResult, span); + let fun = crate::RayQueryFunction::Proceed { result }; + ctx.emitter.start(ctx.naga_expressions); ctx.block .push(crate::Statement::RayQuery { query, fun }, span); @@ -1951,7 +1951,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "rayQueryGetCommittedIntersection" => { let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.expression(args.next()?, ctx.reborrow())?; + let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; args.finish()?; let _ = ctx.module.generate_ray_intersection_type(); @@ -2422,4 +2422,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { binding } + + fn ray_query_pointer( + &mut self, + expr: Handle>, + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let pointer = self.expression(expr, ctx.reborrow())?; + + ctx.grow_types(pointer)?; + match *ctx.resolved_inner(pointer) { + crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { + crate::TypeInner::RayQuery => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to ray query op", other); + Err(Error::InvalidAtomicPointer(span)) + } + }, + ref other => { + log::error!("Type {:?} passed to ray query op", other); + Err(Error::InvalidAtomicPointer(span)) + } + } + } } diff --git a/src/lib.rs b/src/lib.rs index 57102c3150..2985a24773 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1477,7 +1477,9 @@ pub enum RayQueryFunction { acceleration_structure: Handle, descriptor: Handle, }, - Proceed, + Proceed { + result: Handle, + }, Terminate, } diff --git a/src/valid/handles.rs b/src/valid/handles.rs index b5f4dd82d3..d788222f3d 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -523,7 +523,10 @@ impl super::Validator { validate_expr(acceleration_structure)?; validate_expr(descriptor)?; } - crate::RayQueryFunction::Proceed | crate::RayQueryFunction::Terminate => {} + crate::RayQueryFunction::Proceed { result } => { + validate_expr(result)?; + } + crate::RayQueryFunction::Terminate => {} } Ok(()) } diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 0aec9ca142..7b1053bdcb 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -44,10 +44,10 @@ var output: Output; fn main() { var rq: ray_query; - rayQueryInitialize(rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), vec3(0.0, 1.0, 0.0))); + rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), vec3(0.0, 1.0, 0.0))); - rayQueryProceed(rq); + rayQueryProceed(&rq); - let intersection = rayQueryGetCommittedIntersection(rq); + let intersection = rayQueryGetCommittedIntersection(&rq); output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE); } diff --git a/tests/out/spv/ray-query.spvasm b/tests/out/spv/ray-query.spvasm new file mode 100644 index 0000000000..455a8bffe9 --- /dev/null +++ b/tests/out/spv/ray-query.spvasm @@ -0,0 +1,81 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 52 +OpCapability RayQueryKHR +OpCapability Shader +OpExtension "SPV_KHR_ray_query" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %26 "main" %18 %20 +OpExecutionMode %26 LocalSize 1 1 1 +OpMemberDecorate %13 0 Offset 0 +OpMemberDecorate %16 0 Offset 0 +OpMemberDecorate %16 1 Offset 4 +OpMemberDecorate %16 2 Offset 8 +OpMemberDecorate %16 3 Offset 12 +OpMemberDecorate %16 4 Offset 16 +OpMemberDecorate %16 5 Offset 32 +OpMemberDecorate %17 0 Offset 0 +OpMemberDecorate %17 1 Offset 4 +OpDecorate %18 DescriptorSet 0 +OpDecorate %18 Binding 0 +OpDecorate %20 DescriptorSet 0 +OpDecorate %20 Binding 1 +OpDecorate %21 Block +OpMemberDecorate %21 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 0 +%3 = OpConstant %4 4 +%5 = OpConstant %4 255 +%7 = OpTypeFloat 32 +%6 = OpConstant %7 0.1 +%8 = OpConstant %7 100.0 +%9 = OpConstant %7 0.0 +%10 = OpConstant %7 1.0 +%11 = OpConstant %4 0 +%12 = OpTypeAccelerationStructureNV +%13 = OpTypeStruct %4 +%14 = OpTypeRayQueryKHR +%15 = OpTypeVector %7 3 +%16 = OpTypeStruct %4 %4 %7 %7 %15 %15 +%17 = OpTypeStruct %4 %7 +%19 = OpTypePointer UniformConstant %12 +%18 = OpVariable %19 UniformConstant +%21 = OpTypeStruct %13 +%22 = OpTypePointer StorageBuffer %21 +%20 = OpVariable %22 StorageBuffer +%24 = OpTypePointer Function %14 +%27 = OpTypeFunction %2 +%29 = OpTypePointer StorageBuffer %13 +%42 = OpTypeBool +%43 = OpConstant %4 1 +%47 = OpTypePointer StorageBuffer %4 +%26 = OpFunction %2 None %27 +%25 = OpLabel +%23 = OpVariable %24 Function +%28 = OpLoad %12 %18 +%30 = OpAccessChain %29 %20 %11 +OpBranch %31 +%31 = OpLabel +%32 = OpCompositeConstruct %15 %9 %9 %9 +%33 = OpCompositeConstruct %15 %9 %10 %9 +%34 = OpCompositeConstruct %16 %3 %5 %6 %8 %32 %33 +%35 = OpCompositeExtract %4 %34 0 +%36 = OpCompositeExtract %4 %34 1 +%37 = OpCompositeExtract %7 %34 2 +%38 = OpCompositeExtract %7 %34 3 +%39 = OpCompositeExtract %15 %34 4 +%40 = OpCompositeExtract %15 %34 5 +OpRayQueryInitializeKHR %23 %28 %35 %36 %39 %37 %40 %38 +%41 = OpRayQueryProceedKHR %42 %23 +%44 = OpRayQueryGetIntersectionTypeKHR %4 %23 %43 +%45 = OpRayQueryGetIntersectionTKHR %7 %23 %43 +%46 = OpCompositeConstruct %17 %44 %45 +%48 = OpCompositeExtract %4 %46 0 +%49 = OpIEqual %42 %48 %11 +%50 = OpSelect %4 %49 %43 %11 +%51 = OpAccessChain %47 %30 %11 +OpStore %51 %50 +OpReturn +OpFunctionEnd \ No newline at end of file