Skip to content

Commit 1d9fe7e

Browse files
author
GBDixonAlex
committed
- added support for tlas update
1 parent c014e9e commit 1d9fe7e

File tree

3 files changed

+151
-53
lines changed

3 files changed

+151
-53
lines changed

plugins/ecs_examples/src/raytraced_shadows.rs

+67-35
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ pub struct BLAS {
1616

1717
#[derive(Component)]
1818
pub struct TLAS {
19-
pub tlas: Option<gfx_platform::RaytracingTLAS>
19+
pub tlas: Option<gfx_platform::RaytracingTLAS>,
20+
pub instance_buffer: Option<gfx_platform::Buffer>
2021
}
2122

2223
/// Setup multiple draw calls with draw indexed and per draw call push constants for transformation matrix etc.
@@ -28,9 +29,10 @@ pub fn raytraced_shadows(client: &mut Client<gfx_platform::Device, os_platform::
2829
"setup_raytraced_shadows_scene"
2930
],
3031
update: systems![
31-
"setup_raytraced_shadows_tlas",
32+
"animate_meshes",
3233
"animate_lights",
33-
"batch_lights"
34+
"batch_lights",
35+
"update_tlas"
3436
],
3537
render_graph: "mesh_lit_rt_shadow",
3638
..Default::default()
@@ -64,8 +66,9 @@ pub fn setup_raytraced_shadows_scene(
6466
mut commands: Commands) -> Result<(), hotline_rs::Error> {
6567

6668
let cube_mesh = hotline_rs::primitives::create_cube_mesh(&mut device.0);
69+
6770
let tourus_mesh = hotline_rs::primitives::create_tourus_mesh(&mut device.0, 32);
68-
let helix_mesh = hotline_rs::primitives::create_helix_mesh(&mut device.0, 32, 4);
71+
let teapot_mesh = hotline_rs::primitives::create_teapot_mesh(&mut device.0, 32);
6972
let tube_mesh = hotline_rs::primitives::create_tube_prism_mesh(&mut device.0, 32, 0, 32, true, true, 1.0, 0.66, 1.0);
7073
let triangle_mesh = hotline_rs::primitives::create_tube_prism_mesh(&mut device.0, 3, 0, 3, false, true, 0.33, 0.66, 1.0);
7174

@@ -115,9 +118,9 @@ pub fn setup_raytraced_shadows_scene(
115118
Position(vec3f(shape_bounds * -0.3, shape_bounds * -0.6, shape_bounds * 0.8)),
116119
Scale(splat3f(tourus_size * 2.0)),
117120
Rotation(Quatf::identity()),
118-
MeshComponent(helix_mesh.clone()),
121+
MeshComponent(teapot_mesh.clone()),
119122
WorldMatrix(Mat34f::identity()),
120-
blas_from_mesh(&mut device, &helix_mesh)?
123+
blas_from_mesh(&mut device, &teapot_mesh)?
121124
));
122125

123126
// tube
@@ -190,15 +193,16 @@ pub fn setup_raytraced_shadows_scene(
190193

191194
commands.spawn(
192195
TLAS {
193-
tlas: None
196+
tlas: None,
197+
instance_buffer: None
194198
}
195199
);
196200

197201
Ok(())
198202
}
199203

200204
#[export_update_fn]
201-
pub fn setup_raytraced_shadows_tlas(
205+
pub fn update_tlas(
202206
mut device: ResMut<DeviceRes>,
203207
mut pmfx: ResMut<PmfxRes>,
204208
mut entities_query: Query<(&mut Position, &mut Scale, &mut Rotation, &BLAS)>,
@@ -207,31 +211,31 @@ pub fn setup_raytraced_shadows_tlas(
207211

208212
// ..
209213
for mut t in &mut tlas_query {
214+
let mut instances = Vec::new();
215+
for (index, (position, scale, rotation, blas)) in &mut entities_query.iter().enumerate() {
216+
let translate = Mat34f::from_translation(position.0);
217+
let rotate = Mat34f::from(rotation.0);
218+
let scale = Mat34f::from_scale(scale.0);
219+
instances.push(
220+
gfx::RaytracingInstanceInfo::<gfx_platform::Device> {
221+
transform: (translate * rotate * scale).m,
222+
instance_id: index as u32,
223+
instance_mask: 0xff,
224+
hit_group_index: 0,
225+
instance_flags: 0,
226+
blas: &blas.blas
227+
}
228+
);
229+
}
210230
if t.tlas.is_none() {
211-
let mut instances = Vec::new();
212-
for (index, (position, scale, rotation, blas)) in &mut entities_query.iter().enumerate() {
213-
let translate = Mat34f::from_translation(position.0);
214-
let rotate = Mat34f::from(rotation.0);
215-
let scale = Mat34f::from_scale(scale.0);
216-
instances.push(
217-
gfx::RaytracingInstanceInfo::<gfx_platform::Device> {
218-
transform: (translate * rotate * scale).m,
219-
instance_id: index as u32,
220-
instance_mask: 0xff,
221-
hit_group_index: 0,
222-
instance_flags: 0,
223-
blas: &blas.blas
224-
}
225-
);
226-
let tlas = device.create_raytracing_tlas_with_heap(&gfx::RaytracingTLASInfo {
227-
instances: &instances,
228-
build_flags: gfx::AccelerationStructureBuildFlags::PREFER_FAST_TRACE,
229-
},
230-
&mut pmfx.shader_heap
231-
)?;
232-
233-
t.tlas = Some(tlas);
234-
}
231+
let tlas = device.create_raytracing_tlas_with_heap(&gfx::RaytracingTLASInfo {
232+
instances: &instances,
233+
build_flags: gfx::AccelerationStructureBuildFlags::PREFER_FAST_TRACE |
234+
gfx::AccelerationStructureBuildFlags::ALLOW_UPDATE
235+
},
236+
&mut pmfx.shader_heap
237+
)?;
238+
t.tlas = Some(tlas);
235239
}
236240
}
237241

@@ -246,7 +250,6 @@ pub fn animate_lights (
246250
let extent = 60.0;
247251
for (mut position, _) in &mut light_query {
248252
position.0 = vec3f(sin(time.accumulated), cos(time.accumulated), cos(time.accumulated)) * extent;
249-
position.0 += vec3f(100.0, 10.0, 100.0);
250253
}
251254

252255
Ok(())
@@ -259,16 +262,45 @@ pub fn animate_lights (
259262

260263
#[export_compute_fn]
261264
pub fn render_meshes_raytraced(
262-
device: ResMut<DeviceRes>,
265+
mut device: ResMut<DeviceRes>,
263266
pmfx: &Res<PmfxRes>,
264267
pass: &pmfx::ComputePass<gfx_platform::Device>,
265-
tlas_query: Query<&TLAS>
268+
mut entities_query: Query<(&mut Position, &mut Scale, &mut Rotation, &BLAS)>,
269+
mut tlas_query: Query<&mut TLAS>,
266270
) -> Result<(), hotline_rs::Error> {
267271
let pmfx = &pmfx.0;
268272

273+
let mut heap = pmfx.shader_heap.clone();
274+
269275
let output_size = pmfx.get_texture_2d_size("staging_output").expect("expected staging_output");
270276
let output_tex = pmfx.get_texture("staging_output").expect("expected staging_output");
271277

278+
// update tlas
279+
for mut t in &mut tlas_query {
280+
let mut instances = Vec::new();
281+
for (index, (position, scale, rotation, blas)) in &mut entities_query.iter().enumerate() {
282+
let translate = Mat34f::from_translation(position.0);
283+
let rotate = Mat34f::from(rotation.0);
284+
let scale = Mat34f::from_scale(scale.0);
285+
instances.push(
286+
gfx::RaytracingInstanceInfo::<gfx_platform::Device> {
287+
transform: (translate * rotate * scale).m,
288+
instance_id: index as u32,
289+
instance_mask: 0xff,
290+
hit_group_index: 0,
291+
instance_flags: 0,
292+
blas: &blas.blas
293+
}
294+
);
295+
}
296+
297+
if let Some(tlas) = t.tlas.as_ref() {
298+
let instance_buffer = device.create_raytracing_instance_buffer(&instances, &mut heap)?;
299+
pass.cmd_buf.update_raytracing_tlas(tlas, &instance_buffer, instances.len(), gfx::AccelerationStructureRebuildMode::Refit);
300+
t.instance_buffer = Some(instance_buffer);
301+
}
302+
}
303+
272304
let camera = pmfx.get_camera_constants("main_camera");
273305
if let Ok(camera) = camera {
274306
for t in &tlas_query {

src/gfx.rs

+15-2
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,13 @@ pub enum RaytracingHitGeometry {
788788
ProceduralPrimitive
789789
}
790790

791+
pub enum AccelerationStructureRebuildMode {
792+
/// Quick refit will update the transforms only
793+
Refit,
794+
/// Full update will update BLAS topologies
795+
Full
796+
}
797+
791798
#[derive(Serialize, Deserialize, Clone)]
792799
pub struct RaytracingHitGroup {
793800
pub name: String,
@@ -1214,6 +1221,12 @@ pub trait Device: 'static + Send + Sync + Sized + Any + Clone {
12141221
data: Option<&[T]>,
12151222
heap: &mut Self::Heap
12161223
) -> Result<Self::Buffer, Error>;
1224+
/// Create an upload buffer that can be used specifically for raytracing acceleration structure instances
1225+
fn create_raytracing_instance_buffer(
1226+
&mut self,
1227+
instances: &Vec<RaytracingInstanceInfo<Self>>,
1228+
heap: &mut Self::Heap // TODO; need to refactor heap handling
1229+
) -> Result<Self::Buffer, Error>;
12171230
/// Create a `Buffer` specifically for reading back data from the GPU mainly for `Query` use
12181231
fn create_read_back_buffer(
12191232
&mut self,
@@ -1398,8 +1411,6 @@ pub trait CmdBuf<D: Device>: Send + Sync + Clone {
13981411
/// Binds the heap with offset (texture srv, uav) on to the `slot` of a pipeline.
13991412
/// this is like a traditional bindful render architecture `cmd.set_binding(pipeline, heap, 0, texture1_id)`
14001413
fn set_binding<T: Pipeline>(&self, pipeline: &T, heap: &D::Heap, slot: u32, offset: usize);
1401-
// TODO:
1402-
fn set_tlas(&self, tlas: &D::RaytracingTLAS);
14031414
/// Push a small amount of data into the command buffer for a render pipeline, num values and dest offset are the number of 32bit values
14041415
fn push_render_constants<T: Sized>(&self, slot: u32, num_values: u32, dest_offset: u32, data: &[T]);
14051416
/// Push a small amount of data into the command buffer for a compute pipeline, num values and dest offset are the number of 32bit values
@@ -1435,6 +1446,8 @@ pub trait CmdBuf<D: Device>: Send + Sync + Clone {
14351446
);
14361447
/// Issue dispatch call for ray tracing with the specified `RaytracingShaderBindingTable` which is associated with the bound `RaytracingPipeline`
14371448
fn dispatch_rays(&self, sbt: &D::RaytracingShaderBindingTable, numthreads: Size3);
1449+
/// Update a raytracing TLAS with instance transform info contained in `instance_buffer` of length `instance_count`. Use `mode` to control quick refit or full rebuild
1450+
fn update_raytracing_tlas(&self, tlas: &D::RaytracingTLAS, instance_buffer: &D::Buffer, instance_count: usize, mode: AccelerationStructureRebuildMode);
14381451
/// Resolves the `subresource` (mip index, 3d texture slice or array slice)
14391452
fn resolve_texture_subresource(&self, texture: &D::Texture, subresource: u32) -> Result<(), Error>;
14401453
/// Generates a full mip chain for the specified `texture` where `heap` is the shader heap the texture was created on

src/gfx/d3d12.rs

+69-16
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,13 @@ fn to_d3d12_raytracing_acceleration_structure_build_flags
812812
d3d12_flags
813813
}
814814

815+
fn to_d3d12_raytracing_acceleration_structure_update_flags(mode: AccelerationStructureRebuildMode) -> D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS {
816+
match mode {
817+
AccelerationStructureRebuildMode::Refit => D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PERFORM_UPDATE,
818+
AccelerationStructureRebuildMode::Full => D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_NONE
819+
}
820+
}
821+
815822
fn get_d3d12_error_blob_string(blob: &ID3DBlob) -> String {
816823
unsafe {
817824
String::from_raw_parts(
@@ -1720,6 +1727,7 @@ pub struct RaytracingBLAS {
17201727

17211728
pub struct RaytracingTLAS {
17221729
pub(crate) tlas_buffer: Buffer,
1730+
pub(crate) scratch_buffer: Option<Buffer>,
17231731
pub(crate) shader_heap_id: u16
17241732
}
17251733

@@ -3436,19 +3444,19 @@ impl super::Device for Device {
34363444
}
34373445
}
34383446

3439-
fn create_raytracing_tlas_with_heap(
3447+
fn create_raytracing_instance_buffer(
34403448
&mut self,
3441-
info: &RaytracingTLASInfo<Self>,
3442-
heap: &mut Heap
3443-
) -> result::Result<RaytracingTLAS, super::Error> {
3449+
instances: &Vec<RaytracingInstanceInfo<Self>>,
3450+
heap: &mut Heap
3451+
) -> result::Result<Buffer, super::Error> {
34443452
// pack 24: 8 bits
34453453
let pack_24_8 = |a, b| {
34463454
(a & 0x00ffffff) | ((b & 0x000000ff) << 24)
34473455
};
34483456

3449-
// create instance descslea
3450-
let num_instances = info.instances.len();
3451-
let instance_descs: Vec<D3D12_RAYTRACING_INSTANCE_DESC> = info.instances.iter()
3457+
// create instance descs
3458+
let num_instances = instances.len();
3459+
let instance_descs: Vec<D3D12_RAYTRACING_INSTANCE_DESC> = instances.iter()
34523460
.map(|x|
34533461
D3D12_RAYTRACING_INSTANCE_DESC {
34543462
Transform: x.transform,
@@ -3460,6 +3468,7 @@ impl super::Device for Device {
34603468

34613469
// create upload buffer for instance descs
34623470
let stride = std::mem::size_of::<D3D12_RAYTRACING_INSTANCE_DESC>();
3471+
34633472
let instance_buffer = self.create_buffer_with_heap(&BufferInfo {
34643473
usage: super::BufferUsage::UPLOAD,
34653474
cpu_access: super::CpuAccessFlags::NONE,
@@ -3470,13 +3479,25 @@ impl super::Device for Device {
34703479
},
34713480
Some(instance_descs.as_slice()),
34723481
heap
3473-
).expect(format!("hotline_rs::gfx::d3d12: failed to create a scratch buffer for raytracing blas of size {}", stride * num_instances).as_str());
3482+
)?;
3483+
3484+
Ok(instance_buffer)
3485+
}
3486+
3487+
fn create_raytracing_tlas_with_heap(
3488+
&mut self,
3489+
info: &RaytracingTLASInfo<Self>,
3490+
heap: &mut Heap
3491+
) -> result::Result<RaytracingTLAS, super::Error> {
3492+
3493+
// create instance buffer
3494+
let instance_buffer = self.create_raytracing_instance_buffer(info.instances, heap)?;
34743495

34753496
// create acceleration structure inputs
34763497
let inputs = D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS {
34773498
Type: D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL,
34783499
Flags: to_d3d12_raytracing_acceleration_structure_build_flags(info.build_flags),
3479-
NumDescs: num_instances as u32,
3500+
NumDescs: info.instances.len() as u32,
34803501
DescsLayout: D3D12_ELEMENTS_LAYOUT_ARRAY,
34813502
Anonymous: D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS_0 {
34823503
InstanceDescs: instance_buffer.d3d_virtual_address(),
@@ -3541,6 +3562,12 @@ impl super::Device for Device {
35413562
// return the result
35423563
Ok(RaytracingTLAS {
35433564
tlas_buffer,
3565+
scratch_buffer: if info.build_flags.contains(AccelerationStructureBuildFlags::ALLOW_UPDATE) {
3566+
Some(scratch_buffer)
3567+
}
3568+
else {
3569+
None
3570+
},
35443571
shader_heap_id: self.shader_heap.as_ref().map(|x| x.id).unwrap_or(0)
35453572
})
35463573
}
@@ -4342,13 +4369,6 @@ impl super::CmdBuf<Device> for CmdBuf {
43424369
}
43434370
}
43444371

4345-
fn set_tlas(&self, tlas: &RaytracingTLAS) {
4346-
let cmd = self.cmd().cast::<ID3D12GraphicsCommandList4>().unwrap();
4347-
unsafe {
4348-
cmd.SetComputeRootShaderResourceView(0, tlas.tlas_buffer.resource.as_ref().unwrap().GetGPUVirtualAddress());
4349-
}
4350-
}
4351-
43524372
fn dispatch_rays(&self, sbt: &RaytracingShaderBindingTable, numthreads: Size3) {
43534373
unsafe {
43544374
let dispatch_desc = D3D12_DISPATCH_RAYS_DESC {
@@ -4381,6 +4401,39 @@ impl super::CmdBuf<Device> for CmdBuf {
43814401
}
43824402
}
43834403

4404+
fn update_raytracing_tlas(&self, tlas: &RaytracingTLAS, instance_buffer: &Buffer, instance_count: usize, mode: AccelerationStructureRebuildMode) {
4405+
4406+
// create acceleration structure inputs
4407+
let inputs = D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS {
4408+
Type: D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL,
4409+
Flags: to_d3d12_raytracing_acceleration_structure_update_flags(mode),
4410+
NumDescs: instance_count as u32,
4411+
DescsLayout: D3D12_ELEMENTS_LAYOUT_ARRAY,
4412+
Anonymous: D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS_0 {
4413+
InstanceDescs: instance_buffer.d3d_virtual_address(),
4414+
}
4415+
};
4416+
4417+
// create blas desc
4418+
let blas_desc = D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC {
4419+
Inputs: inputs,
4420+
SourceAccelerationStructureData: tlas.tlas_buffer.d3d_virtual_address(),
4421+
DestAccelerationStructureData: tlas.tlas_buffer.d3d_virtual_address(),
4422+
ScratchAccelerationStructureData: tlas.scratch_buffer
4423+
.as_ref()
4424+
.expect("hotline_rs::gfx::d3d12: tlas is required to be created with ALLOW_UPDATE in order to update it")
4425+
.d3d_virtual_address()
4426+
};
4427+
4428+
// build blas
4429+
unsafe {
4430+
let bb = self.bb_index;
4431+
let cmd = self.command_list[bb].cast::<ID3D12GraphicsCommandList4>()
4432+
.expect("hotline_rs::gfx::d3d12: expected ID3D12GraphicsCommandList4 availability to create raytracing blas");
4433+
cmd.BuildRaytracingAccelerationStructure(&blas_desc, None);
4434+
}
4435+
}
4436+
43844437
fn read_back_backbuffer(&mut self, swap_chain: &SwapChain) -> result::Result<ReadBackRequest, super::Error> {
43854438
let bb = self.bb_index;
43864439
let bbz = self.bb_index as u32;

0 commit comments

Comments
 (0)