diff --git a/CHANGELOG.md b/CHANGELOG.md index e209fc8541..45ba0ca539 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,20 @@ By @teoxoy in [#3534](https://github.com/gfx-rs/wgpu/pull/3534) - All `fxhash` dependencies have been replaced with `rustc-hash`. By @james7132 in [#3502](https://github.com/gfx-rs/wgpu/pull/3502) - Change type of `bytes_per_row` and `rows_per_image` (members of `ImageDataLayout`) from `Option` to `Option`. By @teoxoy in [#3529](https://github.com/gfx-rs/wgpu/pull/3529) +### Added/New Features + +#### General +- Added feature flags for ray-tracing (currently only hal): `RAY_QUERY` and `RAY_TRACING` @daniel-keitel (started by @expenses) in [#3507](https://github.com/gfx-rs/wgpu/pull/3507) + +#### Vulkan + +- Implemented basic ray-tracing api for acceleration structures, and ray-queries @daniel-keitel (started by @expenses) in [#3507](https://github.com/gfx-rs/wgpu/pull/3507) + +#### Hal + +- Added basic ray-tracing api for acceleration structures, and ray-queries @daniel-keitel (started by @expenses) in [#3507](https://github.com/gfx-rs/wgpu/pull/3507) + + ### Changes #### General diff --git a/Cargo.lock b/Cargo.lock index 9995234bb2..5a78b7c94a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3063,6 +3063,7 @@ dependencies = [ "d3d12", "env_logger", "foreign-types 0.3.2", + "glam", "glow", "glutin", "gpu-alloc", diff --git a/wgpu-core/src/binding_model.rs b/wgpu-core/src/binding_model.rs index dbf96f0439..347fe015ae 100644 --- a/wgpu-core/src/binding_model.rs +++ b/wgpu-core/src/binding_model.rs @@ -330,6 +330,7 @@ impl BindingTypeMaxCountValidator { wgt::BindingType::StorageTexture { .. } => { self.storage_textures.add(binding.visibility, count); } + wgt::BindingType::AccelerationStructure => todo!(), } } diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 5f6a148129..d89d692e93 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -1716,6 +1716,7 @@ impl Device { }, ) } + Bt::AccelerationStructure => todo!(), }; // Validate the count parameter @@ -2202,6 +2203,7 @@ impl Device { buffers: &hal_buffers, samplers: &hal_samplers, textures: &hal_textures, + acceleration_structures: &[], }; let raw = unsafe { self.raw diff --git a/wgpu-hal/Cargo.toml b/wgpu-hal/Cargo.toml index 6722a8d76e..722fb1196c 100644 --- a/wgpu-hal/Cargo.toml +++ b/wgpu-hal/Cargo.toml @@ -131,8 +131,9 @@ version = "0.11.0" features = ["wgsl-in"] [dev-dependencies] +winit = "0.27.1" # for "halmark" example env_logger = "0.10" -winit = "0.27.1" # for "halmark" example +glam = "0.21.3" # for ray-traced-triangle example # for "halmark" example [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] glutin = "0.29.1" # for "gles" example diff --git a/wgpu-hal/examples/halmark/main.rs b/wgpu-hal/examples/halmark/main.rs index b4f25c9179..6816c270ff 100644 --- a/wgpu-hal/examples/halmark/main.rs +++ b/wgpu-hal/examples/halmark/main.rs @@ -432,6 +432,7 @@ impl Example { buffers: &[global_buffer_binding], samplers: &[&sampler], textures: &[texture_binding], + acceleration_structures: &[], entries: &[ hal::BindGroupEntry { binding: 0, @@ -465,6 +466,7 @@ impl Example { buffers: &[local_buffer_binding], samplers: &[], textures: &[], + acceleration_structures: &[], entries: &[hal::BindGroupEntry { binding: 0, resource_index: 0, diff --git a/wgpu-hal/examples/ray-traced-triangle/main.rs b/wgpu-hal/examples/ray-traced-triangle/main.rs new file mode 100644 index 0000000000..88c66c4bba --- /dev/null +++ b/wgpu-hal/examples/ray-traced-triangle/main.rs @@ -0,0 +1,1082 @@ +extern crate wgpu_hal as hal; + +use hal::{ + Adapter as _, CommandEncoder as _, Device as _, Instance as _, Queue as _, Surface as _, +}; +use raw_window_handle::{HasRawDisplayHandle, HasRawWindowHandle}; + +use glam::{Affine3A, Mat4, Vec3}; +use std::{ + borrow::{Borrow, Cow}, + iter, mem, + mem::{align_of, size_of}, + ptr::{self, copy_nonoverlapping}, + time::Instant, +}; + +const COMMAND_BUFFER_PER_CONTEXT: usize = 100; +const DESIRED_FRAMES: u32 = 3; + +/// [D3D12_RAYTRACING_INSTANCE_DESC](https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#d3d12_raytracing_instance_desc) +/// [VkAccelerationStructureInstanceKHR](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkAccelerationStructureInstanceKHR.html) +#[derive(Clone)] +#[repr(C)] +struct AccelerationStructureInstance { + transform: [f32; 12], + custom_index_and_mask: u32, + shader_binding_table_record_offset_and_flags: u32, + acceleration_structure_reference: u64, +} + +impl std::fmt::Debug for AccelerationStructureInstance { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Instance") + .field("transform", &self.transform) + .field("custom_index()", &self.custom_index()) + .field("mask()", &self.mask()) + .field( + "shader_binding_table_record_offset()", + &self.shader_binding_table_record_offset(), + ) + .field("flags()", &self.flags()) + .field( + "acceleration_structure_reference", + &self.acceleration_structure_reference, + ) + .finish() + } +} + +#[allow(dead_code)] +impl AccelerationStructureInstance { + const LOW_24_MASK: u32 = 0x00ff_ffff; + const MAX_U24: u32 = (1u32 << 24u32) - 1u32; + + #[inline] + fn affine_to_rows(mat: &Affine3A) -> [f32; 12] { + let row_0 = mat.matrix3.row(0); + let row_1 = mat.matrix3.row(1); + let row_2 = mat.matrix3.row(2); + let translation = mat.translation; + [ + row_0.x, + row_0.y, + row_0.z, + translation.x, + row_1.x, + row_1.y, + row_1.z, + translation.y, + row_2.x, + row_2.y, + row_2.z, + translation.z, + ] + } + + #[inline] + fn rows_to_affine(rows: &[f32; 12]) -> Affine3A { + Affine3A::from_cols_array(&[ + rows[0], rows[3], rows[6], rows[9], rows[1], rows[4], rows[7], rows[10], rows[2], + rows[5], rows[8], rows[11], + ]) + } + + pub fn transform_as_affine(&self) -> Affine3A { + Self::rows_to_affine(&self.transform) + } + pub fn set_transform(&mut self, transform: &Affine3A) { + self.transform = Self::affine_to_rows(&transform); + } + + pub fn custom_index(&self) -> u32 { + self.custom_index_and_mask & Self::LOW_24_MASK + } + + pub fn mask(&self) -> u8 { + (self.custom_index_and_mask >> 24) as u8 + } + + pub fn shader_binding_table_record_offset(&self) -> u32 { + self.shader_binding_table_record_offset_and_flags & Self::LOW_24_MASK + } + + pub fn flags(&self) -> u8 { + (self.shader_binding_table_record_offset_and_flags >> 24) as u8 + } + + pub fn set_custom_index(&mut self, custom_index: u32) { + debug_assert!( + custom_index <= Self::MAX_U24, + "custom_index uses more than 24 bits! {custom_index} > {}", + Self::MAX_U24 + ); + self.custom_index_and_mask = + (custom_index & Self::LOW_24_MASK) | (self.custom_index_and_mask & !Self::LOW_24_MASK) + } + + pub fn set_mask(&mut self, mask: u8) { + self.custom_index_and_mask = + (self.custom_index_and_mask & Self::LOW_24_MASK) | (u32::from(mask) << 24) + } + + pub fn set_shader_binding_table_record_offset( + &mut self, + shader_binding_table_record_offset: u32, + ) { + debug_assert!(shader_binding_table_record_offset <= Self::MAX_U24, "shader_binding_table_record_offset uses more than 24 bits! {shader_binding_table_record_offset} > {}", Self::MAX_U24); + self.shader_binding_table_record_offset_and_flags = (shader_binding_table_record_offset + & Self::LOW_24_MASK) + | (self.shader_binding_table_record_offset_and_flags & !Self::LOW_24_MASK) + } + + pub fn set_flags(&mut self, flags: u8) { + self.shader_binding_table_record_offset_and_flags = + (self.shader_binding_table_record_offset_and_flags & Self::LOW_24_MASK) + | (u32::from(flags) << 24) + } + + pub fn new( + transform: &Affine3A, + custom_index: u32, + mask: u8, + shader_binding_table_record_offset: u32, + flags: u8, + acceleration_structure_reference: u64, + ) -> Self { + debug_assert!( + custom_index <= Self::MAX_U24, + "custom_index uses more than 24 bits! {custom_index} > {}", + Self::MAX_U24 + ); + debug_assert!( + shader_binding_table_record_offset <= Self::MAX_U24, + "shader_binding_table_record_offset uses more than 24 bits! {shader_binding_table_record_offset} > {}", Self::MAX_U24 + ); + AccelerationStructureInstance { + transform: Self::affine_to_rows(transform), + custom_index_and_mask: (custom_index & Self::MAX_U24) | (u32::from(mask) << 24), + shader_binding_table_record_offset_and_flags: (shader_binding_table_record_offset + & Self::MAX_U24) + | (u32::from(flags) << 24), + acceleration_structure_reference, + } + } +} + +struct ExecutionContext { + encoder: A::CommandEncoder, + fence: A::Fence, + fence_value: hal::FenceValue, + used_views: Vec, + used_cmd_bufs: Vec, + frames_recorded: usize, +} + +impl ExecutionContext { + unsafe fn wait_and_clear(&mut self, device: &A::Device) { + device.wait(&self.fence, self.fence_value, !0).unwrap(); + self.encoder.reset_all(self.used_cmd_bufs.drain(..)); + for view in self.used_views.drain(..) { + device.destroy_texture_view(view); + } + self.frames_recorded = 0; + } +} + +#[allow(dead_code)] +struct Example { + instance: A::Instance, + adapter: A::Adapter, + surface: A::Surface, + surface_format: wgt::TextureFormat, + device: A::Device, + queue: A::Queue, + + contexts: Vec>, + context_index: usize, + extent: [u32; 2], + start: Instant, + pipeline: A::ComputePipeline, + bind_group: A::BindGroup, + bgl: A::BindGroupLayout, + shader_module: A::ShaderModule, + texture_view: A::TextureView, + uniform_buffer: A::Buffer, + pipeline_layout: A::PipelineLayout, + vertices_buffer: A::Buffer, + indices_buffer: A::Buffer, + texture: A::Texture, + instances: [AccelerationStructureInstance; 3], + instances_buffer: A::Buffer, + blas: A::AccelerationStructure, + tlas: A::AccelerationStructure, + scratch_buffer: A::Buffer, + time: f32, +} + +impl Example { + fn init(window: &winit::window::Window) -> Result { + let instance_desc = hal::InstanceDescriptor { + name: "example", + flags: if cfg!(debug_assertions) { + hal::InstanceFlags::all() + } else { + hal::InstanceFlags::empty() + }, + dx12_shader_compiler: wgt::Dx12Compiler::Fxc, + }; + let instance = unsafe { A::Instance::init(&instance_desc)? }; + let mut surface = unsafe { + instance + .create_surface(window.raw_display_handle(), window.raw_window_handle()) + .unwrap() + }; + + let (adapter, features) = unsafe { + let mut adapters = instance.enumerate_adapters(); + if adapters.is_empty() { + return Err(hal::InstanceError); + } + let exposed = adapters.swap_remove(0); + dbg!(exposed.features); + (exposed.adapter, exposed.features) + }; + let surface_caps = + unsafe { adapter.surface_capabilities(&surface) }.ok_or(hal::InstanceError)?; + log::info!("Surface caps: {:#?}", surface_caps); + + let hal::OpenDevice { device, mut queue } = + unsafe { adapter.open(features, &wgt::Limits::default()).unwrap() }; + + let window_size: (u32, u32) = window.inner_size().into(); + dbg!(&surface_caps.formats); + let surface_format = if surface_caps + .formats + .contains(&wgt::TextureFormat::Rgba8Snorm) + { + wgt::TextureFormat::Rgba8Unorm + } else { + *surface_caps.formats.first().unwrap() + }; + let surface_config = hal::SurfaceConfiguration { + swap_chain_size: DESIRED_FRAMES + .max(*surface_caps.swap_chain_sizes.start()) + .min(*surface_caps.swap_chain_sizes.end()), + present_mode: wgt::PresentMode::Fifo, + composite_alpha_mode: wgt::CompositeAlphaMode::Opaque, + format: surface_format, + extent: wgt::Extent3d { + width: window_size.0, + height: window_size.1, + depth_or_array_layers: 1, + }, + usage: hal::TextureUses::COLOR_TARGET | hal::TextureUses::COPY_DST, + view_formats: vec![surface_format], + }; + unsafe { + surface.configure(&device, &surface_config).unwrap(); + }; + + #[allow(dead_code)] + struct Uniforms { + view_inverse: glam::Mat4, + proj_inverse: glam::Mat4, + } + + let bgl_desc = hal::BindGroupLayoutDescriptor { + label: None, + flags: hal::BindGroupLayoutFlags::empty(), + entries: &[ + wgt::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: wgt::BufferSize::new(mem::size_of::() as _), + }, + count: None, + }, + wgt::BindGroupLayoutEntry { + binding: 1, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::StorageTexture { + access: wgt::StorageTextureAccess::WriteOnly, + format: wgt::TextureFormat::Rgba8Unorm, + view_dimension: wgt::TextureViewDimension::D2, + }, + count: None, + }, + wgt::BindGroupLayoutEntry { + binding: 2, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::AccelerationStructure, + count: None, + }, + ], + }; + + let bgl = unsafe { device.create_bind_group_layout(&bgl_desc).unwrap() }; + + pub fn make_spirv_raw(data: &[u8]) -> Cow<[u32]> { + const MAGIC_NUMBER: u32 = 0x0723_0203; + assert_eq!( + data.len() % size_of::(), + 0, + "data size is not a multiple of 4" + ); + + //If the data happens to be aligned, directly use the byte array, + // otherwise copy the byte array in an owned vector and use that instead. + let words = if data.as_ptr().align_offset(align_of::()) == 0 { + let (pre, words, post) = unsafe { data.align_to::() }; + debug_assert!(pre.is_empty()); + debug_assert!(post.is_empty()); + Cow::from(words) + } else { + let mut words = vec![0u32; data.len() / size_of::()]; + unsafe { + copy_nonoverlapping(data.as_ptr(), words.as_mut_ptr() as *mut u8, data.len()); + } + Cow::from(words) + }; + + assert_eq!( + words[0], MAGIC_NUMBER, + "wrong magic word {:x}. Make sure you are using a binary SPIRV file.", + words[0] + ); + + words + } + + let shader_module = unsafe { + device + .create_shader_module( + &hal::ShaderModuleDescriptor { + label: None, + runtime_checks: false, + }, + hal::ShaderInput::SpirV(&make_spirv_raw(include_bytes!("shader.comp.spv"))), + ) + .unwrap() + }; + + let pipeline_layout_desc = hal::PipelineLayoutDescriptor { + label: None, + flags: hal::PipelineLayoutFlags::empty(), + bind_group_layouts: &[&bgl], + push_constant_ranges: &[], + }; + let pipeline_layout = unsafe { + device + .create_pipeline_layout(&pipeline_layout_desc) + .unwrap() + }; + + let pipeline = unsafe { + device.create_compute_pipeline(&hal::ComputePipelineDescriptor { + label: Some("pipeline"), + layout: &pipeline_layout, + stage: hal::ProgrammableStage { + module: &shader_module, + entry_point: "main", + }, + }) + } + .unwrap(); + + let vertices: [f32; 9] = [1.0, 1.0, 0.0, -1.0, 1.0, 0.0, 0.0, -1.0, 0.0]; + + let vertices_size_in_bytes = vertices.len() * 4; + + let indices: [u32; 3] = [0, 1, 2]; + + let indices_size_in_bytes = indices.len() * 4; + + let vertices_buffer = unsafe { + let vertices_buffer = device + .create_buffer(&hal::BufferDescriptor { + label: Some("vertices buffer"), + size: vertices_size_in_bytes as u64, + usage: hal::BufferUses::MAP_WRITE + | hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, + memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT, + }) + .unwrap(); + + let mapping = device + .map_buffer(&vertices_buffer, 0..vertices_size_in_bytes as u64) + .unwrap(); + ptr::copy_nonoverlapping( + vertices.as_ptr() as *const u8, + mapping.ptr.as_ptr(), + vertices_size_in_bytes, + ); + device.unmap_buffer(&vertices_buffer).unwrap(); + assert!(mapping.is_coherent); + + vertices_buffer + }; + + let indices_buffer = unsafe { + let indices_buffer = device + .create_buffer(&hal::BufferDescriptor { + label: Some("indices buffer"), + size: indices_size_in_bytes as u64, + usage: hal::BufferUses::MAP_WRITE + | hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, + memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT, + }) + .unwrap(); + + let mapping = device + .map_buffer(&indices_buffer, 0..indices_size_in_bytes as u64) + .unwrap(); + ptr::copy_nonoverlapping( + indices.as_ptr() as *const u8, + mapping.ptr.as_ptr(), + indices_size_in_bytes, + ); + device.unmap_buffer(&indices_buffer).unwrap(); + assert!(mapping.is_coherent); + + indices_buffer + }; + + let blas_triangles = vec![hal::AccelerationStructureTriangles { + vertex_buffer: Some(&vertices_buffer), + first_vertex: 0, + vertex_format: wgt::VertexFormat::Float32x3, + vertex_count: vertices.len() as u32, + vertex_stride: 3 * 4, + indices: Some(hal::AccelerationStructureTriangleIndices { + buffer: Some(&indices_buffer), + format: wgt::IndexFormat::Uint32, + offset: 0, + count: indices.len() as u32, + }), + transform: None, + flags: hal::AccelerationStructureGeometryFlags::OPAQUE, + }]; + let blas_entries = hal::AccelerationStructureEntries::Triangles(&blas_triangles); + + let mut tlas_entries = + hal::AccelerationStructureEntries::Instances(hal::AccelerationStructureInstances { + buffer: None, + count: 3, + offset: 0, + }); + + let blas_sizes = unsafe { + device.get_acceleration_structure_build_sizes( + &hal::GetAccelerationStructureBuildSizesDescriptor { + entries: &blas_entries, + flags: hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE, + }, + ) + }; + + let tlas_flags = hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE + | hal::AccelerationStructureBuildFlags::ALLOW_UPDATE; + + let tlas_sizes = unsafe { + device.get_acceleration_structure_build_sizes( + &hal::GetAccelerationStructureBuildSizesDescriptor { + entries: &tlas_entries, + flags: tlas_flags, + }, + ) + }; + + let blas = unsafe { + device.create_acceleration_structure(&hal::AccelerationStructureDescriptor { + label: Some("blas"), + size: blas_sizes.acceleration_structure_size, + format: hal::AccelerationStructureFormat::BottomLevel, + }) + } + .unwrap(); + + let tlas = unsafe { + device.create_acceleration_structure(&hal::AccelerationStructureDescriptor { + label: Some("tlas"), + size: tlas_sizes.acceleration_structure_size, + format: hal::AccelerationStructureFormat::TopLevel, + }) + } + .unwrap(); + + let uniforms = { + let view = Mat4::look_at_rh(Vec3::new(0.0, 0.0, 2.5), Vec3::ZERO, Vec3::Y); + let proj = Mat4::perspective_rh(59.0_f32.to_radians(), 1.0, 0.001, 1000.0); + + Uniforms { + view_inverse: view.inverse(), + proj_inverse: proj.inverse(), + } + }; + + let uniforms_size = std::mem::size_of::(); + + let uniform_buffer = unsafe { + let uniform_buffer = device + .create_buffer(&hal::BufferDescriptor { + label: Some("uniform buffer"), + size: uniforms_size as u64, + usage: hal::BufferUses::MAP_WRITE | hal::BufferUses::UNIFORM, + memory_flags: hal::MemoryFlags::PREFER_COHERENT, + }) + .unwrap(); + + let mapping = device + .map_buffer(&uniform_buffer, 0..uniforms_size as u64) + .unwrap(); + ptr::copy_nonoverlapping( + &uniforms as *const Uniforms as *const u8, + mapping.ptr.as_ptr(), + uniforms_size, + ); + device.unmap_buffer(&uniform_buffer).unwrap(); + assert!(mapping.is_coherent); + uniform_buffer + }; + + let texture_desc = hal::TextureDescriptor { + label: None, + size: wgt::Extent3d { + width: 512, + height: 512, + depth_or_array_layers: 1, + }, + mip_level_count: 1, + sample_count: 1, + dimension: wgt::TextureDimension::D2, + format: wgt::TextureFormat::Rgba8Unorm, + usage: hal::TextureUses::STORAGE_READ_WRITE | hal::TextureUses::COPY_SRC, + memory_flags: hal::MemoryFlags::empty(), + view_formats: vec![wgt::TextureFormat::Rgba8Unorm], + }; + let texture = unsafe { device.create_texture(&texture_desc).unwrap() }; + + let view_desc = hal::TextureViewDescriptor { + label: None, + format: texture_desc.format, + dimension: wgt::TextureViewDimension::D2, + usage: hal::TextureUses::STORAGE_READ_WRITE | hal::TextureUses::COPY_SRC, + range: wgt::ImageSubresourceRange::default(), + }; + let texture_view = unsafe { device.create_texture_view(&texture, &view_desc).unwrap() }; + + let bind_group = { + let buffer_binding = hal::BufferBinding { + buffer: &uniform_buffer, + offset: 0, + size: None, + }; + let texture_binding = hal::TextureBinding { + view: &texture_view, + usage: hal::TextureUses::STORAGE_READ_WRITE, + }; + let group_desc = hal::BindGroupDescriptor { + label: Some("bind group"), + layout: &bgl, + buffers: &[buffer_binding], + samplers: &[], + textures: &[texture_binding], + acceleration_structures: &[&tlas], + entries: &[ + hal::BindGroupEntry { + binding: 0, + resource_index: 0, + count: 1, + }, + hal::BindGroupEntry { + binding: 1, + resource_index: 0, + count: 1, + }, + hal::BindGroupEntry { + binding: 2, + resource_index: 0, + count: 1, + }, + ], + }; + unsafe { device.create_bind_group(&group_desc).unwrap() } + }; + + let scratch_buffer = unsafe { + device + .create_buffer(&hal::BufferDescriptor { + label: Some("scratch buffer"), + size: blas_sizes + .build_scratch_size + .max(tlas_sizes.build_scratch_size), + usage: hal::BufferUses::ACCELERATION_STRUCTURE_SCRATCH, + memory_flags: hal::MemoryFlags::empty(), + }) + .unwrap() + }; + + let instances = [ + AccelerationStructureInstance::new( + &Affine3A::from_translation(Vec3 { + x: 0.0, + y: 0.0, + z: 0.0, + }), + 0, + 0xff, + 0, + 0, + unsafe { device.get_acceleration_structure_device_address(&blas) }, + ), + AccelerationStructureInstance::new( + &Affine3A::from_translation(Vec3 { + x: -1.0, + y: -1.0, + z: -2.0, + }), + 0, + 0xff, + 0, + 0, + unsafe { device.get_acceleration_structure_device_address(&blas) }, + ), + AccelerationStructureInstance::new( + &Affine3A::from_translation(Vec3 { + x: 1.0, + y: -1.0, + z: -2.0, + }), + 0, + 0xff, + 0, + 0, + unsafe { device.get_acceleration_structure_device_address(&blas) }, + ), + ]; + + let instances_buffer_size = + instances.len() * std::mem::size_of::(); + + let instances_buffer = unsafe { + let instances_buffer = device + .create_buffer(&hal::BufferDescriptor { + label: Some("instances_buffer"), + size: instances_buffer_size as u64, + usage: hal::BufferUses::MAP_WRITE + | hal::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, + memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT, + }) + .unwrap(); + + let mapping = device + .map_buffer(&instances_buffer, 0..instances_buffer_size as u64) + .unwrap(); + ptr::copy_nonoverlapping( + instances.as_ptr() as *const u8, + mapping.ptr.as_ptr(), + instances_buffer_size, + ); + device.unmap_buffer(&instances_buffer).unwrap(); + assert!(mapping.is_coherent); + + instances_buffer + }; + + if let hal::AccelerationStructureEntries::Instances(ref mut i) = tlas_entries { + i.buffer = Some(&instances_buffer); + assert!( + instances.len() <= i.count as usize, + "Tlas allocation to small" + ); + } + + let cmd_encoder_desc = hal::CommandEncoderDescriptor { + label: None, + queue: &queue, + }; + let mut cmd_encoder = unsafe { device.create_command_encoder(&cmd_encoder_desc).unwrap() }; + + unsafe { cmd_encoder.begin_encoding(Some("init")).unwrap() }; + + unsafe { + cmd_encoder.build_acceleration_structures(&[ + &hal::BuildAccelerationStructureDescriptor { + mode: hal::AccelerationStructureBuildMode::Build, + flags: hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE, + destination_acceleration_structure: &blas, + scratch_buffer: &scratch_buffer, + entries: &blas_entries, + source_acceleration_structure: None, + }, + ]); + + let as_barrier = hal::BufferBarrier { + buffer: &scratch_buffer, + usage: hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT + ..hal::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, + }; + cmd_encoder.transition_buffers(iter::once(as_barrier)); + + cmd_encoder.build_acceleration_structures(&[ + &hal::BuildAccelerationStructureDescriptor { + mode: hal::AccelerationStructureBuildMode::Build, + flags: tlas_flags, + destination_acceleration_structure: &tlas, + scratch_buffer: &scratch_buffer, + entries: &tlas_entries, + source_acceleration_structure: None, + }, + ]); + + let texture_barrier = hal::TextureBarrier { + texture: &texture, + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::UNINITIALIZED..hal::TextureUses::STORAGE_READ_WRITE, + }; + + cmd_encoder.transition_textures(iter::once(texture_barrier)); + } + + let init_fence_value = 1; + let fence = unsafe { + let mut fence = device.create_fence().unwrap(); + let init_cmd = cmd_encoder.end_encoding().unwrap(); + queue + .submit(&[&init_cmd], Some((&mut fence, init_fence_value))) + .unwrap(); + device.wait(&fence, init_fence_value, !0).unwrap(); + cmd_encoder.reset_all(iter::once(init_cmd)); + fence + }; + + Ok(Self { + instance, + adapter, + surface, + surface_format: surface_config.format, + device, + queue, + pipeline, + contexts: vec![ExecutionContext { + encoder: cmd_encoder, + fence, + fence_value: init_fence_value + 1, + used_views: Vec::new(), + used_cmd_bufs: Vec::new(), + frames_recorded: 0, + }], + context_index: 0, + extent: [window_size.0, window_size.1], + start: Instant::now(), + pipeline_layout, + bind_group, + texture, + instances, + instances_buffer, + blas, + tlas, + scratch_buffer, + time: 0.0, + indices_buffer, + vertices_buffer, + uniform_buffer, + texture_view, + bgl, + shader_module, + }) + } + + fn update(&mut self, _event: winit::event::WindowEvent) {} + + fn render(&mut self) { + let ctx = &mut self.contexts[self.context_index]; + + let surface_tex = unsafe { self.surface.acquire_texture(None).unwrap().unwrap().texture }; + + let target_barrier0 = hal::TextureBarrier { + texture: surface_tex.borrow(), + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::UNINITIALIZED..hal::TextureUses::COPY_DST, + }; + + let instances_buffer_size = + self.instances.len() * std::mem::size_of::(); + + let tlas_flags = hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE + | hal::AccelerationStructureBuildFlags::ALLOW_UPDATE; + + self.time += 1.0 / 60.0; + + self.instances[0].set_transform(&Affine3A::from_rotation_y(self.time)); + + unsafe { + let mapping = self + .device + .map_buffer(&self.instances_buffer, 0..instances_buffer_size as u64) + .unwrap(); + ptr::copy_nonoverlapping( + self.instances.as_ptr() as *const u8, + mapping.ptr.as_ptr(), + instances_buffer_size, + ); + self.device.unmap_buffer(&self.instances_buffer).unwrap(); + assert!(mapping.is_coherent); + } + + unsafe { + ctx.encoder.begin_encoding(Some("frame")).unwrap(); + + let instances = hal::AccelerationStructureInstances { + buffer: Some(&self.instances_buffer), + count: self.instances.len() as u32, + offset: 0, + }; + ctx.encoder.build_acceleration_structures(&[ + &hal::BuildAccelerationStructureDescriptor { + mode: hal::AccelerationStructureBuildMode::Update, + flags: tlas_flags, + destination_acceleration_structure: &self.tlas, + scratch_buffer: &self.scratch_buffer, + entries: &hal::AccelerationStructureEntries::Instances(instances), + source_acceleration_structure: Some(&self.tlas), + }, + ]); + + let as_barrier = hal::BufferBarrier { + buffer: &self.scratch_buffer, + usage: hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT + ..hal::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, + }; + ctx.encoder.transition_buffers(iter::once(as_barrier)); + + ctx.encoder.transition_textures(iter::once(target_barrier0)); + } + + let surface_view_desc = hal::TextureViewDescriptor { + label: None, + format: self.surface_format, + dimension: wgt::TextureViewDimension::D2, + usage: hal::TextureUses::COPY_DST, + range: wgt::ImageSubresourceRange::default(), + }; + let surface_tex_view = unsafe { + self.device + .create_texture_view(surface_tex.borrow(), &surface_view_desc) + .unwrap() + }; + unsafe { + ctx.encoder + .begin_compute_pass(&hal::ComputePassDescriptor { label: None }); + ctx.encoder.set_compute_pipeline(&self.pipeline); + ctx.encoder + .set_bind_group(&self.pipeline_layout, 0, &self.bind_group, &[]); + ctx.encoder.dispatch([512 / 8, 512 / 8, 1]); + } + + ctx.frames_recorded += 1; + let do_fence = ctx.frames_recorded > COMMAND_BUFFER_PER_CONTEXT; + + let target_barrier1 = hal::TextureBarrier { + texture: surface_tex.borrow(), + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::COPY_DST..hal::TextureUses::PRESENT, + }; + let target_barrier2 = hal::TextureBarrier { + texture: &self.texture, + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::STORAGE_READ_WRITE..hal::TextureUses::COPY_SRC, + }; + let target_barrier3 = hal::TextureBarrier { + texture: &self.texture, + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::COPY_SRC..hal::TextureUses::STORAGE_READ_WRITE, + }; + unsafe { + ctx.encoder.end_compute_pass(); + ctx.encoder.transition_textures(iter::once(target_barrier2)); + ctx.encoder.copy_texture_to_texture( + &self.texture, + hal::TextureUses::COPY_SRC, + &surface_tex.borrow(), + std::iter::once(hal::TextureCopy { + src_base: hal::TextureCopyBase { + mip_level: 0, + array_layer: 0, + origin: wgt::Origin3d::ZERO, + aspect: hal::FormatAspects::COLOR, + }, + dst_base: hal::TextureCopyBase { + mip_level: 0, + array_layer: 0, + origin: wgt::Origin3d::ZERO, + aspect: hal::FormatAspects::COLOR, + }, + size: hal::CopyExtent { + width: 512, + height: 512, + depth: 1, + }, + }), + ); + ctx.encoder.transition_textures(iter::once(target_barrier1)); + ctx.encoder.transition_textures(iter::once(target_barrier3)); + } + + unsafe { + let cmd_buf = ctx.encoder.end_encoding().unwrap(); + let fence_param = if do_fence { + Some((&mut ctx.fence, ctx.fence_value)) + } else { + None + }; + self.queue.submit(&[&cmd_buf], fence_param).unwrap(); + self.queue.present(&mut self.surface, surface_tex).unwrap(); + ctx.used_cmd_bufs.push(cmd_buf); + ctx.used_views.push(surface_tex_view); + }; + + if do_fence { + log::info!("Context switch from {}", self.context_index); + let old_fence_value = ctx.fence_value; + if self.contexts.len() == 1 { + let hal_desc = hal::CommandEncoderDescriptor { + label: None, + queue: &self.queue, + }; + self.contexts.push(unsafe { + ExecutionContext { + encoder: self.device.create_command_encoder(&hal_desc).unwrap(), + fence: self.device.create_fence().unwrap(), + fence_value: 0, + used_views: Vec::new(), + used_cmd_bufs: Vec::new(), + frames_recorded: 0, + } + }); + } + self.context_index = (self.context_index + 1) % self.contexts.len(); + let next = &mut self.contexts[self.context_index]; + unsafe { + next.wait_and_clear(&self.device); + } + next.fence_value = old_fence_value + 1; + } + } + + fn exit(mut self) { + unsafe { + { + let ctx = &mut self.contexts[self.context_index]; + self.queue + .submit(&[], Some((&mut ctx.fence, ctx.fence_value))) + .unwrap(); + } + + for mut ctx in self.contexts { + ctx.wait_and_clear(&self.device); + self.device.destroy_command_encoder(ctx.encoder); + self.device.destroy_fence(ctx.fence); + } + + self.device.destroy_bind_group(self.bind_group); + self.device.destroy_buffer(self.scratch_buffer); + self.device.destroy_buffer(self.instances_buffer); + self.device.destroy_buffer(self.indices_buffer); + self.device.destroy_buffer(self.vertices_buffer); + self.device.destroy_buffer(self.uniform_buffer); + self.device.destroy_acceleration_structure(self.tlas); + self.device.destroy_acceleration_structure(self.blas); + self.device.destroy_texture_view(self.texture_view); + self.device.destroy_texture(self.texture); + self.device.destroy_compute_pipeline(self.pipeline); + self.device.destroy_pipeline_layout(self.pipeline_layout); + self.device.destroy_bind_group_layout(self.bgl); + self.device.destroy_shader_module(self.shader_module); + + self.surface.unconfigure(&self.device); + self.device.exit(self.queue); + self.instance.destroy_surface(self.surface); + drop(self.adapter); + } + } +} + +#[cfg(all(feature = "metal"))] +type Api = hal::api::Metal; +#[cfg(all(feature = "vulkan", not(feature = "metal")))] +type Api = hal::api::Vulkan; +#[cfg(all(feature = "gles", not(feature = "metal"), not(feature = "vulkan")))] +type Api = hal::api::Gles; +#[cfg(all( + feature = "dx12", + not(feature = "metal"), + not(feature = "vulkan"), + not(feature = "gles") +))] +type Api = hal::api::Dx12; +#[cfg(not(any( + feature = "metal", + feature = "vulkan", + feature = "gles", + feature = "dx12" +)))] +type Api = hal::api::Empty; + +fn main() { + env_logger::init(); + + let event_loop = winit::event_loop::EventLoop::new(); + let window = winit::window::WindowBuilder::new() + .with_title("hal-bunnymark") + .with_inner_size(winit::dpi::PhysicalSize { + width: 512, + height: 512, + }) + .with_resizable(false) + .build(&event_loop) + .unwrap(); + + let example_result = Example::::init(&window); + let mut example = Some(example_result.expect("Selected backend is not supported")); + + event_loop.run(move |event, _, control_flow| { + let _ = &window; // force ownership by the closure + *control_flow = winit::event_loop::ControlFlow::Poll; + match event { + winit::event::Event::RedrawEventsCleared => { + window.request_redraw(); + } + winit::event::Event::WindowEvent { event, .. } => match event { + winit::event::WindowEvent::KeyboardInput { + input: + winit::event::KeyboardInput { + virtual_keycode: Some(winit::event::VirtualKeyCode::Escape), + state: winit::event::ElementState::Pressed, + .. + }, + .. + } + | winit::event::WindowEvent::CloseRequested => { + *control_flow = winit::event_loop::ControlFlow::Exit; + } + _ => { + example.as_mut().unwrap().update(event); + } + }, + winit::event::Event::RedrawRequested(_) => { + let ex = example.as_mut().unwrap(); + + ex.render(); + } + winit::event::Event::LoopDestroyed => { + example.take().unwrap().exit(); + } + _ => {} + } + }); +} diff --git a/wgpu-hal/examples/ray-traced-triangle/shader.comp b/wgpu-hal/examples/ray-traced-triangle/shader.comp new file mode 100644 index 0000000000..d31f29115f --- /dev/null +++ b/wgpu-hal/examples/ray-traced-triangle/shader.comp @@ -0,0 +1,44 @@ +#version 460 +#extension GL_EXT_ray_query : enable + +layout(set = 0, binding = 0) uniform Uniforms +{ + mat4 viewInverse; + mat4 projInverse; +} cam; +layout(set = 0, binding = 1, rgba8) uniform image2D image; +layout(set = 0, binding = 2) uniform accelerationStructureEXT tlas; + +layout(local_size_x = 8, local_size_y = 8) in; + +void main() +{ + uvec2 launch_id = gl_GlobalInvocationID.xy; + uvec2 launch_size = gl_NumWorkGroups.xy * 8; + + const vec2 pixelCenter = vec2(launch_id) + vec2(0.5); + const vec2 inUV = pixelCenter/vec2(launch_size); + vec2 d = inUV * 2.0 - 1.0; + + vec4 origin = cam.viewInverse * vec4(0,0,0,1); + vec4 target = cam.projInverse * vec4(d.x, d.y, 1, 1) ; + vec4 direction = cam.viewInverse*vec4(normalize(target.xyz), 0) ; + + float tmin = 0.001; + float tmax = 10000.0; + + rayQueryEXT rayQuery; + rayQueryInitializeEXT(rayQuery, tlas, gl_RayFlagsOpaqueEXT, 0xff, origin.xyz, tmin, direction.xyz, tmax); + + rayQueryProceedEXT(rayQuery); + + vec3 out_colour = vec3(0.0, 0.0, 0.0); + + if (rayQueryGetIntersectionTypeEXT(rayQuery, true) == gl_RayQueryCommittedIntersectionTriangleEXT ) { + vec2 barycentrics = rayQueryGetIntersectionBarycentricsEXT(rayQuery, true); + + out_colour = vec3(barycentrics.x, barycentrics.y, 1.0 - barycentrics.x - barycentrics.y); + } + + imageStore(image, ivec2(launch_id), vec4(out_colour, 1.0)); +} \ No newline at end of file diff --git a/wgpu-hal/examples/ray-traced-triangle/shader.comp.spv b/wgpu-hal/examples/ray-traced-triangle/shader.comp.spv new file mode 100644 index 0000000000..345085c948 Binary files /dev/null and b/wgpu-hal/examples/ray-traced-triangle/shader.comp.spv differ diff --git a/wgpu-hal/examples/ray-tracing-pipeline-triangle/main.rs b/wgpu-hal/examples/ray-tracing-pipeline-triangle/main.rs new file mode 100644 index 0000000000..84c175c802 --- /dev/null +++ b/wgpu-hal/examples/ray-tracing-pipeline-triangle/main.rs @@ -0,0 +1,1290 @@ +extern crate wgpu_hal as hal; + +use hal::{ + Adapter as _, CommandEncoder as _, Device as _, Instance as _, Queue as _, + RayTracingHitShaderGroup, RayTracingPipeline, ShaderBindingTableReference, SkipHitType, + Surface as _, +}; +use raw_window_handle::{HasRawDisplayHandle, HasRawWindowHandle}; + +use glam::{Affine3A, Mat4, Vec3}; +use std::{ + borrow::{Borrow, Cow}, + iter, mem, + mem::{align_of, size_of}, + ptr::{self, copy_nonoverlapping}, + time::Instant, +}; + +const COMMAND_BUFFER_PER_CONTEXT: usize = 100; +const DESIRED_FRAMES: u32 = 3; + +/// [D3D12_RAYTRACING_INSTANCE_DESC](https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#d3d12_raytracing_instance_desc) +/// [VkAccelerationStructureInstanceKHR](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkAccelerationStructureInstanceKHR.html) +#[derive(Clone)] +#[repr(C)] +struct AccelerationStructureInstance { + transform: [f32; 12], + custom_index_and_mask: u32, + shader_binding_table_record_offset_and_flags: u32, + acceleration_structure_reference: u64, +} + +impl std::fmt::Debug for AccelerationStructureInstance { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Instance") + .field("transform", &self.transform) + .field("custom_index()", &self.custom_index()) + .field("mask()", &self.mask()) + .field( + "shader_binding_table_record_offset()", + &self.shader_binding_table_record_offset(), + ) + .field("flags()", &self.flags()) + .field( + "acceleration_structure_reference", + &self.acceleration_structure_reference, + ) + .finish() + } +} + +#[allow(dead_code)] +impl AccelerationStructureInstance { + const LOW_24_MASK: u32 = 0x00ff_ffff; + const MAX_U24: u32 = (1u32 << 24u32) - 1u32; + + #[inline] + fn affine_to_rows(mat: &Affine3A) -> [f32; 12] { + let row_0 = mat.matrix3.row(0); + let row_1 = mat.matrix3.row(1); + let row_2 = mat.matrix3.row(2); + let translation = mat.translation; + [ + row_0.x, + row_0.y, + row_0.z, + translation.x, + row_1.x, + row_1.y, + row_1.z, + translation.y, + row_2.x, + row_2.y, + row_2.z, + translation.z, + ] + } + + #[inline] + fn rows_to_affine(rows: &[f32; 12]) -> Affine3A { + Affine3A::from_cols_array(&[ + rows[0], rows[3], rows[6], rows[9], rows[1], rows[4], rows[7], rows[10], rows[2], + rows[5], rows[8], rows[11], + ]) + } + + pub fn transform_as_affine(&self) -> Affine3A { + Self::rows_to_affine(&self.transform) + } + pub fn set_transform(&mut self, transform: &Affine3A) { + self.transform = Self::affine_to_rows(&transform); + } + + pub fn custom_index(&self) -> u32 { + self.custom_index_and_mask & Self::LOW_24_MASK + } + + pub fn mask(&self) -> u8 { + (self.custom_index_and_mask >> 24) as u8 + } + + pub fn shader_binding_table_record_offset(&self) -> u32 { + self.shader_binding_table_record_offset_and_flags & Self::LOW_24_MASK + } + + pub fn flags(&self) -> u8 { + (self.shader_binding_table_record_offset_and_flags >> 24) as u8 + } + + pub fn set_custom_index(&mut self, custom_index: u32) { + debug_assert!( + custom_index <= Self::MAX_U24, + "custom_index uses more than 24 bits! {custom_index} > {}", + Self::MAX_U24 + ); + self.custom_index_and_mask = + (custom_index & Self::LOW_24_MASK) | (self.custom_index_and_mask & !Self::LOW_24_MASK) + } + + pub fn set_mask(&mut self, mask: u8) { + self.custom_index_and_mask = + (self.custom_index_and_mask & Self::LOW_24_MASK) | (u32::from(mask) << 24) + } + + pub fn set_shader_binding_table_record_offset( + &mut self, + shader_binding_table_record_offset: u32, + ) { + debug_assert!(shader_binding_table_record_offset <= Self::MAX_U24, "shader_binding_table_record_offset uses more than 24 bits! {shader_binding_table_record_offset} > {}", Self::MAX_U24); + self.shader_binding_table_record_offset_and_flags = (shader_binding_table_record_offset + & Self::LOW_24_MASK) + | (self.shader_binding_table_record_offset_and_flags & !Self::LOW_24_MASK) + } + + pub fn set_flags(&mut self, flags: u8) { + self.shader_binding_table_record_offset_and_flags = + (self.shader_binding_table_record_offset_and_flags & Self::LOW_24_MASK) + | (u32::from(flags) << 24) + } + + pub fn new( + transform: &Affine3A, + custom_index: u32, + mask: u8, + shader_binding_table_record_offset: u32, + flags: u8, + acceleration_structure_reference: u64, + ) -> Self { + debug_assert!( + custom_index <= Self::MAX_U24, + "custom_index uses more than 24 bits! {custom_index} > {}", + Self::MAX_U24 + ); + debug_assert!( + shader_binding_table_record_offset <= Self::MAX_U24, + "shader_binding_table_record_offset uses more than 24 bits! {shader_binding_table_record_offset} > {}", Self::MAX_U24 + ); + AccelerationStructureInstance { + transform: Self::affine_to_rows(transform), + custom_index_and_mask: (custom_index & Self::MAX_U24) | (u32::from(mask) << 24), + shader_binding_table_record_offset_and_flags: (shader_binding_table_record_offset + & Self::MAX_U24) + | (u32::from(flags) << 24), + acceleration_structure_reference, + } + } +} + +struct ExecutionContext { + encoder: A::CommandEncoder, + fence: A::Fence, + fence_value: hal::FenceValue, + used_views: Vec, + used_cmd_bufs: Vec, + frames_recorded: usize, +} + +impl ExecutionContext { + unsafe fn wait_and_clear(&mut self, device: &A::Device) { + device.wait(&self.fence, self.fence_value, !0).unwrap(); + self.encoder.reset_all(self.used_cmd_bufs.drain(..)); + for view in self.used_views.drain(..) { + device.destroy_texture_view(view); + } + self.frames_recorded = 0; + } +} + +#[allow(dead_code)] +struct Example { + instance: A::Instance, + adapter: A::Adapter, + surface: A::Surface, + surface_format: wgt::TextureFormat, + device: A::Device, + queue: A::Queue, + + contexts: Vec>, + context_index: usize, + extent: [u32; 2], + start: Instant, + pipeline: A::RayTracingPipeline, + bind_group: A::BindGroup, + bgl: A::BindGroupLayout, + gen_shader_module: A::ShaderModule, + miss_shader_module: A::ShaderModule, + call_shader_module: A::ShaderModule, + hit_shader_module: A::ShaderModule, + texture_view: A::TextureView, + uniform_buffer: A::Buffer, + pipeline_layout: A::PipelineLayout, + vertices_buffer: A::Buffer, + indices_buffer: A::Buffer, + texture: A::Texture, + instances: [AccelerationStructureInstance; 3], + instances_buffer: A::Buffer, + blas: A::AccelerationStructure, + tlas: A::AccelerationStructure, + scratch_buffer: A::Buffer, + sbt_buffer: A::Buffer, + gen_sbt_ref: ShaderBindingTableReference, + miss_sbt_ref: ShaderBindingTableReference, + call_sbt_ref: ShaderBindingTableReference, + hit_sbt_ref: ShaderBindingTableReference, + time: f32, +} + +impl Example { + fn init(window: &winit::window::Window) -> Result { + let instance_desc = hal::InstanceDescriptor { + name: "example", + flags: if cfg!(debug_assertions) { + hal::InstanceFlags::all() + } else { + hal::InstanceFlags::empty() + }, + dx12_shader_compiler: wgt::Dx12Compiler::Fxc, + }; + let instance = unsafe { A::Instance::init(&instance_desc)? }; + let mut surface = unsafe { + instance + .create_surface(window.raw_display_handle(), window.raw_window_handle()) + .unwrap() + }; + + let (adapter, features) = unsafe { + let mut adapters = instance.enumerate_adapters(); + if adapters.is_empty() { + return Err(hal::InstanceError); + } + let exposed = adapters.swap_remove(0); + dbg!(exposed.features); + (exposed.adapter, exposed.features) + }; + let surface_caps = + unsafe { adapter.surface_capabilities(&surface) }.ok_or(hal::InstanceError)?; + log::info!("Surface caps: {:#?}", surface_caps); + + let hal::OpenDevice { device, mut queue } = + unsafe { adapter.open(features, &wgt::Limits::default()).unwrap() }; + + let window_size: (u32, u32) = window.inner_size().into(); + dbg!(&surface_caps.formats); + let surface_format = if surface_caps + .formats + .contains(&wgt::TextureFormat::Rgba8Snorm) + { + wgt::TextureFormat::Rgba8Unorm + } else { + *surface_caps.formats.first().unwrap() + }; + let surface_config = hal::SurfaceConfiguration { + swap_chain_size: DESIRED_FRAMES + .max(*surface_caps.swap_chain_sizes.start()) + .min(*surface_caps.swap_chain_sizes.end()), + present_mode: wgt::PresentMode::Fifo, + composite_alpha_mode: wgt::CompositeAlphaMode::Opaque, + format: surface_format, + extent: wgt::Extent3d { + width: window_size.0, + height: window_size.1, + depth_or_array_layers: 1, + }, + usage: hal::TextureUses::COLOR_TARGET | hal::TextureUses::COPY_DST, + view_formats: vec![surface_format], + }; + unsafe { + surface.configure(&device, &surface_config).unwrap(); + }; + + #[allow(dead_code)] + struct Uniforms { + view_inverse: glam::Mat4, + proj_inverse: glam::Mat4, + } + + let bgl_desc = hal::BindGroupLayoutDescriptor { + label: None, + flags: hal::BindGroupLayoutFlags::empty(), + entries: &[ + wgt::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::RAYGEN | wgt::ShaderStages::CLOSEST_HIT, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: wgt::BufferSize::new(mem::size_of::() as _), + }, + count: None, + }, + wgt::BindGroupLayoutEntry { + binding: 1, + visibility: wgt::ShaderStages::RAYGEN | wgt::ShaderStages::CLOSEST_HIT, + ty: wgt::BindingType::StorageTexture { + access: wgt::StorageTextureAccess::WriteOnly, + format: wgt::TextureFormat::Rgba8Unorm, + view_dimension: wgt::TextureViewDimension::D2, + }, + count: None, + }, + wgt::BindGroupLayoutEntry { + binding: 2, + visibility: wgt::ShaderStages::RAYGEN | wgt::ShaderStages::CLOSEST_HIT, + ty: wgt::BindingType::AccelerationStructure, + count: None, + }, + ], + }; + + let bgl = unsafe { device.create_bind_group_layout(&bgl_desc).unwrap() }; + + pub fn make_spirv_raw(data: &[u8]) -> Cow<[u32]> { + const MAGIC_NUMBER: u32 = 0x0723_0203; + assert_eq!( + data.len() % size_of::(), + 0, + "data size is not a multiple of 4" + ); + + //If the data happens to be aligned, directly use the byte array, + // otherwise copy the byte array in an owned vector and use that instead. + let words = if data.as_ptr().align_offset(align_of::()) == 0 { + let (pre, words, post) = unsafe { data.align_to::() }; + debug_assert!(pre.is_empty()); + debug_assert!(post.is_empty()); + Cow::from(words) + } else { + let mut words = vec![0u32; data.len() / size_of::()]; + unsafe { + copy_nonoverlapping(data.as_ptr(), words.as_mut_ptr() as *mut u8, data.len()); + } + Cow::from(words) + }; + + assert_eq!( + words[0], MAGIC_NUMBER, + "wrong magic word {:x}. Make sure you are using a binary SPIRV file.", + words[0] + ); + + words + } + + let gen_shader_module = unsafe { + device + .create_shader_module( + &hal::ShaderModuleDescriptor { + label: None, + runtime_checks: false, + }, + hal::ShaderInput::SpirV(&make_spirv_raw(include_bytes!("shader.rgen.spv"))), + ) + .unwrap() + }; + + let miss_shader_module = unsafe { + device + .create_shader_module( + &hal::ShaderModuleDescriptor { + label: None, + runtime_checks: false, + }, + hal::ShaderInput::SpirV(&make_spirv_raw(include_bytes!("shader.rmiss.spv"))), + ) + .unwrap() + }; + + let call_shader_module = unsafe { + device + .create_shader_module( + &hal::ShaderModuleDescriptor { + label: None, + runtime_checks: false, + }, + hal::ShaderInput::SpirV(&make_spirv_raw(include_bytes!("shader.rcall.spv"))), + ) + .unwrap() + }; + + let hit_shader_module = unsafe { + device + .create_shader_module( + &hal::ShaderModuleDescriptor { + label: None, + runtime_checks: false, + }, + hal::ShaderInput::SpirV(&make_spirv_raw(include_bytes!("shader.rchit.spv"))), + ) + .unwrap() + }; + + let pipeline_layout_desc = hal::PipelineLayoutDescriptor { + label: None, + flags: hal::PipelineLayoutFlags::empty(), + bind_group_layouts: &[&bgl], + push_constant_ranges: &[], + }; + let pipeline_layout = unsafe { + device + .create_pipeline_layout(&pipeline_layout_desc) + .unwrap() + }; + + let gen_group = hal::ProgrammableStage { + module: &gen_shader_module, + entry_point: "main", + }; + + let miss_group = hal::ProgrammableStage { + module: &miss_shader_module, + entry_point: "main", + }; + + let call_group = hal::ProgrammableStage { + module: &call_shader_module, + entry_point: "main", + }; + + let hit_group = RayTracingHitShaderGroup { + closest_hit: Some(hal::ProgrammableStage { + module: &hit_shader_module, + entry_point: "main", + }), + any_hit: None, + intersection: None, + }; + + let pipeline = unsafe { + device.create_ray_tracing_pipeline(&hal::RayTracingPipelineDescriptor { + label: Some("pipeline"), + layout: &pipeline_layout, + max_recursion_depth: 1, + skip_hit_type: SkipHitType::None, + gen_groups: &[gen_group], + miss_groups: &[miss_group], + call_groups: &[call_group], + hit_groups: &[hit_group], + }) + } + .unwrap(); + + //SBT + let (sbt_buffer, gen_sbt_ref, miss_sbt_ref, call_sbt_ref, hit_sbt_ref) = { + let col_a = glam::vec4(1.0, 1.0, 1.0, 0.5); + let col_b = glam::vec4(0.0, 0.0, 0.0, 0.0); + + let mut col_a_mem = [0u8; 16]; + let mut col_b_mem = [0u8; 16]; + + unsafe { + ptr::copy_nonoverlapping( + &col_a as *const glam::Vec4 as *const u8, + col_a_mem.as_mut_ptr(), + 16, + ); + ptr::copy_nonoverlapping( + &col_b as *const glam::Vec4 as *const u8, + col_b_mem.as_mut_ptr(), + 16, + ); + } + + let gen_records: [&[u8]; 1] = [&[]]; + let miss_records: [&[u8]; 1] = [&[]]; + let call_records: [&[u8]; 1] = [&[]]; + let hit_records: [&[u8]; 2] = [&col_a_mem, &col_b_mem]; + + let gen_handles = pipeline.gen_handles(); + let miss_handles = pipeline.miss_handles(); + let call_handles = pipeline.call_handles(); + let hit_handles = pipeline.hit_handles().repeat(2); + + let gen_sbt_data = device.assemble_sbt_data(&gen_handles, &gen_records); + let miss_sbt_data = device.assemble_sbt_data(&miss_handles, &miss_records); + let call_sbt_data = device.assemble_sbt_data(&call_handles, &call_records); + let hit_sbt_data = device.assemble_sbt_data(&hit_handles, &hit_records); + + let combined_iterator = gen_sbt_data + .data + .chain(miss_sbt_data.data) + .chain(call_sbt_data.data) + .chain(hit_sbt_data.data); + + let sbt_size = + gen_sbt_data.padded_size + miss_sbt_data.padded_size + call_sbt_data.padded_size + hit_sbt_data.padded_size; + + let sbt_buffer = unsafe { + let sbt_buffer = device + .create_buffer(&hal::BufferDescriptor { + label: Some("sbt buffer"), + size: sbt_size, + usage: hal::BufferUses::MAP_WRITE | hal::BufferUses::SHADER_BINDING_TABLE, + memory_flags: hal::MemoryFlags::TRANSIENT + | hal::MemoryFlags::PREFER_COHERENT, + }) + .unwrap(); + + let mapping = device.map_buffer(&sbt_buffer, 0..sbt_size).unwrap(); + let slice = ptr::slice_from_raw_parts_mut(mapping.ptr.as_ptr(), sbt_size as usize); + + for (i, src) in combined_iterator.enumerate() { + (*slice)[i] = src; + } + device.unmap_buffer(&sbt_buffer).unwrap(); + assert!(mapping.is_coherent); + + sbt_buffer + }; + + let sbt_address = unsafe { device.get_buffer_device_address(&sbt_buffer) }; + + let mut offset = 0; + + let gen_sbt_ref = ShaderBindingTableReference { + address: sbt_address + offset, + stride: gen_sbt_data.stride, + size: gen_sbt_data.size, + }; + offset += gen_sbt_data.padded_size; + + let miss_sbt_ref = ShaderBindingTableReference { + address: sbt_address + offset, + stride: miss_sbt_data.stride, + size: miss_sbt_data.size, + }; + offset += miss_sbt_data.padded_size; + + let call_sbt_ref = ShaderBindingTableReference { + address: sbt_address + offset, + stride: call_sbt_data.stride, + size: call_sbt_data.size, + }; + offset += call_sbt_data.padded_size; + + let hit_sbt_ref = ShaderBindingTableReference { + address: sbt_address + offset, + stride: hit_sbt_data.stride as u64, + size: hit_sbt_data.size, + }; + (sbt_buffer, gen_sbt_ref, miss_sbt_ref, call_sbt_ref, hit_sbt_ref) + }; + + // t[0] = &[1u8; 8]; + + // std::vector table_data; + // for (size_t i = 0; i < count; i++) { + // group_strides[i] = align_up < VkDeviceSize > (handle_size + max_record_sizes[i], properties.shaderGroupHandleAlignment); + // sizes[i] = align_up(group_counts[i] * group_strides[i], properties.shaderGroupBaseAlignment); + // size_t offset = table_data.size(); + // table_data.insert(table_data.end(), sizes[i], 0); + // record_offsets[i] = offset + handle_size; + // for (size_t c = 0; c < group_counts[i]; c++) { + // memcpy(&table_data[offset], &handles[cur_group * handle_size], handle_size); + // offset += group_strides[i]; + // cur_group++; + // } + // } + + let vertices: [f32; 9] = [1.0, 1.0, 0.0, -1.0, 1.0, 0.0, 0.0, -1.0, 0.0]; + + let vertices_size_in_bytes = vertices.len() * 4; + + let indices: [u32; 3] = [0, 1, 2]; + + let indices_size_in_bytes = indices.len() * 4; + + let vertices_buffer = unsafe { + let vertices_buffer = device + .create_buffer(&hal::BufferDescriptor { + label: Some("vertices buffer"), + size: vertices_size_in_bytes as u64, + usage: hal::BufferUses::MAP_WRITE + | hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, + memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT, + }) + .unwrap(); + + let mapping = device + .map_buffer(&vertices_buffer, 0..vertices_size_in_bytes as u64) + .unwrap(); + ptr::copy_nonoverlapping( + vertices.as_ptr() as *const u8, + mapping.ptr.as_ptr(), + vertices_size_in_bytes, + ); + device.unmap_buffer(&vertices_buffer).unwrap(); + assert!(mapping.is_coherent); + + vertices_buffer + }; + + let indices_buffer = unsafe { + let indices_buffer = device + .create_buffer(&hal::BufferDescriptor { + label: Some("indices buffer"), + size: indices_size_in_bytes as u64, + usage: hal::BufferUses::MAP_WRITE + | hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, + memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT, + }) + .unwrap(); + + let mapping = device + .map_buffer(&indices_buffer, 0..indices_size_in_bytes as u64) + .unwrap(); + ptr::copy_nonoverlapping( + indices.as_ptr() as *const u8, + mapping.ptr.as_ptr(), + indices_size_in_bytes, + ); + device.unmap_buffer(&indices_buffer).unwrap(); + assert!(mapping.is_coherent); + + indices_buffer + }; + + let blas_triangles = vec![hal::AccelerationStructureTriangles { + vertex_buffer: Some(&vertices_buffer), + first_vertex: 0, + vertex_format: wgt::VertexFormat::Float32x3, + vertex_count: vertices.len() as u32, + vertex_stride: 3 * 4, + indices: Some(hal::AccelerationStructureTriangleIndices { + buffer: Some(&indices_buffer), + format: wgt::IndexFormat::Uint32, + offset: 0, + count: indices.len() as u32, + }), + transform: None, + flags: hal::AccelerationStructureGeometryFlags::OPAQUE, + }]; + let blas_entries = hal::AccelerationStructureEntries::Triangles(&blas_triangles); + + let mut tlas_entries = + hal::AccelerationStructureEntries::Instances(hal::AccelerationStructureInstances { + buffer: None, + count: 3, + offset: 0, + }); + + let blas_sizes = unsafe { + device.get_acceleration_structure_build_sizes( + &hal::GetAccelerationStructureBuildSizesDescriptor { + entries: &blas_entries, + flags: hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE, + }, + ) + }; + + let tlas_flags = hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE + | hal::AccelerationStructureBuildFlags::ALLOW_UPDATE; + + let tlas_sizes = unsafe { + device.get_acceleration_structure_build_sizes( + &hal::GetAccelerationStructureBuildSizesDescriptor { + entries: &tlas_entries, + flags: tlas_flags, + }, + ) + }; + + let blas = unsafe { + device.create_acceleration_structure(&hal::AccelerationStructureDescriptor { + label: Some("blas"), + size: blas_sizes.acceleration_structure_size, + format: hal::AccelerationStructureFormat::BottomLevel, + }) + } + .unwrap(); + + let tlas = unsafe { + device.create_acceleration_structure(&hal::AccelerationStructureDescriptor { + label: Some("tlas"), + size: tlas_sizes.acceleration_structure_size, + format: hal::AccelerationStructureFormat::TopLevel, + }) + } + .unwrap(); + + let uniforms = { + let view = Mat4::look_at_rh(Vec3::new(0.0, 0.0, 2.5), Vec3::ZERO, Vec3::Y); + let proj = Mat4::perspective_rh(59.0_f32.to_radians(), 1.0, 0.001, 1000.0); + + Uniforms { + view_inverse: view.inverse(), + proj_inverse: proj.inverse(), + } + }; + + let uniforms_size = std::mem::size_of::(); + + let uniform_buffer = unsafe { + let uniform_buffer = device + .create_buffer(&hal::BufferDescriptor { + label: Some("uniform buffer"), + size: uniforms_size as u64, + usage: hal::BufferUses::MAP_WRITE | hal::BufferUses::UNIFORM, + memory_flags: hal::MemoryFlags::PREFER_COHERENT, + }) + .unwrap(); + + let mapping = device + .map_buffer(&uniform_buffer, 0..uniforms_size as u64) + .unwrap(); + ptr::copy_nonoverlapping( + &uniforms as *const Uniforms as *const u8, + mapping.ptr.as_ptr(), + uniforms_size, + ); + device.unmap_buffer(&uniform_buffer).unwrap(); + assert!(mapping.is_coherent); + uniform_buffer + }; + + let texture_desc = hal::TextureDescriptor { + label: None, + size: wgt::Extent3d { + width: 512, + height: 512, + depth_or_array_layers: 1, + }, + mip_level_count: 1, + sample_count: 1, + dimension: wgt::TextureDimension::D2, + format: wgt::TextureFormat::Rgba8Unorm, + usage: hal::TextureUses::STORAGE_READ_WRITE | hal::TextureUses::COPY_SRC, + memory_flags: hal::MemoryFlags::empty(), + view_formats: vec![wgt::TextureFormat::Rgba8Unorm], + }; + let texture = unsafe { device.create_texture(&texture_desc).unwrap() }; + + let view_desc = hal::TextureViewDescriptor { + label: None, + format: texture_desc.format, + dimension: wgt::TextureViewDimension::D2, + usage: hal::TextureUses::STORAGE_READ_WRITE | hal::TextureUses::COPY_SRC, + range: wgt::ImageSubresourceRange::default(), + }; + let texture_view = unsafe { device.create_texture_view(&texture, &view_desc).unwrap() }; + + let bind_group = { + let buffer_binding = hal::BufferBinding { + buffer: &uniform_buffer, + offset: 0, + size: None, + }; + let texture_binding = hal::TextureBinding { + view: &texture_view, + usage: hal::TextureUses::STORAGE_READ_WRITE, + }; + let group_desc = hal::BindGroupDescriptor { + label: Some("bind group"), + layout: &bgl, + buffers: &[buffer_binding], + samplers: &[], + textures: &[texture_binding], + acceleration_structures: &[&tlas], + entries: &[ + hal::BindGroupEntry { + binding: 0, + resource_index: 0, + count: 1, + }, + hal::BindGroupEntry { + binding: 1, + resource_index: 0, + count: 1, + }, + hal::BindGroupEntry { + binding: 2, + resource_index: 0, + count: 1, + }, + ], + }; + unsafe { device.create_bind_group(&group_desc).unwrap() } + }; + + let scratch_buffer = unsafe { + device + .create_buffer(&hal::BufferDescriptor { + label: Some("scratch buffer"), + size: blas_sizes + .build_scratch_size + .max(tlas_sizes.build_scratch_size), + usage: hal::BufferUses::ACCELERATION_STRUCTURE_SCRATCH, + memory_flags: hal::MemoryFlags::empty(), + }) + .unwrap() + }; + + let instances = [ + AccelerationStructureInstance::new( + &Affine3A::from_translation(Vec3 { + x: 0.0, + y: 0.0, + z: 0.0, + }), + 0, + 0xff, + 1, + 0, + unsafe { device.get_acceleration_structure_device_address(&blas) }, + ), + AccelerationStructureInstance::new( + &Affine3A::from_translation(Vec3 { + x: -1.0, + y: -1.0, + z: -2.0, + }), + 1, + 0xff, + 0, + 0, + unsafe { device.get_acceleration_structure_device_address(&blas) }, + ), + AccelerationStructureInstance::new( + &Affine3A::from_translation(Vec3 { + x: 1.0, + y: -1.0, + z: -2.0, + }), + 0, + 0xff, + 0, + 0, + unsafe { device.get_acceleration_structure_device_address(&blas) }, + ), + ]; + + let instances_buffer_size = + instances.len() * std::mem::size_of::(); + + let instances_buffer = unsafe { + let instances_buffer = device + .create_buffer(&hal::BufferDescriptor { + label: Some("instances_buffer"), + size: instances_buffer_size as u64, + usage: hal::BufferUses::MAP_WRITE + | hal::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, + memory_flags: hal::MemoryFlags::TRANSIENT | hal::MemoryFlags::PREFER_COHERENT, + }) + .unwrap(); + + let mapping = device + .map_buffer(&instances_buffer, 0..instances_buffer_size as u64) + .unwrap(); + ptr::copy_nonoverlapping( + instances.as_ptr() as *const u8, + mapping.ptr.as_ptr(), + instances_buffer_size, + ); + device.unmap_buffer(&instances_buffer).unwrap(); + assert!(mapping.is_coherent); + + instances_buffer + }; + + if let hal::AccelerationStructureEntries::Instances(ref mut i) = tlas_entries { + i.buffer = Some(&instances_buffer); + assert!( + instances.len() <= i.count as usize, + "Tlas allocation to small" + ); + } + + let cmd_encoder_desc = hal::CommandEncoderDescriptor { + label: None, + queue: &queue, + }; + let mut cmd_encoder = unsafe { device.create_command_encoder(&cmd_encoder_desc).unwrap() }; + + unsafe { cmd_encoder.begin_encoding(Some("init")).unwrap() }; + + unsafe { + cmd_encoder.build_acceleration_structures(&[ + &hal::BuildAccelerationStructureDescriptor { + mode: hal::AccelerationStructureBuildMode::Build, + flags: hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE, + destination_acceleration_structure: &blas, + scratch_buffer: &scratch_buffer, + entries: &blas_entries, + source_acceleration_structure: None, + }, + ]); + + let as_barrier = hal::BufferBarrier { + buffer: &scratch_buffer, + usage: hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT + ..hal::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, + }; + cmd_encoder.transition_buffers(iter::once(as_barrier)); + + cmd_encoder.build_acceleration_structures(&[ + &hal::BuildAccelerationStructureDescriptor { + mode: hal::AccelerationStructureBuildMode::Build, + flags: tlas_flags, + destination_acceleration_structure: &tlas, + scratch_buffer: &scratch_buffer, + entries: &tlas_entries, + source_acceleration_structure: None, + }, + ]); + + let texture_barrier = hal::TextureBarrier { + texture: &texture, + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::UNINITIALIZED..hal::TextureUses::STORAGE_READ_WRITE, + }; + + cmd_encoder.transition_textures(iter::once(texture_barrier)); + } + + let init_fence_value = 1; + let fence = unsafe { + let mut fence = device.create_fence().unwrap(); + let init_cmd = cmd_encoder.end_encoding().unwrap(); + queue + .submit(&[&init_cmd], Some((&mut fence, init_fence_value))) + .unwrap(); + device.wait(&fence, init_fence_value, !0).unwrap(); + cmd_encoder.reset_all(iter::once(init_cmd)); + fence + }; + + Ok(Self { + instance, + adapter, + surface, + surface_format: surface_config.format, + device, + queue, + pipeline, + contexts: vec![ExecutionContext { + encoder: cmd_encoder, + fence, + fence_value: init_fence_value + 1, + used_views: Vec::new(), + used_cmd_bufs: Vec::new(), + frames_recorded: 0, + }], + context_index: 0, + extent: [window_size.0, window_size.1], + start: Instant::now(), + pipeline_layout, + bind_group, + texture, + instances, + instances_buffer, + blas, + tlas, + scratch_buffer, + sbt_buffer, + time: 0.0, + indices_buffer, + vertices_buffer, + uniform_buffer, + texture_view, + bgl, + gen_shader_module, + miss_shader_module, + call_shader_module, + hit_shader_module, + gen_sbt_ref, + miss_sbt_ref, + call_sbt_ref, + hit_sbt_ref, + }) + } + + fn update(&mut self, _event: winit::event::WindowEvent) {} + + fn render(&mut self) { + let ctx = &mut self.contexts[self.context_index]; + + let surface_tex = unsafe { self.surface.acquire_texture(None).unwrap().unwrap().texture }; + + let target_barrier0 = hal::TextureBarrier { + texture: surface_tex.borrow(), + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::UNINITIALIZED..hal::TextureUses::COPY_DST, + }; + + let instances_buffer_size = + self.instances.len() * std::mem::size_of::(); + + let tlas_flags = hal::AccelerationStructureBuildFlags::PREFER_FAST_TRACE + | hal::AccelerationStructureBuildFlags::ALLOW_UPDATE; + + self.time += 1.0 / 60.0; + + self.instances[0].set_transform(&Affine3A::from_rotation_y(self.time)); + + unsafe { + let mapping = self + .device + .map_buffer(&self.instances_buffer, 0..instances_buffer_size as u64) + .unwrap(); + ptr::copy_nonoverlapping( + self.instances.as_ptr() as *const u8, + mapping.ptr.as_ptr(), + instances_buffer_size, + ); + self.device.unmap_buffer(&self.instances_buffer).unwrap(); + assert!(mapping.is_coherent); + } + + unsafe { + ctx.encoder.begin_encoding(Some("frame")).unwrap(); + + let instances = hal::AccelerationStructureInstances { + buffer: Some(&self.instances_buffer), + count: self.instances.len() as u32, + offset: 0, + }; + ctx.encoder.build_acceleration_structures(&[ + &hal::BuildAccelerationStructureDescriptor { + mode: hal::AccelerationStructureBuildMode::Update, + flags: tlas_flags, + destination_acceleration_structure: &self.tlas, + scratch_buffer: &self.scratch_buffer, + entries: &hal::AccelerationStructureEntries::Instances(instances), + source_acceleration_structure: Some(&self.tlas), + }, + ]); + + let as_barrier = hal::BufferBarrier { + buffer: &self.scratch_buffer, + usage: hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT + ..hal::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, + }; + ctx.encoder.transition_buffers(iter::once(as_barrier)); + + ctx.encoder.transition_textures(iter::once(target_barrier0)); + } + + let surface_view_desc = hal::TextureViewDescriptor { + label: None, + format: self.surface_format, + dimension: wgt::TextureViewDimension::D2, + usage: hal::TextureUses::COPY_DST, + range: wgt::ImageSubresourceRange::default(), + }; + let surface_tex_view = unsafe { + self.device + .create_texture_view(surface_tex.borrow(), &surface_view_desc) + .unwrap() + }; + unsafe { + ctx.encoder + .begin_ray_tracing_pass(&hal::RayTracingPassDescriptor { label: None }); + ctx.encoder.set_ray_tracing_pipeline(&self.pipeline); + ctx.encoder + .set_bind_group(&self.pipeline_layout, 0, &self.bind_group, &[]); + + ctx.encoder.trace_rays( + &self.gen_sbt_ref, + &self.miss_sbt_ref, + &self.call_sbt_ref, + &self.hit_sbt_ref, + [512, 512, 1], + ) + } + + ctx.frames_recorded += 1; + let do_fence = ctx.frames_recorded > COMMAND_BUFFER_PER_CONTEXT; + + let target_barrier1 = hal::TextureBarrier { + texture: surface_tex.borrow(), + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::COPY_DST..hal::TextureUses::PRESENT, + }; + let target_barrier2 = hal::TextureBarrier { + texture: &self.texture, + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::STORAGE_READ_WRITE..hal::TextureUses::COPY_SRC, + }; + let target_barrier3 = hal::TextureBarrier { + texture: &self.texture, + range: wgt::ImageSubresourceRange::default(), + usage: hal::TextureUses::COPY_SRC..hal::TextureUses::STORAGE_READ_WRITE, + }; + unsafe { + ctx.encoder.end_ray_tracing_pass(); + ctx.encoder.transition_textures(iter::once(target_barrier2)); + ctx.encoder.copy_texture_to_texture( + &self.texture, + hal::TextureUses::COPY_SRC, + &surface_tex.borrow(), + std::iter::once(hal::TextureCopy { + src_base: hal::TextureCopyBase { + mip_level: 0, + array_layer: 0, + origin: wgt::Origin3d::ZERO, + aspect: hal::FormatAspects::COLOR, + }, + dst_base: hal::TextureCopyBase { + mip_level: 0, + array_layer: 0, + origin: wgt::Origin3d::ZERO, + aspect: hal::FormatAspects::COLOR, + }, + size: hal::CopyExtent { + width: 512, + height: 512, + depth: 1, + }, + }), + ); + ctx.encoder.transition_textures(iter::once(target_barrier1)); + ctx.encoder.transition_textures(iter::once(target_barrier3)); + } + + unsafe { + let cmd_buf = ctx.encoder.end_encoding().unwrap(); + let fence_param = if do_fence { + Some((&mut ctx.fence, ctx.fence_value)) + } else { + None + }; + self.queue.submit(&[&cmd_buf], fence_param).unwrap(); + self.queue.present(&mut self.surface, surface_tex).unwrap(); + ctx.used_cmd_bufs.push(cmd_buf); + ctx.used_views.push(surface_tex_view); + }; + + if do_fence { + log::info!("Context switch from {}", self.context_index); + let old_fence_value = ctx.fence_value; + if self.contexts.len() == 1 { + let hal_desc = hal::CommandEncoderDescriptor { + label: None, + queue: &self.queue, + }; + self.contexts.push(unsafe { + ExecutionContext { + encoder: self.device.create_command_encoder(&hal_desc).unwrap(), + fence: self.device.create_fence().unwrap(), + fence_value: 0, + used_views: Vec::new(), + used_cmd_bufs: Vec::new(), + frames_recorded: 0, + } + }); + } + self.context_index = (self.context_index + 1) % self.contexts.len(); + let next = &mut self.contexts[self.context_index]; + unsafe { + next.wait_and_clear(&self.device); + } + next.fence_value = old_fence_value + 1; + } + } + + fn exit(mut self) { + unsafe { + { + let ctx = &mut self.contexts[self.context_index]; + self.queue + .submit(&[], Some((&mut ctx.fence, ctx.fence_value))) + .unwrap(); + } + + for mut ctx in self.contexts { + ctx.wait_and_clear(&self.device); + self.device.destroy_command_encoder(ctx.encoder); + self.device.destroy_fence(ctx.fence); + } + + self.device.destroy_bind_group(self.bind_group); + self.device.destroy_buffer(self.scratch_buffer); + self.device.destroy_buffer(self.sbt_buffer); + self.device.destroy_buffer(self.instances_buffer); + self.device.destroy_buffer(self.indices_buffer); + self.device.destroy_buffer(self.vertices_buffer); + self.device.destroy_buffer(self.uniform_buffer); + self.device.destroy_acceleration_structure(self.tlas); + self.device.destroy_acceleration_structure(self.blas); + self.device.destroy_texture_view(self.texture_view); + self.device.destroy_texture(self.texture); + self.device.destroy_ray_tracing_pipeline(self.pipeline); + self.device.destroy_pipeline_layout(self.pipeline_layout); + self.device.destroy_bind_group_layout(self.bgl); + self.device.destroy_shader_module(self.gen_shader_module); + self.device.destroy_shader_module(self.miss_shader_module); + self.device.destroy_shader_module(self.call_shader_module); + self.device.destroy_shader_module(self.hit_shader_module); + + self.surface.unconfigure(&self.device); + self.device.exit(self.queue); + self.instance.destroy_surface(self.surface); + drop(self.adapter); + } + } +} + +#[cfg(all(feature = "metal"))] +type Api = hal::api::Metal; +#[cfg(all(feature = "vulkan", not(feature = "metal")))] +type Api = hal::api::Vulkan; +#[cfg(all(feature = "gles", not(feature = "metal"), not(feature = "vulkan")))] +type Api = hal::api::Gles; +#[cfg(all( + feature = "dx12", + not(feature = "metal"), + not(feature = "vulkan"), + not(feature = "gles") +))] +type Api = hal::api::Dx12; +#[cfg(not(any( + feature = "metal", + feature = "vulkan", + feature = "gles", + feature = "dx12" +)))] +type Api = hal::api::Empty; + +fn main() { + env_logger::init(); + + let event_loop = winit::event_loop::EventLoop::new(); + let window = winit::window::WindowBuilder::new() + .with_title("hal-ray-tracing-pipeline-example") + .with_inner_size(winit::dpi::PhysicalSize { + width: 512, + height: 512, + }) + .with_resizable(false) + .build(&event_loop) + .unwrap(); + + let example_result = Example::::init(&window); + let mut example = Some(example_result.expect("Selected backend is not supported")); + + event_loop.run(move |event, _, control_flow| { + let _ = &window; // force ownership by the closure + *control_flow = winit::event_loop::ControlFlow::Poll; + match event { + winit::event::Event::RedrawEventsCleared => { + window.request_redraw(); + } + winit::event::Event::WindowEvent { event, .. } => match event { + winit::event::WindowEvent::KeyboardInput { + input: + winit::event::KeyboardInput { + virtual_keycode: Some(winit::event::VirtualKeyCode::Escape), + state: winit::event::ElementState::Pressed, + .. + }, + .. + } + | winit::event::WindowEvent::CloseRequested => { + *control_flow = winit::event_loop::ControlFlow::Exit; + } + _ => { + example.as_mut().unwrap().update(event); + } + }, + winit::event::Event::RedrawRequested(_) => { + let ex = example.as_mut().unwrap(); + + ex.render(); + } + winit::event::Event::LoopDestroyed => { + example.take().unwrap().exit(); + } + _ => {} + } + }); +} diff --git a/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rcall b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rcall new file mode 100644 index 0000000000..eba02e7b4a --- /dev/null +++ b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rcall @@ -0,0 +1,13 @@ +// glslc --target-spv=spv1.6 shader.rcall -o shader.rcall.spv +#version 460 core +#extension GL_EXT_ray_tracing : require + +struct call_payload_struct { + vec3 col; +}; + +layout(location = 1) callableDataInEXT call_payload_struct call_payload; + +void main() { + call_payload.col = call_payload.col.grb; +} \ No newline at end of file diff --git a/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rcall.spv b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rcall.spv new file mode 100644 index 0000000000..967777c1ca Binary files /dev/null and b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rcall.spv differ diff --git a/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rchit b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rchit new file mode 100644 index 0000000000..a0a66500af --- /dev/null +++ b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rchit @@ -0,0 +1,55 @@ +// glslc --target-spv=spv1.6 shader.rchit -o shader.rchit.spv +#version 460 core +#extension GL_EXT_ray_tracing : require +#extension GL_EXT_scalar_block_layout : require + +layout(set = 0, binding = 2) uniform accelerationStructureEXT tlas; + +hitAttributeEXT vec2 barycentric_coord; + +layout(shaderRecordEXT, scalar) buffer shader_record { + vec4 col; +}record; + +struct ray_payload_struct { + vec3 pos; + vec3 dir; + vec3 col; +}; + +layout (location = 0) rayPayloadInEXT ray_payload_struct ray_payload; + +struct call_payload_struct { + vec3 col; +}; + +layout(location = 1) callableDataEXT call_payload_struct call_payload; + + +vec2 bary_lerp2(vec2 a, vec2 b, vec2 c, vec3 barycentrics) { + return a * barycentrics.x + b * barycentrics.y + c * barycentrics.z; +} + +vec3 bary_lerp3(vec3 a, vec3 b, vec3 c, vec3 barycentrics) { + return a * barycentrics.x + b * barycentrics.y + c * barycentrics.z; +} + +vec4 bary_lerp4(vec4 a, vec4 b, vec4 c, vec3 barycentrics) { + return a * barycentrics.x + b * barycentrics.y + c * barycentrics.z; +} + +void main() { + vec3 barycentrics = vec3(1.0f - barycentric_coord.x - barycentric_coord.y, barycentric_coord.x, barycentric_coord.y); + vec3 col = bary_lerp3(vec3(1,0,0),vec3(0,1,0),vec3(0,0,1), barycentrics); + + call_payload.col = col; + + if(gl_InstanceCustomIndexEXT == 1){ + executeCallableEXT( + 0, // SBT callable index + 1 // payload location + ); + } + + ray_payload.col = mix(call_payload.col,record.col.rgb,record.col.w); +} diff --git a/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rchit.spv b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rchit.spv new file mode 100644 index 0000000000..46830daeb8 Binary files /dev/null and b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rchit.spv differ diff --git a/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rgen b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rgen new file mode 100644 index 0000000000..abe152908e --- /dev/null +++ b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rgen @@ -0,0 +1,55 @@ +// glslc --target-spv=spv1.6 shader.rgen -o shader.rgen.spv +#version 460 core +#extension GL_EXT_ray_tracing : require + +layout(set = 0, binding = 0) uniform Uniforms +{ + mat4 viewInverse; + mat4 projInverse; +} cam; + +layout(set = 0, binding = 1, rgba8) uniform image2D image; +layout(set = 0, binding = 2) uniform accelerationStructureEXT tlas; + +struct ray_payload { + vec3 pos; + vec3 dir; + vec3 col; +}; +layout (location = 0) rayPayloadEXT ray_payload payload; + + +void main() { + ivec2 coords = ivec2(gl_LaunchIDEXT.xy); + + vec4 cam_position = cam.viewInverse * vec4(0.0, 0.0, 0.0, 1.0); + + payload.pos = cam_position.xyz; + payload.col = vec3(0,0,1); + + vec2 offset = vec2(0.5); + + vec2 pixel_center = vec2(coords) + offset; + vec2 uv = pixel_center / vec2(gl_LaunchSizeEXT.xy); + + vec4 target = cam.projInverse * vec4(uv * 2.0 - 1.0, 1.0, 1.0); + vec4 direction = cam.viewInverse * vec4(normalize(target.xyz), 0.0); + + payload.dir = direction.xyz; + + traceRayEXT( + tlas, + gl_RayFlagsOpaqueEXT, + 0xff, + 0, // SBT hit group index + 0, // SBT record stride + 0, // SBT miss index + payload.pos, + 0.001, // min distance + payload.dir, + 200, // max distance + 0 // payload location + ); + + imageStore(image, coords, vec4(payload.col, 1.0)); +} diff --git a/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rgen.spv b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rgen.spv new file mode 100644 index 0000000000..25d0258259 Binary files /dev/null and b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rgen.spv differ diff --git a/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rmiss b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rmiss new file mode 100644 index 0000000000..4db4bdd261 --- /dev/null +++ b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rmiss @@ -0,0 +1,25 @@ +// glslc --target-spv=spv1.6 shader.rmiss -o shader.rmiss.spv +#version 460 core +#extension GL_EXT_ray_tracing : require + +const float PI = 3.14159265; +const float INV_PI = 1.0 / PI; +const float INV_2PI = 0.5 / PI; + +struct ray_payload { + vec3 pos; + vec3 dir; + vec3 col; +}; +layout (location = 0) rayPayloadInEXT ray_payload payload; + +vec2 dir_to_uv(vec3 direction) +{ + vec2 uv = vec2(atan(direction.z, direction.x), asin(-direction.y)); + uv = vec2(uv.x * INV_2PI, uv.y * INV_PI) + 0.5; + return uv; +} + +void main() { + payload.col = vec3(dir_to_uv(normalize(payload.dir)),1.); +} diff --git a/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rmiss.spv b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rmiss.spv new file mode 100644 index 0000000000..712dfcebd1 Binary files /dev/null and b/wgpu-hal/examples/ray-tracing-pipeline-triangle/shader.rmiss.spv differ diff --git a/wgpu-hal/src/dx11/command.rs b/wgpu-hal/src/dx11/command.rs index 1c73f3c325..45c3eab4f6 100644 --- a/wgpu-hal/src/dx11/command.rs +++ b/wgpu-hal/src/dx11/command.rs @@ -265,4 +265,34 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) { todo!() } + + unsafe fn build_acceleration_structures( + &mut self, + desc: &[&crate::BuildAccelerationStructureDescriptor], + ) { + unimplemented!() + } + + unsafe fn begin_ray_tracing_pass(&mut self, desc: &crate::RayTracingPassDescriptor) { + unimplemented!() + } + + unsafe fn end_ray_tracing_pass(&mut self) { + unimplemented!() + } + + unsafe fn set_ray_tracing_pipeline(&mut self, pipeline: &super::RayTracingPipeline) { + unimplemented!() + } + + unsafe fn trace_rays( + &mut self, + ray_gen_sbt: &crate::ShaderBindingTableReference, + miss_sbt: &crate::ShaderBindingTableReference, + callable_sbt: &crate::ShaderBindingTableReference, + hit_sbt: &crate::ShaderBindingTableReference, + dimensions: [u32; 3], + ) { + unimplemented!() + } } diff --git a/wgpu-hal/src/dx11/device.rs b/wgpu-hal/src/dx11/device.rs index 3b087c4311..a655e385de 100644 --- a/wgpu-hal/src/dx11/device.rs +++ b/wgpu-hal/src/dx11/device.rs @@ -200,6 +200,54 @@ impl crate::Device for super::Device { unsafe fn stop_capture(&self) { todo!() } + + unsafe fn create_acceleration_structure( + &self, + desc: &crate::AccelerationStructureDescriptor, + ) -> Result { + unimplemented!() + } + unsafe fn get_acceleration_structure_build_sizes( + &self, + desc: &crate::GetAccelerationStructureBuildSizesDescriptor, + ) -> crate::AccelerationStructureBuildSizes { + unimplemented!() + } + unsafe fn get_acceleration_structure_device_address( + &self, + acceleration_structure: &super::AccelerationStructure, + ) -> wgt::BufferAddress { + unimplemented!() + } + unsafe fn destroy_acceleration_structure( + &self, + acceleration_structure: super::AccelerationStructure, + ) { + unimplemented!() + } + + unsafe fn create_ray_tracing_pipeline( + &self, + _desc: &crate::RayTracingPipelineDescriptor, + ) -> Result { + unimplemented!() + } + + unsafe fn destroy_ray_tracing_pipeline(&self, _pipeline: super::RayTracingPipeline) { + unimplemented!() + } + + fn assemble_sbt_data<'a>( + &self, + handles: &'a [&'a [u8]], + records: &'a [&'a [u8]], + ) -> crate::ShaderBindingTableData<'a> { + unimplemented!() + } + + unsafe fn get_buffer_device_address(&self, buffer: &super::Buffer) -> wgt::BufferAddress { + unimplemented!() + } } impl crate::Queue for super::Queue { diff --git a/wgpu-hal/src/dx11/mod.rs b/wgpu-hal/src/dx11/mod.rs index 91827874b1..dab7e27279 100644 --- a/wgpu-hal/src/dx11/mod.rs +++ b/wgpu-hal/src/dx11/mod.rs @@ -36,6 +36,9 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + type RayTracingPipeline = RayTracingPipeline; + + type AccelerationStructure = AccelerationStructure; } pub struct Instance { @@ -108,6 +111,8 @@ pub struct BindGroup {} pub struct PipelineLayout {} #[derive(Debug)] pub struct ShaderModule {} +#[derive(Debug)] +pub struct AccelerationStructure {} pub struct RenderPipeline {} pub struct ComputePipeline {} @@ -135,3 +140,23 @@ impl crate::Surface for Surface { todo!() } } + +pub struct RayTracingPipeline {} + +impl crate::RayTracingPipeline for RayTracingPipeline { + fn gen_handles<'a>(&'a self) -> Vec<&'a [u8]> { + unimplemented!() + } + + fn miss_handles<'a>(&'a self) -> Vec<&'a [u8]> { + unimplemented!() + } + + fn call_handles<'a>(&'a self) -> Vec<&'a [u8]> { + unimplemented!() + } + + fn hit_handles<'a>(&'a self) -> Vec<&'a [u8]> { + unimplemented!() + } +} diff --git a/wgpu-hal/src/dx12/command.rs b/wgpu-hal/src/dx12/command.rs index c634110154..e6b7699c03 100644 --- a/wgpu-hal/src/dx12/command.rs +++ b/wgpu-hal/src/dx12/command.rs @@ -1140,4 +1140,36 @@ impl crate::CommandEncoder for super::CommandEncoder { ) }; } + + unsafe fn build_acceleration_structures( + &mut self, + _desc: &[&crate::BuildAccelerationStructureDescriptor], + ) { + // Implement using `BuildRaytracingAccelerationStructure`: + // https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#buildraytracingaccelerationstructure + todo!() + } + + unsafe fn begin_ray_tracing_pass(&mut self, _desc: &crate::RayTracingPassDescriptor) { + todo!() + } + + unsafe fn end_ray_tracing_pass(&mut self) { + todo!() + } + + unsafe fn set_ray_tracing_pipeline(&mut self, _pipeline: &super::RayTracingPipeline) { + todo!() + } + + unsafe fn trace_rays( + &mut self, + _ray_gen_sbt: &crate::ShaderBindingTableReference, + _miss_sbt: &crate::ShaderBindingTableReference, + _callable_sbt: &crate::ShaderBindingTableReference, + _hit_sbt: &crate::ShaderBindingTableReference, + _dimensions: [u32; 3], + ) { + todo!() + } } diff --git a/wgpu-hal/src/dx12/conv.rs b/wgpu-hal/src/dx12/conv.rs index 7b39e98ad2..57ce8254db 100644 --- a/wgpu-hal/src/dx12/conv.rs +++ b/wgpu-hal/src/dx12/conv.rs @@ -112,6 +112,7 @@ pub fn map_binding_type(ty: &wgt::BindingType) -> d3d12::DescriptorRangeType { .. } | Bt::StorageTexture { .. } => d3d12::DescriptorRangeType::UAV, + Bt::AccelerationStructure => todo!(), } } diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 24fea55663..0e4e065224 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -667,6 +667,7 @@ impl crate::Device for super::Device { num_texture_views += count } wgt::BindingType::Sampler { .. } => num_samplers += count, + wgt::BindingType::AccelerationStructure => todo!(), } } @@ -1199,6 +1200,7 @@ impl crate::Device for super::Device { cpu_samplers.as_mut().unwrap().stage.push(data.handle.raw); } } + wgt::BindingType::AccelerationStructure => todo!(), } } @@ -1586,4 +1588,61 @@ impl crate::Device for super::Device { .end_frame_capture(self.raw.as_mut_ptr() as *mut _, ptr::null_mut()) } } + + unsafe fn get_acceleration_structure_build_sizes( + &self, + _desc: &crate::GetAccelerationStructureBuildSizesDescriptor, + ) -> crate::AccelerationStructureBuildSizes { + // Implement using `GetRaytracingAccelerationStructurePrebuildInfo`: + // https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#getraytracingaccelerationstructureprebuildinfo + todo!() + } + + unsafe fn get_acceleration_structure_device_address( + &self, + _acceleration_structure: &super::AccelerationStructure, + ) -> wgt::BufferAddress { + // Implement using `GetGPUVirtualAddress`: + // https://docs.microsoft.com/en-us/windows/win32/api/d3d12/nf-d3d12-id3d12resource-getgpuvirtualaddress + todo!() + } + + unsafe fn create_acceleration_structure( + &self, + _desc: &crate::AccelerationStructureDescriptor, + ) -> Result { + // Create a D3D12 resource as per-usual. + todo!() + } + + unsafe fn destroy_acceleration_structure( + &self, + _acceleration_structure: super::AccelerationStructure, + ) { + // Destroy a D3D12 resource as per-usual. + todo!() + } + + unsafe fn create_ray_tracing_pipeline( + &self, + _desc: &crate::RayTracingPipelineDescriptor, + ) -> Result { + todo!() + } + + unsafe fn destroy_ray_tracing_pipeline(&self, _pipeline: super::RayTracingPipeline) { + todo!() + } + + fn assemble_sbt_data<'a>( + &self, + _handles: &'a [&'a [u8]], + _records: &'a [&'a [u8]], + ) -> crate::ShaderBindingTableData<'a> { + todo!() + } + + unsafe fn get_buffer_device_address(&self, _buffer: &super::Buffer) -> wgt::BufferAddress { + todo!() + } } diff --git a/wgpu-hal/src/dx12/mod.rs b/wgpu-hal/src/dx12/mod.rs index 390a5693aa..3978a3d87e 100644 --- a/wgpu-hal/src/dx12/mod.rs +++ b/wgpu-hal/src/dx12/mod.rs @@ -82,6 +82,9 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + type RayTracingPipeline = RayTracingPipeline; + + type AccelerationStructure = AccelerationStructure; } // Limited by D3D12's root signature size of 64. Each element takes 1 or 2 entries. @@ -607,6 +610,9 @@ pub struct ComputePipeline { unsafe impl Send for ComputePipeline {} unsafe impl Sync for ComputePipeline {} +#[derive(Debug)] +pub struct AccelerationStructure {} + impl SwapChain { unsafe fn release_resources(self) -> d3d12::WeakPtr { for resource in self.resources { @@ -899,3 +905,23 @@ impl crate::Queue for Queue { (1_000_000_000.0 / frequency as f64) as f32 } } + +pub struct RayTracingPipeline {} + +impl crate::RayTracingPipeline for RayTracingPipeline { + fn gen_handles<'a>(&'a self) -> Vec<&'a [u8]> { + todo!() + } + + fn miss_handles<'a>(&'a self) -> Vec<&'a [u8]> { + todo!() + } + + fn call_handles<'a>(&'a self) -> Vec<&'a [u8]> { + todo!() + } + + fn hit_handles<'a>(&'a self) -> Vec<&'a [u8]> { + todo!() + } +} diff --git a/wgpu-hal/src/empty.rs b/wgpu-hal/src/empty.rs index 1497acad91..c6784fb4ef 100644 --- a/wgpu-hal/src/empty.rs +++ b/wgpu-hal/src/empty.rs @@ -29,6 +29,7 @@ impl crate::Api for Api { type Sampler = Resource; type QuerySet = Resource; type Fence = Resource; + type AccelerationStructure = Resource; type BindGroupLayout = Resource; type BindGroup = Resource; @@ -36,6 +37,7 @@ impl crate::Api for Api { type ShaderModule = Resource; type RenderPipeline = Resource; type ComputePipeline = Resource; + type RayTracingPipeline = Resource; } impl crate::Instance for Context { @@ -236,6 +238,46 @@ impl crate::Device for Context { false } unsafe fn stop_capture(&self) {} + unsafe fn create_acceleration_structure( + &self, + desc: &crate::AccelerationStructureDescriptor, + ) -> DeviceResult { + Ok(Resource) + } + unsafe fn get_acceleration_structure_build_sizes( + &self, + _desc: &crate::GetAccelerationStructureBuildSizesDescriptor, + ) -> crate::AccelerationStructureBuildSizes { + Default::default() + } + unsafe fn get_acceleration_structure_device_address( + &self, + _acceleration_structure: &Resource, + ) -> wgt::BufferAddress { + Default::default() + } + unsafe fn destroy_acceleration_structure(&self, _acceleration_structure: Resource) {} + + unsafe fn create_ray_tracing_pipeline( + &self, + desc: &crate::RayTracingPipelineDescriptor, + ) -> Result { + Ok(Resource) + } + + unsafe fn destroy_ray_tracing_pipeline(&self, pipeline: Resource) {} + + fn assemble_sbt_data<'a>( + &self, + _handles: &'a [&'a [u8]], + _records: &'a [&'a [u8]], + ) -> crate::ShaderBindingTableData<'a> { + unimplemented!() + } + + unsafe fn get_buffer_device_address(&self, buffer: &Resource) -> wgt::BufferAddress { + Default::default() + } } impl crate::CommandEncoder for Encoder { @@ -410,4 +452,44 @@ impl crate::CommandEncoder for Encoder { unsafe fn dispatch(&mut self, count: [u32; 3]) {} unsafe fn dispatch_indirect(&mut self, buffer: &Resource, offset: wgt::BufferAddress) {} + + unsafe fn build_acceleration_structures( + &mut self, + _desc: &[&crate::BuildAccelerationStructureDescriptor], + ) { + } + + unsafe fn begin_ray_tracing_pass(&mut self, desc: &crate::RayTracingPassDescriptor) {} + + unsafe fn end_ray_tracing_pass(&mut self) {} + + unsafe fn set_ray_tracing_pipeline(&mut self, pipeline: &Resource) {} + + unsafe fn trace_rays( + &mut self, + ray_gen_sbt: &crate::ShaderBindingTableReference, + miss_sbt: &crate::ShaderBindingTableReference, + callable_sbt: &crate::ShaderBindingTableReference, + hit_sbt: &crate::ShaderBindingTableReference, + dimensions: [u32; 3], + ) { + } +} + +impl crate::RayTracingPipeline for Resource { + fn gen_handles<'a>(&'a self) -> Vec<&'a [u8]> { + vec![] + } + + fn miss_handles<'a>(&'a self) -> Vec<&'a [u8]> { + vec![] + } + + fn call_handles<'a>(&'a self) -> Vec<&'a [u8]> { + vec![] + } + + fn hit_handles<'a>(&'a self) -> Vec<&'a [u8]> { + vec![] + } } diff --git a/wgpu-hal/src/gles/command.rs b/wgpu-hal/src/gles/command.rs index 23cf8a36b3..bedc0080b9 100644 --- a/wgpu-hal/src/gles/command.rs +++ b/wgpu-hal/src/gles/command.rs @@ -1060,4 +1060,34 @@ impl crate::CommandEncoder for super::CommandEncoder { indirect_offset: offset, }); } + + unsafe fn build_acceleration_structures( + &mut self, + _desc: &[&crate::BuildAccelerationStructureDescriptor], + ) { + unimplemented!() + } + + unsafe fn begin_ray_tracing_pass(&mut self, _desc: &crate::RayTracingPassDescriptor) { + unimplemented!() + } + + unsafe fn end_ray_tracing_pass(&mut self) { + unimplemented!() + } + + unsafe fn set_ray_tracing_pipeline(&mut self, _pipeline: &super::RayTracingPipeline) { + unimplemented!() + } + + unsafe fn trace_rays( + &mut self, + _ray_gen_sbt: &crate::ShaderBindingTableReference, + _miss_sbt: &crate::ShaderBindingTableReference, + _callable_sbt: &crate::ShaderBindingTableReference, + _hit_sbt: &crate::ShaderBindingTableReference, + _dimensions: [u32; 3], + ) { + unimplemented!() + } } diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index d994aa1d56..9f76e4683d 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -980,6 +980,7 @@ impl crate::Device for super::Device { ty: wgt::BufferBindingType::Storage { .. }, .. } => &mut num_storage_buffers, + wgt::BindingType::AccelerationStructure => unimplemented!(), }; binding_to_slot[entry.binding as usize] = *counter; @@ -1065,6 +1066,7 @@ impl crate::Device for super::Device { format: format_desc.internal, }) } + wgt::BindingType::AccelerationStructure => unimplemented!(), }; contents.push(binding); } @@ -1316,6 +1318,48 @@ impl crate::Device for super::Device { .end_frame_capture(ptr::null_mut(), ptr::null_mut()) } } + unsafe fn create_acceleration_structure( + &self, + _desc: &crate::AccelerationStructureDescriptor, + ) -> Result<(), crate::DeviceError> { + unimplemented!() + } + unsafe fn get_acceleration_structure_build_sizes( + &self, + _desc: &crate::GetAccelerationStructureBuildSizesDescriptor, + ) -> crate::AccelerationStructureBuildSizes { + unimplemented!() + } + unsafe fn get_acceleration_structure_device_address( + &self, + _acceleration_structure: &(), + ) -> wgt::BufferAddress { + unimplemented!() + } + unsafe fn destroy_acceleration_structure(&self, _acceleration_structure: ()) {} + + unsafe fn create_ray_tracing_pipeline( + &self, + _desc: &crate::RayTracingPipelineDescriptor, + ) -> Result { + unimplemented!() + } + + unsafe fn destroy_ray_tracing_pipeline(&self, _pipeline: super::RayTracingPipeline) { + unimplemented!() + } + + fn assemble_sbt_data<'a>( + &self, + _handles: &'a [&'a [u8]], + _records: &'a [&'a [u8]], + ) -> crate::ShaderBindingTableData<'a> { + unimplemented!() + } + + unsafe fn get_buffer_device_address(&self, _buffer: &super::Buffer) -> wgt::BufferAddress { + unimplemented!() + } } // SAFE: WASM doesn't have threads diff --git a/wgpu-hal/src/gles/mod.rs b/wgpu-hal/src/gles/mod.rs index d196b8bc46..b2b9d27389 100644 --- a/wgpu-hal/src/gles/mod.rs +++ b/wgpu-hal/src/gles/mod.rs @@ -119,6 +119,7 @@ impl crate::Api for Api { type Sampler = Sampler; type QuerySet = QuerySet; type Fence = Fence; + type AccelerationStructure = (); type BindGroupLayout = BindGroupLayout; type BindGroup = BindGroup; @@ -126,6 +127,7 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + type RayTracingPipeline = RayTracingPipeline; } bitflags::bitflags! { @@ -857,3 +859,23 @@ impl fmt::Debug for CommandEncoder { .finish() } } + +pub struct RayTracingPipeline {} + +impl crate::RayTracingPipeline for RayTracingPipeline { + fn gen_handles<'a>(&'a self) -> Vec<&'a [u8]> { + unimplemented!() + } + + fn miss_handles<'a>(&'a self) -> Vec<&'a [u8]> { + unimplemented!() + } + + fn call_handles<'a>(&'a self) -> Vec<&'a [u8]> { + unimplemented!() + } + + fn hit_handles<'a>(&'a self) -> Vec<&'a [u8]> { + unimplemented!() + } +} diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 814c451f06..f1e53757a0 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -177,6 +177,9 @@ pub trait Api: Clone + Sized { type ShaderModule: fmt::Debug + Send + Sync; type RenderPipeline: Send + Sync; type ComputePipeline: Send + Sync; + type RayTracingPipeline: RayTracingPipeline + Send + Sync; + + type AccelerationStructure: fmt::Debug + Send + Sync + 'static; } pub trait Instance: Sized + Send + Sync { @@ -334,6 +337,37 @@ pub trait Device: Send + Sync { unsafe fn start_capture(&self) -> bool; unsafe fn stop_capture(&self); + + unsafe fn create_acceleration_structure( + &self, + desc: &AccelerationStructureDescriptor, + ) -> Result; + unsafe fn get_acceleration_structure_build_sizes( + &self, + desc: &GetAccelerationStructureBuildSizesDescriptor, + ) -> AccelerationStructureBuildSizes; + unsafe fn get_acceleration_structure_device_address( + &self, + acceleration_structure: &A::AccelerationStructure, + ) -> wgt::BufferAddress; + unsafe fn destroy_acceleration_structure( + &self, + acceleration_structure: A::AccelerationStructure, + ); + + unsafe fn create_ray_tracing_pipeline( + &self, + desc: &RayTracingPipelineDescriptor, + ) -> Result; + unsafe fn destroy_ray_tracing_pipeline(&self, pipeline: A::RayTracingPipeline); + + unsafe fn get_buffer_device_address(&self, buffer: &A::Buffer) -> wgt::BufferAddress; + + fn assemble_sbt_data<'a>( + &self, + handles: &'a [&'a [u8]], + records: &'a [&'a [u8]], + ) -> ShaderBindingTableData<'a>; } pub trait Queue: Send + Sync { @@ -548,6 +582,35 @@ pub trait CommandEncoder: Send + Sync + fmt::Debug { unsafe fn dispatch(&mut self, count: [u32; 3]); unsafe fn dispatch_indirect(&mut self, buffer: &A::Buffer, offset: wgt::BufferAddress); + + /// To get the required sizes for the buffer allocations use `get_acceleration_structure_build_sizes` per descriptor + /// All buffers must be synchronized externally + /// All buffer regions, which are written to may only be passed once per function call, + /// with the exertion of updates in the same descriptor. + /// Consequences of this limitation: + /// - scratch buffers need to be unique + /// - a tlas can't be build in the same call with a blas it contains + unsafe fn build_acceleration_structures( + &mut self, + descriptors: &[&BuildAccelerationStructureDescriptor], + ); + + // ray-tracing passes + + // Begins a ray-tracing pass, clears all active bindings. + unsafe fn begin_ray_tracing_pass(&mut self, desc: &RayTracingPassDescriptor); + unsafe fn end_ray_tracing_pass(&mut self); + + unsafe fn set_ray_tracing_pipeline(&mut self, pipeline: &A::RayTracingPipeline); + + unsafe fn trace_rays( + &mut self, + ray_gen_sbt: &ShaderBindingTableReference, + miss_sbt: &ShaderBindingTableReference, + callable_sbt: &ShaderBindingTableReference, + hit_sbt: &ShaderBindingTableReference, + dimensions: [u32; 3], + ); } bitflags!( @@ -709,12 +772,19 @@ bitflags::bitflags! { const STORAGE_READ_WRITE = 1 << 8; /// The indirect or count buffer in a indirect draw or dispatch. const INDIRECT = 1 << 9; + const ACCELERATION_STRUCTURE_SCRATCH = 1 << 10; + const BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT = 1 << 11; + const TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT = 1 << 12; + const SHADER_BINDING_TABLE = 1 << 13; /// The combination of states that a buffer may be in _at the same time_. const INCLUSIVE = Self::MAP_READ.bits | Self::COPY_SRC.bits | Self::INDEX.bits | Self::VERTEX.bits | Self::UNIFORM.bits | - Self::STORAGE_READ.bits | Self::INDIRECT.bits; + Self::STORAGE_READ.bits | Self::INDIRECT.bits | + Self::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT.bits | Self::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT.bits | + Self::SHADER_BINDING_TABLE.bits; /// The combination of states that a buffer must exclusively be in. - const EXCLUSIVE = Self::MAP_WRITE.bits | Self::COPY_DST.bits | Self::STORAGE_READ_WRITE.bits; + const EXCLUSIVE = Self::MAP_WRITE.bits | Self::COPY_DST.bits | + Self::STORAGE_READ_WRITE.bits | Self::ACCELERATION_STRUCTURE_SCRATCH.bits; /// The combination of all usages that the are guaranteed to be be ordered by the hardware. /// If a usage is ordered, then if the buffer state doesn't change between draw calls, there /// are no barriers needed for synchronization. @@ -1021,6 +1091,7 @@ pub struct BindGroupDescriptor<'a, A: Api> { pub samplers: &'a [&'a A::Sampler], pub textures: &'a [TextureBinding<'a, A>], pub entries: &'a [BindGroupEntry], + pub acceleration_structures: &'a [&'a A::AccelerationStructure], } #[derive(Clone, Debug)] @@ -1302,3 +1373,191 @@ fn test_default_limits() { let limits = wgt::Limits::default(); assert!(limits.max_bind_groups <= MAX_BIND_GROUPS as u32); } + +#[derive(Clone, Debug)] +pub struct AccelerationStructureDescriptor<'a> { + pub label: Label<'a>, + pub size: wgt::BufferAddress, + pub format: AccelerationStructureFormat, +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum AccelerationStructureFormat { + TopLevel, + BottomLevel, +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum AccelerationStructureBuildMode { + Build, + Update, +} + +/// Information of the required size for a corresponding entries struct (+ flags) +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] +pub struct AccelerationStructureBuildSizes { + pub acceleration_structure_size: wgt::BufferAddress, + pub update_scratch_size: wgt::BufferAddress, + pub build_scratch_size: wgt::BufferAddress, +} + +/// Updates use source_acceleration_structure if present, else the update will be performed in place. +/// For updates, only the data is allowed to change (not the meta data or sizes). +#[derive(Clone, Debug)] +pub struct BuildAccelerationStructureDescriptor<'a, A: Api> { + pub entries: &'a AccelerationStructureEntries<'a, A>, + pub mode: AccelerationStructureBuildMode, + pub flags: AccelerationStructureBuildFlags, + pub source_acceleration_structure: Option<&'a A::AccelerationStructure>, + pub destination_acceleration_structure: &'a A::AccelerationStructure, + pub scratch_buffer: &'a A::Buffer, +} + +/// - All buffers, buffer addresses and offsets will be ignored. +/// - The build mode will be ignored. +/// - Reducing the amount of Instances, Triangle groups or AABB groups (or the number of Triangles/AABBs in corresponding groups), +/// may result in reduced size requirements. +/// - Any other change may result in a bigger or smaller size requirement. +#[derive(Clone, Debug)] +pub struct GetAccelerationStructureBuildSizesDescriptor<'a, A: Api> { + pub entries: &'a AccelerationStructureEntries<'a, A>, + pub flags: AccelerationStructureBuildFlags, +} + +/// Entries for a single descriptor +/// * `Instances` - Multiple instances for a top level acceleration structure +/// * `Triangles` - Multiple triangle meshes for a bottom level acceleration structure +/// * `AABBs` - List of list of axis aligned bounding boxes for a bottom level acceleration structure +#[derive(Debug)] +pub enum AccelerationStructureEntries<'a, A: Api> { + Instances(AccelerationStructureInstances<'a, A>), + Triangles(&'a [AccelerationStructureTriangles<'a, A>]), + AABBs(&'a [AccelerationStructureAABBs<'a, A>]), +} + +/// * `first_vertex` - offset in the vertex buffer (as number of vertices) +/// * `indices` - optional index buffer with attributes +/// * `transform` - optional transform +#[derive(Clone, Debug)] +pub struct AccelerationStructureTriangles<'a, A: Api> { + pub vertex_buffer: Option<&'a A::Buffer>, + pub vertex_format: wgt::VertexFormat, + pub first_vertex: u32, + pub vertex_count: u32, + pub vertex_stride: wgt::BufferAddress, + pub indices: Option>, + pub transform: Option>, + pub flags: AccelerationStructureGeometryFlags, +} + +/// * `offset` - offset in bytes +#[derive(Clone, Debug)] +pub struct AccelerationStructureAABBs<'a, A: Api> { + pub buffer: Option<&'a A::Buffer>, + pub offset: u32, + pub count: u32, + pub stride: wgt::BufferAddress, + pub flags: AccelerationStructureGeometryFlags, +} + +/// * `offset` - offset in bytes +#[derive(Clone, Debug)] +pub struct AccelerationStructureInstances<'a, A: Api> { + pub buffer: Option<&'a A::Buffer>, + pub offset: u32, + pub count: u32, +} + +/// * `offset` - offset in bytes +#[derive(Clone, Debug)] +pub struct AccelerationStructureTriangleIndices<'a, A: Api> { + pub format: wgt::IndexFormat, + pub buffer: Option<&'a A::Buffer>, + pub offset: u32, + pub count: u32, +} + +/// * `offset` - offset in bytes +#[derive(Clone, Debug)] +pub struct AccelerationStructureTriangleTransform<'a, A: Api> { + pub buffer: &'a A::Buffer, + pub offset: u32, +} + +bitflags!( + /// Flags for acceleration structures + pub struct AccelerationStructureBuildFlags: u32 { + /// Allow for incremental updates (no change in size) + const ALLOW_UPDATE = 1 << 0; + /// Allow the acceleration structure to be compacted in a copy operation + const ALLOW_COMPACTION = 1 << 1; + /// Optimize for fast ray tracing performance + const PREFER_FAST_TRACE = 1 << 2; + /// Optimize for fast build time + const PREFER_FAST_BUILD = 1 << 3; + /// Optimize for low memory footprint (scratch and output) + const LOW_MEMORY = 1 << 4; + } +); + +bitflags!( + pub struct AccelerationStructureGeometryFlags: u32 { + const OPAQUE = 1 << 0; + const NO_DUPLICATE_ANY_HIT_INVOCATION = 1 << 1; + } +); + +#[derive(Clone, Debug)] +pub struct RayTracingHitShaderGroup<'a, A: Api> { + pub closest_hit: Option>, + pub any_hit: Option>, + pub intersection: Option>, +} + +#[derive(Copy, Clone, Debug, Default)] +pub enum SkipHitType { + #[default] + None, + Triangles, + Procedural, +} + +#[derive(Clone, Debug)] +pub struct RayTracingPipelineDescriptor<'a, A: Api> { + pub label: Label<'a>, + pub layout: &'a A::PipelineLayout, + pub max_recursion_depth: u32, + pub skip_hit_type: SkipHitType, + pub gen_groups: &'a [ProgrammableStage<'a, A>], + pub miss_groups: &'a [ProgrammableStage<'a, A>], + pub call_groups: &'a [ProgrammableStage<'a, A>], + pub hit_groups: &'a [RayTracingHitShaderGroup<'a, A>], +} + +/// unstable may change for dx12 implementation +pub trait RayTracingPipeline { + fn gen_handles<'a>(&'a self) -> Vec<&'a [u8]>; + fn miss_handles<'a>(&'a self) -> Vec<&'a [u8]>; + fn call_handles<'a>(&'a self) -> Vec<&'a [u8]>; + fn hit_handles<'a>(&'a self) -> Vec<&'a [u8]>; +} + +#[derive(Clone, Debug, Default)] +pub struct ShaderBindingTableReference { + pub address: wgt::BufferAddress, + pub stride: wgt::BufferAddress, + pub size: wgt::BufferAddress, +} + +#[derive(Clone, Debug)] +pub struct RayTracingPassDescriptor<'a> { + pub label: Label<'a>, +} + +pub struct ShaderBindingTableData<'a> { + pub data: Box + 'a>, + pub stride: wgt::BufferAddress, + pub count: wgt::BufferAddress, + pub size: wgt::BufferAddress, + pub padded_size: wgt::BufferAddress, +} diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 9303dd5395..493e9ecc50 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -971,4 +971,11 @@ impl crate::CommandEncoder for super::CommandEncoder { let encoder = self.state.compute.as_ref().unwrap(); encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.state.raw_wg_size); } + + unsafe fn build_acceleration_structure( + &mut self, + _desc: &crate::BuildAccelerationStructureDescriptor, + ) { + unimplemented!() + } } diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 52cc215126..7257d012fd 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -603,6 +603,7 @@ impl crate::Device for super::Device { wgt::StorageTextureAccess::ReadWrite => true, }; } + wgt::BindingType::AccelerationStructure => unimplemented!(), } let br = naga::ResourceBinding { @@ -766,6 +767,7 @@ impl crate::Device for super::Device { ); counter.textures += size; } + wgt::BindingType::AccelerationStructure => unimplemented!(), } } } @@ -1145,4 +1147,32 @@ impl crate::Device for super::Device { } shared_capture_manager.stop_capture(); } + + unsafe fn get_acceleration_structure_build_sizes( + &self, + _desc: &crate::GetAccelerationStructureBuildSizesDescriptor, + ) -> crate::AccelerationStructureBuildSizes { + unimplemented!() + } + + unsafe fn get_acceleration_structure_device_address( + &self, + _acceleration_structure: &super::AccelerationStructure, + ) -> wgt::BufferAddress { + unimplemented!() + } + + unsafe fn create_acceleration_structure( + &self, + _desc: &crate::AccelerationStructureDescriptor, + ) -> Result { + unimplemented!() + } + + unsafe fn destroy_acceleration_structure( + &self, + _acceleration_structure: super::AccelerationStructure, + ) { + unimplemented!() + } } diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 57083b585d..d6700995f3 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -60,6 +60,8 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + + type AccelerationStructure = AccelerationStructure; } pub struct Instance { @@ -795,3 +797,6 @@ pub struct CommandBuffer { unsafe impl Send for CommandBuffer {} unsafe impl Sync for CommandBuffer {} + +#[derive(Debug)] +pub struct AccelerationStructure; diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index 5efeed35e3..18ff6a673c 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -31,6 +31,10 @@ pub struct PhysicalDeviceFeatures { vk::PhysicalDeviceShaderFloat16Int8Features, vk::PhysicalDevice16BitStorageFeatures, )>, + acceleration_structure: Option, + buffer_device_address: Option, + ray_query: Option, + ray_tracing_pipeline: Option, zero_initialize_workgroup_memory: Option, } @@ -74,6 +78,18 @@ impl PhysicalDeviceFeatures { if let Some(ref mut feature) = self.zero_initialize_workgroup_memory { info = info.push_next(feature); } + if let Some(ref mut feature) = self.acceleration_structure { + info = info.push_next(feature); + } + if let Some(ref mut feature) = self.buffer_device_address { + info = info.push_next(feature); + } + if let Some(ref mut feature) = self.ray_query { + info = info.push_next(feature); + } + if let Some(ref mut feature) = self.ray_tracing_pipeline { + info = info.push_next(feature); + } info } @@ -291,6 +307,48 @@ impl PhysicalDeviceFeatures { } else { None }, + acceleration_structure: if enabled_extensions + .contains(&vk::KhrAccelerationStructureFn::name()) + { + Some( + vk::PhysicalDeviceAccelerationStructureFeaturesKHR::builder() + .acceleration_structure(true) + .build(), + ) + } else { + None + }, + buffer_device_address: if enabled_extensions + .contains(&vk::KhrBufferDeviceAddressFn::name()) + { + Some( + vk::PhysicalDeviceBufferDeviceAddressFeaturesKHR::builder() + .buffer_device_address(true) + .build(), + ) + } else { + None + }, + ray_query: if enabled_extensions.contains(&vk::KhrRayQueryFn::name()) { + Some( + vk::PhysicalDeviceRayQueryFeaturesKHR::builder() + .ray_query(true) + .build(), + ) + } else { + None + }, + ray_tracing_pipeline: if enabled_extensions + .contains(&vk::KhrRayTracingPipelineFn::name()) + { + Some( + vk::PhysicalDeviceRayTracingPipelineFeaturesKHR::builder() + .ray_tracing_pipeline(true) + .build(), + ) + } else { + None + }, zero_initialize_workgroup_memory: if effective_api_version >= vk::API_VERSION_1_3 || enabled_extensions.contains(&vk::KhrZeroInitializeWorkgroupMemoryFn::name()) { @@ -526,6 +584,19 @@ impl PhysicalDeviceFeatures { features.set(F::DEPTH32FLOAT_STENCIL8, texture_d32_s8); + features.set( + F::RAY_TRACING, + caps.supports_extension(vk::KhrDeferredHostOperationsFn::name()) + && caps.supports_extension(vk::KhrAccelerationStructureFn::name()) + && caps.supports_extension(vk::KhrBufferDeviceAddressFn::name()) + && caps.supports_extension(vk::KhrRayTracingPipelineFn::name()), + ); + + features.set( + F::RAY_QUERY, + caps.supports_extension(vk::KhrRayQueryFn::name()), + ); + (features, dl_flags) } @@ -540,12 +611,14 @@ impl PhysicalDeviceFeatures { } /// Information gathered about a physical device capabilities. -#[derive(Default)] +#[derive(Default, Debug)] pub struct PhysicalDeviceCapabilities { supported_extensions: Vec, properties: vk::PhysicalDeviceProperties, maintenance_3: Option, descriptor_indexing: Option, + acceleration_structure: Option, + ray_tracing_pipeline: Option, driver: Option, /// The effective driver api version supported by the physical device. /// @@ -684,6 +757,19 @@ impl PhysicalDeviceCapabilities { extensions.push(vk::KhrDrawIndirectCountFn::name()); } + // Require `VK_KHR_deferred_host_operations`, `VK_KHR_acceleration_structure` and `VK_KHR_buffer_device_address` if the feature `RAY_TRACING` was requested + if requested_features.contains(wgt::Features::RAY_TRACING) { + extensions.push(vk::KhrDeferredHostOperationsFn::name()); + extensions.push(vk::KhrAccelerationStructureFn::name()); + extensions.push(vk::KhrBufferDeviceAddressFn::name()); + extensions.push(vk::KhrRayTracingPipelineFn::name()) + } + + // Require `VK_KHR_ray_query` if the associated feature was requested + if requested_features.contains(wgt::Features::RAY_QUERY) { + extensions.push(vk::KhrRayQueryFn::name()); + } + // Require `VK_EXT_conservative_rasterization` if the associated feature was requested if requested_features.contains(wgt::Features::CONSERVATIVE_RASTERIZATION) { extensions.push(vk::ExtConservativeRasterizationFn::name()); @@ -798,6 +884,12 @@ impl super::InstanceShared { let supports_driver_properties = self.driver_api_version >= vk::API_VERSION_1_2 || capabilities.supports_extension(vk::KhrDriverPropertiesFn::name()); + let supports_acceleration_structure = + capabilities.supports_extension(vk::KhrAccelerationStructureFn::name()); + + let supports_ray_tracing_pipeline = + capabilities.supports_extension(vk::KhrRayTracingPipelineFn::name()); + let mut builder = vk::PhysicalDeviceProperties2KHR::builder(); if self.driver_api_version >= vk::API_VERSION_1_1 || capabilities.supports_extension(vk::KhrMaintenance3Fn::name()) @@ -814,6 +906,20 @@ impl super::InstanceShared { builder = builder.push_next(next); } + if supports_acceleration_structure { + let next = capabilities + .acceleration_structure + .insert(vk::PhysicalDeviceAccelerationStructurePropertiesKHR::default()); + builder = builder.push_next(next); + } + + if supports_ray_tracing_pipeline { + let next = capabilities + .ray_tracing_pipeline + .insert(vk::PhysicalDeviceRayTracingPipelinePropertiesKHR::default()); + builder = builder.push_next(next); + } + if supports_driver_properties { let next = capabilities .driver @@ -910,6 +1016,12 @@ impl super::InstanceShared { builder = builder.push_next(&mut next.0); builder = builder.push_next(&mut next.1); } + if capabilities.supports_extension(vk::KhrAccelerationStructureFn::name()) { + let next = features + .acceleration_structure + .insert(vk::PhysicalDeviceAccelerationStructureFeaturesKHR::default()); + builder = builder.push_next(next); + } // `VK_KHR_zero_initialize_workgroup_memory` is promoted to 1.3 if capabilities.effective_api_version >= vk::API_VERSION_1_3 @@ -1085,6 +1197,19 @@ impl super::Instance { .map_or(false, |ext| { ext.shader_zero_initialize_workgroup_memory == vk::TRUE }), + ray_tracing_device_properties: phd_capabilities.ray_tracing_pipeline.map(|x| { + super::RayTracingCapabilities { + shader_group_handle_size: x.shader_group_handle_size, + max_ray_recursion_depth: x.max_ray_recursion_depth, + max_shader_group_stride: x.max_shader_group_stride, + shader_group_base_alignment: x.shader_group_base_alignment, + shader_group_handle_capture_replay_size: x + .shader_group_handle_capture_replay_size, + max_ray_dispatch_invocation_count: x.max_ray_dispatch_invocation_count, + shader_group_handle_alignment: x.shader_group_handle_alignment, + max_ray_hit_attribute_size: x.max_ray_hit_attribute_size, + } + }), }; let capabilities = crate::Capabilities { limits: phd_capabilities.to_wgpu_limits(), @@ -1217,6 +1342,24 @@ impl super::Adapter { } else { None }; + let ray_tracing_fns = if enabled_extensions.contains(&khr::AccelerationStructure::name()) + && enabled_extensions.contains(&khr::BufferDeviceAddress::name()) + && enabled_extensions.contains(&khr::RayTracingPipeline::name()) + { + Some(super::RayTracingDeviceExtensionFunctions { + acceleration_structure: khr::AccelerationStructure::new( + &self.instance.raw, + &raw_device, + ), + buffer_device_address: khr::BufferDeviceAddress::new( + &self.instance.raw, + &raw_device, + ), + rt_pipeline: khr::RayTracingPipeline::new(&self.instance.raw, &raw_device), + }) + } else { + None + }; let naga_options = { use naga::back::spv; @@ -1317,6 +1460,7 @@ impl super::Adapter { extension_fns: super::DeviceExtensionFunctions { draw_indirect_count: indirect_count_fn, timeline_semaphore: timeline_semaphore_fn, + ray_tracing: ray_tracing_fns, }, vendor_id: self.phd_capabilities.properties.vendor_id, timestamp_period: self.phd_capabilities.properties.limits.timestamp_period, @@ -1372,7 +1516,8 @@ impl super::Adapter { size: memory_heap.size, }) .collect(), - buffer_device_address: false, + buffer_device_address: enabled_extensions + .contains(&khr::BufferDeviceAddress::name()), }; gpu_alloc::GpuAllocator::new(config, properties) }; @@ -1396,6 +1541,10 @@ impl super::Adapter { Ok(crate::OpenDevice { device, queue }) } + + pub fn ray_tracing_capabilities(&self) -> &Option { + &self.private_caps.ray_tracing_device_properties + } } impl crate::Adapter for super::Adapter { diff --git a/wgpu-hal/src/vulkan/command.rs b/wgpu-hal/src/vulkan/command.rs index f6c871026c..8bc339e267 100644 --- a/wgpu-hal/src/vulkan/command.rs +++ b/wgpu-hal/src/vulkan/command.rs @@ -366,6 +366,241 @@ impl crate::CommandEncoder for super::CommandEncoder { }; } + unsafe fn build_acceleration_structures( + &mut self, + descriptors: &[&crate::BuildAccelerationStructureDescriptor], + ) { + const CAPACITY_OUTER: usize = 8; + const CAPACITY_INNER: usize = 1; + + let ray_tracing_functions = match self.device.extension_fns.ray_tracing { + Some(ref functions) => functions, + None => panic!("Feature `RAY_TRACING` not enabled"), + }; + + let get_device_address = |buffer: Option<&super::Buffer>| unsafe { + match buffer { + Some(buffer) => ray_tracing_functions + .buffer_device_address + .get_buffer_device_address( + &vk::BufferDeviceAddressInfo::builder().buffer(buffer.raw), + ), + None => panic!("Buffers are required to build acceleration structures"), + } + }; + + // storage to all the data required for cmd_build_acceleration_structures + let mut ranges_storage = smallvec::SmallVec::< + [smallvec::SmallVec<[vk::AccelerationStructureBuildRangeInfoKHR; CAPACITY_INNER]>; + CAPACITY_OUTER], + >::with_capacity(descriptors.len()); + let mut geometries_storage = smallvec::SmallVec::< + [smallvec::SmallVec<[vk::AccelerationStructureGeometryKHR; CAPACITY_INNER]>; + CAPACITY_OUTER], + >::with_capacity(descriptors.len()); + + // pointers to all the data required for cmd_build_acceleration_structures + let mut geometry_infos = smallvec::SmallVec::< + [vk::AccelerationStructureBuildGeometryInfoKHR; CAPACITY_OUTER], + >::with_capacity(descriptors.len()); + let mut ranges_ptrs = smallvec::SmallVec::< + [&[vk::AccelerationStructureBuildRangeInfoKHR]; CAPACITY_OUTER], + >::with_capacity(descriptors.len()); + + for desc in descriptors { + let (geometries, ranges) = match *desc.entries { + crate::AccelerationStructureEntries::Instances(ref instances) => { + let instance_data = vk::AccelerationStructureGeometryInstancesDataKHR::builder( + ) + .data(vk::DeviceOrHostAddressConstKHR { + device_address: get_device_address(instances.buffer), + }); + + let geometry = vk::AccelerationStructureGeometryKHR::builder() + .geometry_type(vk::GeometryTypeKHR::INSTANCES) + .geometry(vk::AccelerationStructureGeometryDataKHR { + instances: *instance_data, + }); + + let range = vk::AccelerationStructureBuildRangeInfoKHR::builder() + .primitive_count(instances.count) + .primitive_offset(instances.offset); + + (smallvec::smallvec![*geometry], smallvec::smallvec![*range]) + } + crate::AccelerationStructureEntries::Triangles(in_geometries) => { + let mut ranges = smallvec::SmallVec::< + [vk::AccelerationStructureBuildRangeInfoKHR; CAPACITY_INNER], + >::with_capacity(in_geometries.len()); + let mut geometries = smallvec::SmallVec::< + [vk::AccelerationStructureGeometryKHR; CAPACITY_INNER], + >::with_capacity(in_geometries.len()); + for triangles in in_geometries { + let mut triangle_data = + vk::AccelerationStructureGeometryTrianglesDataKHR::builder() + .vertex_data(vk::DeviceOrHostAddressConstKHR { + device_address: get_device_address(triangles.vertex_buffer), + }) + .vertex_format(conv::map_vertex_format(triangles.vertex_format)) + .max_vertex(triangles.vertex_count) + .vertex_stride(triangles.vertex_stride); + + let mut range = vk::AccelerationStructureBuildRangeInfoKHR::builder(); + + if let Some(ref indices) = triangles.indices { + triangle_data = triangle_data + .index_data(vk::DeviceOrHostAddressConstKHR { + device_address: get_device_address(indices.buffer), + }) + .index_type(conv::map_index_format(indices.format)); + + range = range + .primitive_count(indices.count / 3) + .primitive_offset(indices.offset) + .first_vertex(triangles.first_vertex); + } else { + range = range + .primitive_count(triangles.vertex_count) + .first_vertex(triangles.first_vertex); + } + + if let Some(ref transform) = triangles.transform { + let transform_device_address = unsafe { + ray_tracing_functions + .buffer_device_address + .get_buffer_device_address( + &vk::BufferDeviceAddressInfo::builder() + .buffer(transform.buffer.raw), + ) + }; + triangle_data = + triangle_data.transform_data(vk::DeviceOrHostAddressConstKHR { + device_address: transform_device_address, + }); + + range = range.transform_offset(transform.offset); + } + + let geometry = vk::AccelerationStructureGeometryKHR::builder() + .geometry_type(vk::GeometryTypeKHR::TRIANGLES) + .geometry(vk::AccelerationStructureGeometryDataKHR { + triangles: *triangle_data, + }) + .flags(conv::map_acceleration_structure_geomety_flags( + triangles.flags, + )); + + geometries.push(*geometry); + ranges.push(*range); + } + (geometries, ranges) + } + crate::AccelerationStructureEntries::AABBs(in_geometries) => { + let mut ranges = smallvec::SmallVec::< + [vk::AccelerationStructureBuildRangeInfoKHR; CAPACITY_INNER], + >::with_capacity(in_geometries.len()); + let mut geometries = smallvec::SmallVec::< + [vk::AccelerationStructureGeometryKHR; CAPACITY_INNER], + >::with_capacity(in_geometries.len()); + for aabb in in_geometries { + let aabbs_data = vk::AccelerationStructureGeometryAabbsDataKHR::builder() + .data(vk::DeviceOrHostAddressConstKHR { + device_address: get_device_address(aabb.buffer), + }) + .stride(aabb.stride); + + let range = vk::AccelerationStructureBuildRangeInfoKHR::builder() + .primitive_count(aabb.count) + .primitive_offset(aabb.offset); + + let geometry = vk::AccelerationStructureGeometryKHR::builder() + .geometry_type(vk::GeometryTypeKHR::AABBS) + .geometry(vk::AccelerationStructureGeometryDataKHR { + aabbs: *aabbs_data, + }) + .flags(conv::map_acceleration_structure_geomety_flags(aabb.flags)); + + geometries.push(*geometry); + ranges.push(*range); + } + (geometries, ranges) + } + }; + + ranges_storage.push(ranges); + geometries_storage.push(geometries); + } + + for (i, desc) in descriptors.iter().enumerate() { + let scratch_device_address = unsafe { + ray_tracing_functions + .buffer_device_address + .get_buffer_device_address( + &vk::BufferDeviceAddressInfo::builder().buffer(desc.scratch_buffer.raw), + ) + }; + let ty = match *desc.entries { + crate::AccelerationStructureEntries::Instances(_) => { + vk::AccelerationStructureTypeKHR::TOP_LEVEL + } + _ => vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL, + }; + let mut geometry_info = vk::AccelerationStructureBuildGeometryInfoKHR::builder() + .ty(ty) + .mode(conv::map_acceleration_structure_build_mode(desc.mode)) + .flags(conv::map_acceleration_structure_flags(desc.flags)) + .geometries(&geometries_storage[i]) // pointer must live + .dst_acceleration_structure(desc.destination_acceleration_structure.raw) + .scratch_data(vk::DeviceOrHostAddressKHR { + device_address: scratch_device_address, + }); + + if desc.mode == crate::AccelerationStructureBuildMode::Update { + geometry_info.src_acceleration_structure = desc + .source_acceleration_structure + .unwrap_or(desc.destination_acceleration_structure) + .raw; + } + + geometry_infos.push(*geometry_info); + ranges_ptrs.push(&ranges_storage[i]); + } + + // let mut geometry_infos = + // Vec::::with_capacity(descriptors.len()); + + // let mut ranges_vec = + // Vec::<&[vk::AccelerationStructureBuildRangeInfoKHR]>::with_capacity(descriptors.len()); + + // let mut ranges_storage = + // Vec::>::with_capacity(descriptors.len()); + + // for desc in descriptors { + // let (ranges, geometry_info) = prepare_geometry_info_and_ranges(desc); + // geometry_infos.push(geometry_info); + // ranges_storage.push(ranges); + + // } + + // for i in 0..descriptors.len() { + // ranges_vec.push(&ranges_storage[i]); + // } + + // let (ranges, geometry_info) = prepare_geometry_info_and_ranges(descriptors[0]); + + unsafe { + ray_tracing_functions + .acceleration_structure + .cmd_build_acceleration_structures(self.active, &geometry_infos, &ranges_ptrs); + } + + // unsafe { + // ray_tracing_functions + // .acceleration_structure + // .cmd_build_acceleration_structures(self.active, &geometry_infos, &ranges_vec); + // } + } + // render unsafe fn begin_render_pass(&mut self, desc: &crate::RenderPassDescriptor) { @@ -815,6 +1050,74 @@ impl crate::CommandEncoder for super::CommandEncoder { .cmd_dispatch_indirect(self.active, buffer.raw, offset) } } + + unsafe fn begin_ray_tracing_pass(&mut self, desc: &crate::RayTracingPassDescriptor) { + self.bind_point = vk::PipelineBindPoint::RAY_TRACING_KHR; + if let Some(label) = desc.label { + unsafe { self.begin_debug_marker(label) }; + self.rpass_debug_marker_active = true; + } + } + + unsafe fn end_ray_tracing_pass(&mut self) { + if self.rpass_debug_marker_active { + unsafe { self.end_debug_marker() }; + self.rpass_debug_marker_active = false + } + } + + unsafe fn set_ray_tracing_pipeline(&mut self, pipeline: &super::RayTracingPipeline) { + unsafe { + self.device.raw.cmd_bind_pipeline( + self.active, + vk::PipelineBindPoint::RAY_TRACING_KHR, + pipeline.raw, + ) + }; + } + + unsafe fn trace_rays( + &mut self, + ray_gen_sbt: &crate::ShaderBindingTableReference, + miss_sbt: &crate::ShaderBindingTableReference, + callable_sbt: &crate::ShaderBindingTableReference, + hit_sbt: &crate::ShaderBindingTableReference, + dimensions: [u32; 3], + ) { + let ray_tracing_functions = match self.device.extension_fns.ray_tracing { + Some(ref functions) => functions, + None => panic!("Feature `RAY_TRACING` not enabled"), + }; + + unsafe { + ray_tracing_functions.rt_pipeline.cmd_trace_rays( + self.active, + &vk::StridedDeviceAddressRegionKHR { + device_address: ray_gen_sbt.address, + stride: ray_gen_sbt.stride, + size: ray_gen_sbt.stride, // intentional + }, + &vk::StridedDeviceAddressRegionKHR { + device_address: miss_sbt.address, + stride: miss_sbt.stride, + size: miss_sbt.size, + }, + &vk::StridedDeviceAddressRegionKHR { + device_address: hit_sbt.address, + stride: hit_sbt.stride, + size: hit_sbt.size, + }, + &vk::StridedDeviceAddressRegionKHR { + device_address: callable_sbt.address, + stride: callable_sbt.stride, + size: callable_sbt.size, + }, + dimensions[0], + dimensions[1], + dimensions[2], + ) + }; + } } #[test] diff --git a/wgpu-hal/src/vulkan/conv.rs b/wgpu-hal/src/vulkan/conv.rs index a91479a835..2ad53475ce 100644 --- a/wgpu-hal/src/vulkan/conv.rs +++ b/wgpu-hal/src/vulkan/conv.rs @@ -508,6 +508,20 @@ pub fn map_buffer_usage(usage: crate::BufferUses) -> vk::BufferUsageFlags { if usage.contains(crate::BufferUses::INDIRECT) { flags |= vk::BufferUsageFlags::INDIRECT_BUFFER; } + if usage.contains(crate::BufferUses::ACCELERATION_STRUCTURE_SCRATCH) { + flags |= vk::BufferUsageFlags::STORAGE_BUFFER | vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS; + } + if usage.intersects( + crate::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT + | crate::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, + ) { + flags |= vk::BufferUsageFlags::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_KHR + | vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS; + } + if usage.contains(crate::BufferUses::SHADER_BINDING_TABLE) { + flags |= vk::BufferUsageFlags::SHADER_BINDING_TABLE_KHR + | vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS; + } flags } @@ -560,6 +574,15 @@ pub fn map_buffer_usage_to_barrier( stages |= vk::PipelineStageFlags::DRAW_INDIRECT; access |= vk::AccessFlags::INDIRECT_COMMAND_READ; } + if usage.intersects( + crate::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT + | crate::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT + | crate::BufferUses::ACCELERATION_STRUCTURE_SCRATCH, + ) { + stages |= vk::PipelineStageFlags::ACCELERATION_STRUCTURE_BUILD_KHR; + access |= vk::AccessFlags::ACCELERATION_STRUCTURE_READ_KHR + | vk::AccessFlags::ACCELERATION_STRUCTURE_WRITE_KHR; + } (stages, access) } @@ -674,6 +697,24 @@ pub fn map_shader_stage(stage: wgt::ShaderStages) -> vk::ShaderStageFlags { if stage.contains(wgt::ShaderStages::COMPUTE) { flags |= vk::ShaderStageFlags::COMPUTE; } + if stage.contains(wgt::ShaderStages::RAYGEN) { + flags |= vk::ShaderStageFlags::RAYGEN_KHR; + } + if stage.contains(wgt::ShaderStages::MISS) { + flags |= vk::ShaderStageFlags::MISS_KHR; + } + if stage.contains(wgt::ShaderStages::CALLABLE) { + flags |= vk::ShaderStageFlags::CALLABLE_KHR; + } + if stage.contains(wgt::ShaderStages::CLOSEST_HIT) { + flags |= vk::ShaderStageFlags::CLOSEST_HIT_KHR; + } + if stage.contains(wgt::ShaderStages::ANY_HIT) { + flags |= vk::ShaderStageFlags::ANY_HIT_KHR; + } + if stage.contains(wgt::ShaderStages::INTERSECTION) { + flags |= vk::ShaderStageFlags::INTERSECTION_KHR; + } flags } @@ -696,6 +737,7 @@ pub fn map_binding_type(ty: wgt::BindingType) -> vk::DescriptorType { wgt::BindingType::Sampler { .. } => vk::DescriptorType::SAMPLER, wgt::BindingType::Texture { .. } => vk::DescriptorType::SAMPLED_IMAGE, wgt::BindingType::StorageTexture { .. } => vk::DescriptorType::STORAGE_IMAGE, + wgt::BindingType::AccelerationStructure => vk::DescriptorType::ACCELERATION_STRUCTURE_KHR, } } @@ -823,3 +865,71 @@ pub fn map_pipeline_statistics( } flags } + +pub fn map_acceleration_structure_format( + format: crate::AccelerationStructureFormat, +) -> vk::AccelerationStructureTypeKHR { + match format { + crate::AccelerationStructureFormat::TopLevel => vk::AccelerationStructureTypeKHR::TOP_LEVEL, + crate::AccelerationStructureFormat::BottomLevel => { + vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL + } + } +} + +pub fn map_acceleration_structure_build_mode( + format: crate::AccelerationStructureBuildMode, +) -> vk::BuildAccelerationStructureModeKHR { + match format { + crate::AccelerationStructureBuildMode::Build => { + vk::BuildAccelerationStructureModeKHR::BUILD + } + crate::AccelerationStructureBuildMode::Update => { + vk::BuildAccelerationStructureModeKHR::UPDATE + } + } +} + +pub fn map_acceleration_structure_flags( + flags: crate::AccelerationStructureBuildFlags, +) -> vk::BuildAccelerationStructureFlagsKHR { + let mut vk_flags = vk::BuildAccelerationStructureFlagsKHR::empty(); + + if flags.contains(crate::AccelerationStructureBuildFlags::PREFER_FAST_TRACE) { + vk_flags |= vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE; + } + + if flags.contains(crate::AccelerationStructureBuildFlags::PREFER_FAST_BUILD) { + vk_flags |= vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_BUILD; + } + + if flags.contains(crate::AccelerationStructureBuildFlags::ALLOW_UPDATE) { + vk_flags |= vk::BuildAccelerationStructureFlagsKHR::ALLOW_UPDATE; + } + + if flags.contains(crate::AccelerationStructureBuildFlags::LOW_MEMORY) { + vk_flags |= vk::BuildAccelerationStructureFlagsKHR::LOW_MEMORY; + } + + if flags.contains(crate::AccelerationStructureBuildFlags::ALLOW_COMPACTION) { + vk_flags |= vk::BuildAccelerationStructureFlagsKHR::ALLOW_COMPACTION + } + + vk_flags +} + +pub fn map_acceleration_structure_geomety_flags( + flags: crate::AccelerationStructureGeometryFlags, +) -> vk::GeometryFlagsKHR { + let mut vk_flags = vk::GeometryFlagsKHR::empty(); + + if flags.contains(crate::AccelerationStructureGeometryFlags::OPAQUE) { + vk_flags |= vk::GeometryFlagsKHR::OPAQUE; + } + + if flags.contains(crate::AccelerationStructureGeometryFlags::NO_DUPLICATE_ANY_HIT_INVOCATION) { + vk_flags |= vk::GeometryFlagsKHR::NO_DUPLICATE_ANY_HIT_INVOCATION; + } + + vk_flags +} diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 1d10d69b0a..b11d67448f 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -762,6 +762,34 @@ impl super::Device { }) } + fn compile_stage_temp_ray_tracing( + &self, + stage: &crate::ProgrammableStage, + stage_flags: wgt::ShaderStages, + _binding_map: &naga::back::spv::BindingMap, + ) -> Result { + let vk_module = match *stage.module { + super::ShaderModule::Raw(raw) => raw, + _ => unimplemented!("naga support for ray tracing shaders not yet implemented"), + }; + + let entry_point = CString::new(stage.entry_point).unwrap(); + let create_info = vk::PipelineShaderStageCreateInfo::builder() + .stage(conv::map_shader_stage(stage_flags)) + .module(vk_module) + .name(&entry_point) + .build(); + + Ok(CompiledStage { + create_info, + _entry_point: entry_point, + temp_raw_module: match *stage.module { + super::ShaderModule::Raw(_) => None, + super::ShaderModule::Intermediate { .. } => Some(vk_module), + }, + }) + } + /// Returns the queue family index of the device's internal queue. /// /// This is useful for constructing memory barriers needed for queue family ownership transfer when @@ -841,12 +869,30 @@ impl crate::Device for super::Device { desc.memory_flags.contains(crate::MemoryFlags::TRANSIENT), ); + let mut alignment = req.alignment; + if desc + .usage + .contains(crate::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT) + { + alignment = std::cmp::max(alignment, 16) + } + if desc.usage.contains(crate::BufferUses::SHADER_BINDING_TABLE) { + alignment = std::cmp::max( + alignment, + self.shared + .private_caps + .ray_tracing_device_properties + .as_ref() + .expect("Feature `RAY_TRACING` not enabled") + .shader_group_base_alignment as u64, + ) + } let block = unsafe { self.mem_allocator.lock().alloc( &*self.shared, gpu_alloc::Request { size: req.size, - align_mask: req.alignment - 1, + align_mask: alignment - 1, usage: alloc_usage, memory_types: req.memory_type_bits & self.valid_ash_memory_types, }, @@ -1239,6 +1285,9 @@ impl crate::Device for super::Device { wgt::BindingType::StorageTexture { .. } => { desc_count.storage_image += count; } + wgt::BindingType::AccelerationStructure => { + desc_count.acceleration_structure += count; + } } } @@ -1413,6 +1462,10 @@ impl crate::Device for super::Device { let mut buffer_infos = Vec::with_capacity(desc.buffers.len()); let mut sampler_infos = Vec::with_capacity(desc.samplers.len()); let mut image_infos = Vec::with_capacity(desc.textures.len()); + let mut acceleration_structure_infos = + Vec::with_capacity(desc.acceleration_structures.len()); + let mut raw_acceleration_structures = + Vec::with_capacity(desc.acceleration_structures.len()); for entry in desc.entries { let (ty, size) = desc.layout.types[entry.binding as usize]; if size == 0 { @@ -1422,6 +1475,9 @@ impl crate::Device for super::Device { .dst_set(*set.raw()) .dst_binding(entry.binding) .descriptor_type(ty); + + let mut extra_descriptor_count = 0; + write = match ty { vk::DescriptorType::SAMPLER => { let index = sampler_infos.len(); @@ -1472,9 +1528,44 @@ impl crate::Device for super::Device { )); write.buffer_info(&buffer_infos[index..]) } + vk::DescriptorType::ACCELERATION_STRUCTURE_KHR => { + let index = acceleration_structure_infos.len(); + let start = entry.resource_index; + let end = start + entry.count; + + let raw_start = raw_acceleration_structures.len(); + + raw_acceleration_structures.extend( + desc.acceleration_structures[start as usize..end as usize] + .iter() + .map(|acceleration_structure| acceleration_structure.raw), + ); + + let acceleration_structure_info = + vk::WriteDescriptorSetAccelerationStructureKHR::builder() + .acceleration_structures(&raw_acceleration_structures[raw_start..]); + + // todo: Dereference the struct to get around lifetime issues. Safe as long as we never resize + // `raw_acceleration_structures`. + let acceleration_structure_info: vk::WriteDescriptorSetAccelerationStructureKHR = *acceleration_structure_info; + + assert!( + index < desc.acceleration_structures.len(), + "Encountered more acceleration structures then expected" + ); + acceleration_structure_infos.push(acceleration_structure_info); + + extra_descriptor_count += 1; + + write.push_next(&mut acceleration_structure_infos[index]) + } _ => unreachable!(), }; - writes.push(write.build()); + + let mut write = write.build(); + write.descriptor_count += extra_descriptor_count; + + writes.push(write); } unsafe { self.shared.raw.update_descriptor_sets(&writes, &[]) }; @@ -2004,6 +2095,441 @@ impl crate::Device for super::Device { } } } + + unsafe fn get_acceleration_structure_build_sizes( + &self, + desc: &crate::GetAccelerationStructureBuildSizesDescriptor, + ) -> crate::AccelerationStructureBuildSizes { + const CAPACITY: usize = 8; + + let ray_tracing_functions = match self.shared.extension_fns.ray_tracing { + Some(ref functions) => functions, + None => panic!("Feature `RAY_TRACING` not enabled"), + }; + + let (geometries, primitive_counts) = match *desc.entries { + crate::AccelerationStructureEntries::Instances(ref instances) => { + let instance_data = vk::AccelerationStructureGeometryInstancesDataKHR::default(); + + let geometry = vk::AccelerationStructureGeometryKHR::builder() + .geometry_type(vk::GeometryTypeKHR::INSTANCES) + .geometry(vk::AccelerationStructureGeometryDataKHR { + instances: instance_data, + }); + + ( + smallvec::smallvec![*geometry], + smallvec::smallvec![instances.count], + ) + } + crate::AccelerationStructureEntries::Triangles(in_geometries) => { + let mut primitive_counts = + smallvec::SmallVec::<[u32; CAPACITY]>::with_capacity(in_geometries.len()); + let mut geometries = smallvec::SmallVec::< + [vk::AccelerationStructureGeometryKHR; CAPACITY], + >::with_capacity(in_geometries.len()); + + for triangles in in_geometries { + let mut triangle_data = + vk::AccelerationStructureGeometryTrianglesDataKHR::builder() + .vertex_format(conv::map_vertex_format(triangles.vertex_format)) + .max_vertex(triangles.vertex_count) + .vertex_stride(triangles.vertex_stride); + + let pritive_count = if let Some(ref indices) = triangles.indices { + triangle_data = + triangle_data.index_type(conv::map_index_format(indices.format)); + indices.count / 3 + } else { + triangles.vertex_count + }; + + let geometry = vk::AccelerationStructureGeometryKHR::builder() + .geometry_type(vk::GeometryTypeKHR::TRIANGLES) + .geometry(vk::AccelerationStructureGeometryDataKHR { + triangles: *triangle_data, + }) + .flags(conv::map_acceleration_structure_geomety_flags( + triangles.flags, + )); + + geometries.push(*geometry); + primitive_counts.push(pritive_count); + } + (geometries, primitive_counts) + } + crate::AccelerationStructureEntries::AABBs(in_geometries) => { + let mut primitive_counts = + smallvec::SmallVec::<[u32; CAPACITY]>::with_capacity(in_geometries.len()); + let mut geometries = smallvec::SmallVec::< + [vk::AccelerationStructureGeometryKHR; CAPACITY], + >::with_capacity(in_geometries.len()); + for aabb in in_geometries { + let aabbs_data = vk::AccelerationStructureGeometryAabbsDataKHR::builder() + .stride(aabb.stride); + + let geometry = vk::AccelerationStructureGeometryKHR::builder() + .geometry_type(vk::GeometryTypeKHR::AABBS) + .geometry(vk::AccelerationStructureGeometryDataKHR { aabbs: *aabbs_data }) + .flags(conv::map_acceleration_structure_geomety_flags(aabb.flags)); + + geometries.push(*geometry); + primitive_counts.push(aabb.count); + } + (geometries, primitive_counts) + } + }; + + let ty = match *desc.entries { + crate::AccelerationStructureEntries::Instances(_) => { + vk::AccelerationStructureTypeKHR::TOP_LEVEL + } + _ => vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL, + }; + + let geometry_info = vk::AccelerationStructureBuildGeometryInfoKHR::builder() + .ty(ty) + .flags(conv::map_acceleration_structure_flags(desc.flags)) + .geometries(&geometries); + + let raw = unsafe { + ray_tracing_functions + .acceleration_structure + .get_acceleration_structure_build_sizes( + vk::AccelerationStructureBuildTypeKHR::DEVICE, + &geometry_info, + &primitive_counts, + ) + }; + + crate::AccelerationStructureBuildSizes { + acceleration_structure_size: raw.acceleration_structure_size, + update_scratch_size: raw.update_scratch_size, + build_scratch_size: raw.build_scratch_size, + } + } + + unsafe fn get_acceleration_structure_device_address( + &self, + acceleration_structure: &super::AccelerationStructure, + ) -> wgt::BufferAddress { + let ray_tracing_functions = match self.shared.extension_fns.ray_tracing { + Some(ref functions) => functions, + None => panic!("Feature `RAY_TRACING` not enabled"), + }; + + unsafe { + ray_tracing_functions + .acceleration_structure + .get_acceleration_structure_device_address( + &vk::AccelerationStructureDeviceAddressInfoKHR::builder() + .acceleration_structure(acceleration_structure.raw), + ) + } + } + + unsafe fn create_acceleration_structure( + &self, + desc: &crate::AccelerationStructureDescriptor, + ) -> Result { + let ray_tracing_functions = match self.shared.extension_fns.ray_tracing { + Some(ref functions) => functions, + None => panic!("Feature `RAY_TRACING` not enabled"), + }; + + let vk_buffer_info = vk::BufferCreateInfo::builder() + .size(desc.size) + .usage(vk::BufferUsageFlags::ACCELERATION_STRUCTURE_STORAGE_KHR) + .sharing_mode(vk::SharingMode::EXCLUSIVE); + + unsafe { + let raw_buffer = self.shared.raw.create_buffer(&vk_buffer_info, None)?; + let req = self.shared.raw.get_buffer_memory_requirements(raw_buffer); + + let block = self.mem_allocator.lock().alloc( + &*self.shared, + gpu_alloc::Request { + size: req.size, + align_mask: req.alignment - 1, + usage: gpu_alloc::UsageFlags::FAST_DEVICE_ACCESS, + memory_types: req.memory_type_bits & self.valid_ash_memory_types, + }, + )?; + + self.shared + .raw + .bind_buffer_memory(raw_buffer, *block.memory(), block.offset())?; + + if let Some(label) = desc.label { + self.shared + .set_object_name(vk::ObjectType::BUFFER, raw_buffer, label); + } + + let vk_info = vk::AccelerationStructureCreateInfoKHR::builder() + .buffer(raw_buffer) + .offset(0) + .size(desc.size) + .ty(conv::map_acceleration_structure_format(desc.format)); + + let raw_acceleration_structure = ray_tracing_functions + .acceleration_structure + .create_acceleration_structure(&vk_info, None)?; + + if let Some(label) = desc.label { + self.shared.set_object_name( + vk::ObjectType::ACCELERATION_STRUCTURE_KHR, + raw_acceleration_structure, + label, + ); + } + + Ok(super::AccelerationStructure { + raw: raw_acceleration_structure, + buffer: raw_buffer, + block: Mutex::new(block), + }) + } + } + + unsafe fn destroy_acceleration_structure( + &self, + acceleration_structure: super::AccelerationStructure, + ) { + let ray_tracing_functions = match self.shared.extension_fns.ray_tracing { + Some(ref functions) => functions, + None => panic!("Feature `RAY_TRACING` not enabled"), + }; + + unsafe { + ray_tracing_functions + .acceleration_structure + .destroy_acceleration_structure(acceleration_structure.raw, None); + self.shared + .raw + .destroy_buffer(acceleration_structure.buffer, None); + self.mem_allocator + .lock() + .dealloc(&*self.shared, acceleration_structure.block.into_inner()); + } + } + + unsafe fn create_ray_tracing_pipeline( + &self, + desc: &crate::RayTracingPipelineDescriptor, + ) -> Result { + let ray_tracing_functions = match self.shared.extension_fns.ray_tracing { + Some(ref functions) => functions, + None => panic!("Feature `RAY_TRACING` not enabled"), + }; + + let mut compiled_storage = Vec::::new(); + let mut get_create_info = |stage, stage_flags| -> Result<_, crate::PipelineError> { + let t = self.compile_stage_temp_ray_tracing( + stage, + stage_flags, + &desc.layout.binding_arrays, + )?; + compiled_storage.push(t); + Ok(compiled_storage.last().unwrap().create_info) + }; + + // Future work: don't add the same shader to stages multiple times + let mut stages = Vec::::new(); + let mut groups = Vec::::new(); + + let mut next_shader_index = 0; + + for (entries, stage_flags) in [ + (desc.gen_groups, wgt::ShaderStages::RAYGEN), + (desc.miss_groups, wgt::ShaderStages::MISS), + (desc.call_groups, wgt::ShaderStages::CALLABLE), + ] { + for programmable_stage in entries { + let group = vk::RayTracingShaderGroupCreateInfoKHR::builder() + .ty(vk::RayTracingShaderGroupTypeKHR::GENERAL) + .general_shader(next_shader_index) + .any_hit_shader(vk::SHADER_UNUSED_KHR) + .closest_hit_shader(vk::SHADER_UNUSED_KHR) + .intersection_shader(vk::SHADER_UNUSED_KHR); + next_shader_index += 1; + + stages.push(get_create_info(programmable_stage, stage_flags)?); + groups.push(*group); + } + } + + for entry in desc.hit_groups { + let mut group = vk::RayTracingShaderGroupCreateInfoKHR::builder() + .ty(match entry.intersection { + None => vk::RayTracingShaderGroupTypeKHR::TRIANGLES_HIT_GROUP, + Some(_) => vk::RayTracingShaderGroupTypeKHR::PROCEDURAL_HIT_GROUP, + }) + .general_shader(vk::SHADER_UNUSED_KHR); + + if let Some(ref stage) = entry.closest_hit { + stages.push(get_create_info(stage, wgt::ShaderStages::CLOSEST_HIT)?); + group = group.closest_hit_shader(next_shader_index); + next_shader_index += 1; + } else { + group = group.closest_hit_shader(vk::SHADER_UNUSED_KHR); + } + if let Some(ref stage) = entry.any_hit { + stages.push(get_create_info(stage, wgt::ShaderStages::ANY_HIT)?); + group = group.any_hit_shader(next_shader_index); + next_shader_index += 1; + } else { + group = group.any_hit_shader(vk::SHADER_UNUSED_KHR); + } + if let Some(ref stage) = entry.intersection { + stages.push(get_create_info(stage, wgt::ShaderStages::INTERSECTION)?); + group = group.intersection_shader(next_shader_index); + next_shader_index += 1; + } else { + group = group.intersection_shader(vk::SHADER_UNUSED_KHR); + } + + groups.push(*group); + } + + let create_info = vk::RayTracingPipelineCreateInfoKHR::builder() + .stages(&stages) + .groups(&groups) + .max_pipeline_ray_recursion_depth(desc.max_recursion_depth) + .layout(desc.layout.raw); + + let raw = unsafe { + ray_tracing_functions + .rt_pipeline + .create_ray_tracing_pipelines( + vk::DeferredOperationKHR::null(), + vk::PipelineCache::null(), + &[*create_info], + None, + ) + .map_err(crate::DeviceError::from)?[0] + }; + + let handle_size = self + .shared + .private_caps + .ray_tracing_device_properties + .as_ref() + .unwrap() + .shader_group_handle_size as usize; + + println!( + "{:?}", + self.shared + .private_caps + .ray_tracing_device_properties + .as_ref() + .unwrap() + ); + + let handle_data = unsafe { + ray_tracing_functions + .rt_pipeline + .get_ray_tracing_shader_group_handles( + raw, + 0, + groups.len() as u32, + handle_size * groups.len(), + ) + } + .map_err(crate::DeviceError::from)?; + + let mut range_acc = 0; + + let ranges = [ + 0, + desc.gen_groups.len(), + desc.miss_groups.len(), + desc.call_groups.len(), + desc.hit_groups.len(), + ] + .map(|x| { + range_acc += x * handle_size; + range_acc + }); + + Ok(super::RayTracingPipeline { + raw, + handle_data, + handle_size, + ranges, + }) + } + + unsafe fn destroy_ray_tracing_pipeline(&self, pipeline: super::RayTracingPipeline) { + unsafe { self.shared.raw.destroy_pipeline(pipeline.raw, None) }; + } + + fn assemble_sbt_data<'a>( + &self, + handles: &'a [&'a [u8]], + records: &'a [&'a [u8]], + ) -> crate::ShaderBindingTableData<'a> { + assert!( + handles.len() == records.len(), + "the number of handles and record must match" + ); + + let cap = self + .shared + .private_caps + .ray_tracing_device_properties + .as_ref() + .expect("Feature `RAY_TRACING` not enabled"); + + let shader_group_handle_alignment = cap.shader_group_handle_alignment; + let shader_group_base_alignment = cap.shader_group_base_alignment; + let shader_group_handle_size = cap.shader_group_handle_size; + + let max_record_size = records.iter().map(|e| e.len()).max().unwrap_or(0) as u32; + + let stride = crate::auxil::align_to( + max_record_size + shader_group_handle_size, + shader_group_handle_alignment, + ) as u64; + let count = handles.len() as u64; + let size = stride * count; + let padded_size = crate::auxil::align_to(size as u32, shader_group_base_alignment) as u64; + let outer_padding = padded_size - size; + + let ret = std::iter::zip(handles, records) + .flat_map(move |(handle, record)| { + let inner_padding = stride - (handle.len() + record.len()) as u64; + handle + .iter() + .chain(record.iter()) + .copied() + .chain((0..inner_padding).map(|_| 0)) + }) + .chain((0..outer_padding).map(|_| 0)); + + crate::ShaderBindingTableData { + data: Box::new(ret), + stride, + count, + size, + padded_size, + } + } + + unsafe fn get_buffer_device_address(&self, buffer: &super::Buffer) -> wgt::BufferAddress { + let ray_tracing_functions = match self.shared.extension_fns.ray_tracing { + Some(ref functions) => functions, + None => panic!("Feature `RAY_TRACING` not enabled"), + }; + + unsafe { + ray_tracing_functions + .buffer_device_address + .get_buffer_device_address( + &vk::BufferDeviceAddressInfo::builder().buffer(buffer.raw), + ) + } + } } impl From for crate::DeviceError { diff --git a/wgpu-hal/src/vulkan/mod.rs b/wgpu-hal/src/vulkan/mod.rs index af322e0ee8..af6dce2a7f 100644 --- a/wgpu-hal/src/vulkan/mod.rs +++ b/wgpu-hal/src/vulkan/mod.rs @@ -63,6 +63,7 @@ impl crate::Api for Api { type Sampler = Sampler; type QuerySet = QuerySet; type Fence = Fence; + type AccelerationStructure = AccelerationStructure; type BindGroupLayout = BindGroupLayout; type BindGroup = BindGroup; @@ -70,6 +71,7 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + type RayTracingPipeline = RayTracingPipeline; } struct DebugUtils { @@ -147,6 +149,13 @@ enum ExtensionFn { struct DeviceExtensionFunctions { draw_indirect_count: Option, timeline_semaphore: Option>, + ray_tracing: Option, +} + +struct RayTracingDeviceExtensionFunctions { + acceleration_structure: khr::AccelerationStructure, + buffer_device_address: khr::BufferDeviceAddress, + rt_pipeline: khr::RayTracingPipeline, } /// Set of internal capabilities, which don't show up in the exposed @@ -169,6 +178,19 @@ struct PrivateCapabilities { robust_buffer_access: bool, robust_image_access: bool, zero_initialize_workgroup_memory: bool, + ray_tracing_device_properties: Option, +} + +#[derive(Clone, Debug)] +pub struct RayTracingCapabilities { + pub shader_group_handle_size: u32, + pub max_ray_recursion_depth: u32, + pub max_shader_group_stride: u32, + pub shader_group_base_alignment: u32, + pub shader_group_handle_capture_replay_size: u32, + pub max_ray_dispatch_invocation_count: u32, + pub shader_group_handle_alignment: u32, + pub max_ray_hit_attribute_size: u32, } bitflags::bitflags!( @@ -287,6 +309,46 @@ pub struct Buffer { block: Mutex>, } +#[derive(Debug)] +pub struct AccelerationStructure { + raw: vk::AccelerationStructureKHR, + buffer: vk::Buffer, + block: Mutex>, +} + +#[derive(Debug)] +pub struct RayTracingPipeline { + raw: vk::Pipeline, + handle_data: Vec, + handle_size: usize, + ranges: [usize; 5], +} + +fn get_handle_slices<'a>(pipeline: &'a RayTracingPipeline, range_index: usize) -> Vec<&'a [u8]> { + let range = pipeline.ranges[range_index]..pipeline.ranges[range_index + 1]; + pipeline.handle_data[range] + .chunks(pipeline.handle_size) + .collect() +} + +impl crate::RayTracingPipeline for RayTracingPipeline { + fn gen_handles<'a>(&'a self) -> Vec<&'a [u8]> { + get_handle_slices(self, 0) + } + + fn miss_handles<'a>(&'a self) -> Vec<&'a [u8]> { + get_handle_slices(self, 1) + } + + fn call_handles<'a>(&'a self) -> Vec<&'a [u8]> { + get_handle_slices(self, 2) + } + + fn hit_handles<'a>(&'a self) -> Vec<&'a [u8]> { + get_handle_slices(self, 3) + } +} + #[derive(Debug)] pub struct Texture { raw: vk::Image, diff --git a/wgpu-info/src/main.rs b/wgpu-info/src/main.rs index 4f4c0347f7..e6caadac2e 100644 --- a/wgpu-info/src/main.rs +++ b/wgpu-info/src/main.rs @@ -256,7 +256,7 @@ mod inner { for format in TEXTURE_FORMAT_LIST { let features = adapter.get_texture_format_features(format); let format_name = texture_format_name(format); - print!("\t\t{format_name:>0$}", max_format_name_size); + print!("\t\t{format_name:>max_format_name_size$}"); wgpu::TextureUsages::for_valid_bits(|bit, _i| { print!(" │ "); if features.allowed_usages.contains(bit) { @@ -285,7 +285,7 @@ mod inner { let features = adapter.get_texture_format_features(format); let format_name = texture_format_name(format); - print!("\t\t{format_name:>0$}", max_format_name_size); + print!("\t\t{format_name:>max_format_name_size$}"); wgpu::TextureFormatFeatureFlags::for_valid_bits(|bit, _i| { print!(" │ "); if features.flags.contains(bit) { diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index e6a5ec1bd6..cb0eb324f9 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -702,8 +702,15 @@ bitflags::bitflags! { /// /// This is a native only feature. const VERTEX_ATTRIBUTE_64BIT = 1 << 53; + /// Allows for the creation of ray-tracing acceleration structures. + /// + /// Supported platforms: + /// - Vulkan + /// + /// This is a native-only feature. + const RAY_TRACING = 1 << 54; - // 54..59 available + // 55..59 available // Shader: @@ -746,8 +753,15 @@ bitflags::bitflags! { /// /// This is a native only feature. const SHADER_EARLY_DEPTH_TEST = 1 << 62; + /// Allows for the creation of ray-tracing queries within shaders. + /// + /// Supported platforms: + /// - Vulkan + /// + /// This is a native-only feature. + const RAY_QUERY = 1 << 63; - // 62..64 available + // 64 available } } @@ -1357,6 +1371,18 @@ bitflags::bitflags! { const FRAGMENT = 1 << 1; /// Binding is visible from the compute shader of a compute pipeline. const COMPUTE = 1 << 2; + /// Binding is visible from a ray generation shader of a ray-tracing pipeline. + const RAYGEN = 1 << 3; + /// Binding is visible from a miss shader of a ray-tracing pipeline. + const MISS = 1 << 4; + /// Binding is visible from a callable shader of a ray-tracing pipeline. + const CALLABLE = 1 << 5; + /// Binding is visible from a closest hit shader of a ray-tracing pipeline. + const CLOSEST_HIT = 1 << 6; + /// Binding is visible from a any hit shader of a ray-tracing pipeline. + const ANY_HIT = 1 << 7; + /// Binding is visible from a intersection shader of a ray-tracing pipeline. + const INTERSECTION = 1 << 8; /// Binding is visible from the vertex and fragment shaders of a render pipeline. const VERTEX_FRAGMENT = Self::VERTEX.bits | Self::FRAGMENT.bits; } @@ -5566,6 +5592,15 @@ pub enum BindingType { /// Dimension of the texture view that is going to be sampled. view_dimension: TextureViewDimension, }, + + /// A ray-tracing acceleration structure binding. + /// + /// Example GLSL syntax: + /// ```cpp,ignore + /// layout(binding = 0) + /// uniform accelerationStructureEXT as; + /// ``` + AccelerationStructure, } impl BindingType { diff --git a/wgpu/src/backend/web.rs b/wgpu/src/backend/web.rs index cfa22e3e6b..5b588208ab 100644 --- a/wgpu/src/backend/web.rs +++ b/wgpu/src/backend/web.rs @@ -1321,6 +1321,7 @@ impl crate::context::Context for Context { storage_texture.view_dimension(map_texture_view_dimension(view_dimension)); mapped_entry.storage_texture(&storage_texture); } + wgt::BindingType::AccelerationStructure => todo!(), } mapped_entry