-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathScene.cs
411 lines (338 loc) · 21.2 KB
/
Scene.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
using RayTracingTutorial18.RTX;
using RayTracingTutorial18.RTX.Structs;
using RayTracingTutorial18.Structs;
using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Threading;
using Vortice.Direct3D12;
using Vortice.Direct3D12.Debug;
using Vortice.DXGI;
using Vortice.Mathematics;
namespace RayTracingTutorial18
{
public class Scene
{
private const int D3D12DefaultShader4ComponentMapping = 5768;
private const int kRtvHeapSize = 3;
private Color4 clearColor = new Color4(0.4f, 0.6f, 0.2f, 1.0f);
private readonly Window Window;
private D3D12GraphicsContext context;
private IntPtr mHwnd;
private ID3D12Device5 mpDevice;
private ID3D12CommandQueue mpCmdQueue;
private IDXGISwapChain3 mpSwapChain;
private ID3D12GraphicsCommandList4 mpCmdList;
private HeapData mRtvHeap;
private FrameObject[] mFrameObjects;
private ID3D12Fence mpFence;
private uint mFenceValue = 0;
private EventWaitHandle mFenceEvent;
private Rect mSwapChainRect;
private ID3D12Resource mpTopLevelAS;
private ID3D12StateObject mpPipelineState;
private ID3D12RootSignature mpEmptyRootSig;
private AccelerationStructures acs;
private ID3D12Resource mpOutputResource;
private ID3D12DescriptorHeap mpSrvUavHeap;
private ID3D12Resource mpShaderTable;
private uint mShaderTableEntrySize;
private long mTlasSize = 0;
private CpuDescriptorHandle indexSRVHandle;
private CpuDescriptorHandle vertexSRVHandle;
public Scene(Window window)
{
this.Window = window;
this.context = new D3D12GraphicsContext(window.Width, window.Height);
// InitDXR Tutorial 02
this.InitDXR((IntPtr)window.Handle, window.Width, window.Height);
// Acceleration Structures Tutorial 03
this.CreateAccelerationStructures();
// RtPipeline Tutorial 04
this.CreateRtPipelineState();
// ShaderResources Tutorial 06. Need to do this before initializing the shader-table
this.CreateShaderResources();
// ShaderTable Tutorial 05
this.CreateShaderTable();
}
private void InitDXR(IntPtr winHandle, int winWidth, int winHeight)
{
mHwnd = winHandle;
this.mSwapChainRect = new Rect(0, 0, winWidth, winHeight);
// Initialize the debug layer for debug builds
#if DEBUG
if (D3D12.D3D12GetDebugInterface<ID3D12Debug>(out var pDx12Debug).Success)
{
pDx12Debug.EnableDebugLayer();
}
#endif
// Create the DXGI factory
IDXGIFactory4 pDXGIFactory;
DXGI.CreateDXGIFactory1<IDXGIFactory4>(out pDXGIFactory);
mpDevice = this.context.CreateDevice(pDXGIFactory);
mpCmdQueue = this.context.CreateCommandQueue(mpDevice);
mpSwapChain = this.context.CreateDXGISwapChain(pDXGIFactory, mHwnd, winWidth, winHeight, Format.R8G8B8A8_UNorm, mpCmdQueue);
// Create a RTV descriptor heap
mRtvHeap.Heap = this.context.CreateDescriptorHeap(mpDevice, kRtvHeapSize, DescriptorHeapType.RenderTargetView, false);
// Create the per-frame objects
this.mFrameObjects = new FrameObject[this.context.kDefaultSwapChainBuffers];
for (int i = 0; i < this.context.kDefaultSwapChainBuffers; i++)
{
mFrameObjects[i].pCmdAllocator = mpDevice.CreateCommandAllocator(CommandListType.Direct);
mFrameObjects[i].pSwapChainBuffer = mpSwapChain.GetBuffer<ID3D12Resource>(i);
mFrameObjects[i].rtvHandle = context.CreateRTV(mpDevice, mFrameObjects[i].pSwapChainBuffer, mRtvHeap.Heap, ref mRtvHeap.usedEntries, Format.R8G8B8A8_UNorm_SRgb);
}
// Create the command-list
var cmdList = mpDevice.CreateCommandList(0, CommandListType.Direct, mFrameObjects[0].pCmdAllocator, null);
this.mpCmdList = cmdList.QueryInterface<ID3D12GraphicsCommandList4>();
// Create a fence and the event
this.mpFence = mpDevice.CreateFence(0, FenceFlags.None);
this.mFenceEvent = new EventWaitHandle(false, EventResetMode.AutoReset);
}
public void CreateAccelerationStructures()
{
acs = new AccelerationStructures();
AccelerationStructureBuffers[] bottomLevelBuffers = new AccelerationStructureBuffers[2];
bottomLevelBuffers[0] = acs.CreatePlaneBottomLevelAS(mpDevice, mpCmdList);
bottomLevelBuffers[1] = acs.CreatePrimitiveBottomLevelAS(mpDevice, mpCmdList);
AccelerationStructureBuffers topLevelBuffers = acs.CreateTopLevelAS(mpDevice, mpCmdList, bottomLevelBuffers, ref mTlasSize);
// The tutorial doesn't have any resource lifetime management, so we flush and sync here. This is not required by the DXR spec - you can submit the list whenever you like as long as you take care of the resources lifetime.
mFenceValue = context.SubmitCommandList(mpCmdList, mpCmdQueue, mpFence, mFenceValue);
mpFence.SetEventOnCompletion(mFenceValue, mFenceEvent);
mFenceEvent.WaitOne();
int bufferIndex = mpSwapChain.GetCurrentBackBufferIndex();
mpCmdList.Reset(mFrameObjects[0].pCmdAllocator, null);
// Store the AS buffers. The rest of the buffers will be released once we exit the function
mpTopLevelAS = topLevelBuffers.pResult;
}
public void CreateRtPipelineState()
{
var rtpipeline = new RTPipeline();
// Need 10 subobjects:
// 1 for the DXIL library
// 1 for hit-group
// 2 for RayGen root-signature (root-signature and the subobject association)
// 2 for hit-program root-signature (root-signature and the subobject association)
// 2 for miss-shader root-signature (signature and association)
// 2 for shader config (shared between all programs. 1 for the config, 1 for association)
// 1 for pipeline config
// 1 for the global root signature
StateSubObject[] subobjects = new StateSubObject[12];
int index = 0;
// Create the DXIL library
DxilLibrary dxilLib = rtpipeline.CreateDxilLibrary();
subobjects[index++] = dxilLib.stateSubObject; // 0 Library
HitProgram hitProgram = new HitProgram(null, RTPipeline.kClosestHitShader, RTPipeline.kHitGroup);
subobjects[index++] = hitProgram.subObject; // 1 Hit Group
// Create the ray-gen root-signature and association
Structs.LocalRootSignature rgsRootSignature = new Structs.LocalRootSignature(mpDevice, rtpipeline.CreateRayGenRootDesc());
subobjects[index] = rgsRootSignature.subobject; // 2 RayGen Root Sig
int rgsRootIndex = index++; // 2
ExportAssociation rgsRootAssociation = new ExportAssociation(new string[] { RTPipeline.kRayGenShader }, subobjects[rgsRootIndex]);
subobjects[index++] = rgsRootAssociation.subobject; // 3 Associate Root Sig to RGS
// Create the hit root-signature and association
Structs.LocalRootSignature hitRootSignature = new Structs.LocalRootSignature(mpDevice, rtpipeline.CreateHitRootDesc());
subobjects[index] = hitRootSignature.subobject; // 4 Hit Root Sig
int hitRootIndex = index++; // 4
ExportAssociation hitRootAssociation = new ExportAssociation(new string[] { RTPipeline.kClosestHitShader }, subobjects[hitRootIndex]); // 5 Associate Hit Root Sig to Hit Group
subobjects[index++] = hitRootAssociation.subobject; // 6 Associate Hit Root Sig to Hit Group
// Create the miss root-signature and association
RootSignatureDescription emptyDesc = new RootSignatureDescription(RootSignatureFlags.LocalRootSignature);
Structs.LocalRootSignature missRootSignature = new Structs.LocalRootSignature(mpDevice, emptyDesc);
subobjects[index] = missRootSignature.subobject; // 6 Miss Root Sig
int missRootIndex = index++; // 6
ExportAssociation missRootAssociation = new ExportAssociation(new string[] { RTPipeline.kMissShader, RTPipeline.kShadowMiss }, subobjects[missRootIndex]);
subobjects[index++] = missRootAssociation.subobject; // 7 Associate Miss Root Sig to Miss Shader
// Bind the payload size to the programs
ShaderConfig shaderConfig = new ShaderConfig(sizeof(float) * 2, sizeof(float) * (4 + 1)); //MaxPayloadSize float4 + uint
subobjects[index] = shaderConfig.subObject; // 8 Shader Config;
int shaderConfigIndex = index++; // 8
string[] shaderExports = new string[] { RTPipeline.kMissShader, RTPipeline.kClosestHitShader, RTPipeline.kRayGenShader, RTPipeline.kShadowMiss };
ExportAssociation configAssociation = new ExportAssociation(shaderExports, subobjects[shaderConfigIndex]);
subobjects[index++] = configAssociation.subobject; // 9 Associate Shader Config to Miss, CHS, RGS
// Create the pipeline config
PipelineConfig config = new PipelineConfig(4+1);
subobjects[index++] = config.suboject; // 10
// Create the global root signature and store the empty signature
Structs.GlobalRootSignature root = new Structs.GlobalRootSignature(mpDevice, new RootSignatureDescription());
mpEmptyRootSig = root.pRootSig.RootSignature;
subobjects[index++] = root.suboject; // 11
// Create the state
StateObjectDescription desc = new StateObjectDescription(StateObjectType.RaytracingPipeline, subobjects);
mpPipelineState = mpDevice.CreateStateObject(desc);
}
private const uint D3D12ShaderIdentifierSizeInBytes = 32;
private const uint D3D12RaytracingShaderRecordByteAlignment = 32;
private static uint align_to(uint _alignment, uint _val)
{
return (((_val + _alignment - 1) / _alignment) * _alignment);
}
public unsafe void CreateShaderTable()
{
/** The shader-table layout is as follows:
Entry 0 - Ray-gen program
Entry 1 - Miss program
Entry 2 - Hit program
All entries in the shader-table must have the same size, so we will choose it base on the largest required entry.
The ray-gen program requires the largest entry - sizeof(program identifier) + 8 bytes for a descriptor-table.
The entry size must be aligned up to D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT
*/
// Calculate the size and create the buffer
mShaderTableEntrySize = D3D12ShaderIdentifierSizeInBytes;
mShaderTableEntrySize += 8; // the ray-gen's descriptor table
mShaderTableEntrySize = align_to(D3D12RaytracingShaderRecordByteAlignment, mShaderTableEntrySize);
uint shaderTableSize = mShaderTableEntrySize * 4;
// For simplicity, we create the shader.table on the upload heap. You can also create it on the default heap
mpShaderTable = this.acs.CreateBuffer(mpDevice, shaderTableSize, ResourceFlags.None, ResourceStates.GenericRead, AccelerationStructures.kUploadHeapProps);
// Map the buffer
IntPtr pData;
pData = mpShaderTable.Map(0, null);
ID3D12StateObjectProperties pRtsoProps;
pRtsoProps = mpPipelineState.QueryInterface<ID3D12StateObjectProperties>();
// Entry 0 - ray-gen program ID and descriptor data
Unsafe.CopyBlock((void*)pData, (void*)pRtsoProps.GetShaderIdentifier(RTPipeline.kRayGenShader), D3D12ShaderIdentifierSizeInBytes);
ulong heapStart = (ulong)mpSrvUavHeap.GetGPUDescriptorHandleForHeapStart().Ptr;
Unsafe.Write<ulong>((pData + (int)D3D12ShaderIdentifierSizeInBytes).ToPointer(), heapStart);
// This is where we need to set the descriptor data for the ray-gen shader. We'll get to it in the next tutorial
// Entry 1 - miss program
pData += (int)mShaderTableEntrySize; // +1 skips ray-gen
Unsafe.CopyBlock((void*)pData, (void*)pRtsoProps.GetShaderIdentifier(RTPipeline.kMissShader), D3D12ShaderIdentifierSizeInBytes);
// Entry 2 - miss program
pData += (int)mShaderTableEntrySize; // +1 skips ray-gen
Unsafe.CopyBlock((void*)pData, (void*)pRtsoProps.GetShaderIdentifier(RTPipeline.kShadowMiss), D3D12ShaderIdentifierSizeInBytes);
// Entry 3 - hit program
pData += (int)mShaderTableEntrySize; // +1 skips miss entries
Unsafe.CopyBlock((void*)pData, (void*)pRtsoProps.GetShaderIdentifier(RTPipeline.kHitGroup), D3D12ShaderIdentifierSizeInBytes);
heapStart = (ulong)mpSrvUavHeap.GetGPUDescriptorHandleForHeapStart().Ptr;
Unsafe.Write<ulong>((pData + (int)D3D12ShaderIdentifierSizeInBytes).ToPointer(), heapStart);
// Unmap
mpShaderTable.Unmap(0, null);
}
public void CreateShaderResources()
{
// Create the output resource. The dimensions and format should match the swap-chain
ResourceDescription resDesc = new ResourceDescription();
resDesc.DepthOrArraySize = 1;
resDesc.Dimension = ResourceDimension.Texture2D;
resDesc.Format = Format.R8G8B8A8_UNorm; // The backbuffer is actually DXGI_FORMAT_R8G8B8A8_UNORM_SRGB, but sRGB formats can't be used with UAVs. We will convert to sRGB ourselves in the shader
resDesc.Flags = ResourceFlags.AllowUnorderedAccess;
resDesc.Height = mSwapChainRect.Height;
resDesc.Layout = TextureLayout.Unknown;
resDesc.MipLevels = 1;
resDesc.SampleDescription = new SampleDescription(1, 0);
resDesc.Width = mSwapChainRect.Width;
mpOutputResource = mpDevice.CreateCommittedResource(AccelerationStructures.kDefaultHeapProps, HeapFlags.None, resDesc, ResourceStates.CopySource, null); // Starting as copy-source to simplify onFrameRender()
// Create an SRV/UAV/VertexSRV/IndexSRV descriptor heap. Need 4 entries - 1 SRV for the scene, 1 UAV for the output, 1 SRV for VertexBuffer, 1 SRV for IndexBuffer
mpSrvUavHeap = this.context.CreateDescriptorHeap(mpDevice, 4, DescriptorHeapType.ConstantBufferViewShaderResourceViewUnorderedAccessView, true);
// Create the UAV. Based on the root signature we created it should be the first entry
UnorderedAccessViewDescription uavDesc = new UnorderedAccessViewDescription();
uavDesc.ViewDimension = UnorderedAccessViewDimension.Texture2D;
mpDevice.CreateUnorderedAccessView(mpOutputResource, null, uavDesc, mpSrvUavHeap.GetCPUDescriptorHandleForHeapStart());
// Create the TLAS SRV right after the UAV. Note that we are using a different SRV desc here
ShaderResourceViewDescription srvDesc = new ShaderResourceViewDescription();
srvDesc.ViewDimension = ShaderResourceViewDimension.RaytracingAccelerationStructure;
srvDesc.Shader4ComponentMapping = D3D12DefaultShader4ComponentMapping;
srvDesc.RaytracingAccelerationStructure = new RaytracingAccelerationStructureShaderResourceView();
srvDesc.RaytracingAccelerationStructure.Location = mpTopLevelAS.GPUVirtualAddress;
CpuDescriptorHandle srvHandle = mpSrvUavHeap.GetCPUDescriptorHandleForHeapStart();
srvHandle.Ptr += mpDevice.GetDescriptorHandleIncrementSize(DescriptorHeapType.ConstantBufferViewShaderResourceViewUnorderedAccessView);
mpDevice.CreateShaderResourceView(null, srvDesc, srvHandle);
// Index SRV
var indexSRVDesc = new ShaderResourceViewDescription()
{
ViewDimension = ShaderResourceViewDimension.Buffer,
Shader4ComponentMapping = D3D12DefaultShader4ComponentMapping,
Format = Format.R32_Typeless,
Buffer =
{
NumElements = (int)(this.acs.IndexCount * 2 / 4),
Flags = BufferShaderResourceViewFlags.Raw,
StructureByteStride = 0,
}
};
srvHandle.Ptr += mpDevice.GetDescriptorHandleIncrementSize(DescriptorHeapType.ConstantBufferViewShaderResourceViewUnorderedAccessView);
indexSRVHandle = srvHandle;
mpDevice.CreateShaderResourceView(this.acs.IndexBuffer, indexSRVDesc, indexSRVHandle);
// Vertex SRV
var vertexSRVDesc = new ShaderResourceViewDescription()
{
ViewDimension = ShaderResourceViewDimension.Buffer,
Shader4ComponentMapping = D3D12DefaultShader4ComponentMapping,
Format = Format.Unknown,
Buffer =
{
NumElements = (int)this.acs.VertexCount,
Flags = BufferShaderResourceViewFlags.None,
StructureByteStride = Unsafe.SizeOf<VertexPositionNormalTangentTexture>(),
}
};
srvHandle.Ptr += mpDevice.GetDescriptorHandleIncrementSize(DescriptorHeapType.ConstantBufferViewShaderResourceViewUnorderedAccessView);
vertexSRVHandle = srvHandle;
mpDevice.CreateShaderResourceView(this.acs.VertexBuffer, vertexSRVDesc, vertexSRVHandle);
}
private int BeginFrame()
{
// Bind the descriptor heaps
ID3D12DescriptorHeap[] heaps = new ID3D12DescriptorHeap[] { mpSrvUavHeap };
mpCmdList.SetDescriptorHeaps(1, heaps);
return this.mpSwapChain.GetCurrentBackBufferIndex();
}
public bool DrawFrame(Action<int, int> draw, [CallerMemberName] string frameName = null)
{
int rtvIndex = BeginFrame();
// Let's raytrace
context.ResourceBarrier(mpCmdList, mpOutputResource, ResourceStates.CopySource, ResourceStates.UnorderedAccess);
DispatchRaysDescription raytraceDesc = new DispatchRaysDescription();
raytraceDesc.Width = mSwapChainRect.Width;
raytraceDesc.Height = mSwapChainRect.Height;
raytraceDesc.Depth = 1;
// RayGen is the first entry in the shader-table
raytraceDesc.RayGenerationShaderRecord.StartAddress = mpShaderTable.GPUVirtualAddress + 0 * mShaderTableEntrySize;
raytraceDesc.RayGenerationShaderRecord.SizeInBytes = mShaderTableEntrySize;
// Miss is the second entry in the shader-table
uint missOffset = 1 * mShaderTableEntrySize;
raytraceDesc.MissShaderTable.StartAddress = mpShaderTable.GPUVirtualAddress + missOffset;
raytraceDesc.MissShaderTable.StrideInBytes = mShaderTableEntrySize;
raytraceDesc.MissShaderTable.SizeInBytes = mShaderTableEntrySize * 2; // Only a s single miss-entry
// Hit is the third entry in the shader-table
uint hitOffset = 3 * mShaderTableEntrySize;
raytraceDesc.HitGroupTable.StartAddress = mpShaderTable.GPUVirtualAddress + hitOffset;
raytraceDesc.HitGroupTable.StrideInBytes = mShaderTableEntrySize;
raytraceDesc.HitGroupTable.SizeInBytes = mShaderTableEntrySize;
// Bind the empty root signature
mpCmdList.SetComputeRootSignature(mpEmptyRootSig);
// Dispatch
mpCmdList.SetPipelineState1(mpPipelineState);
mpCmdList.DispatchRays(raytraceDesc);
// Copy the results to the back-buffer
context.ResourceBarrier(mpCmdList, mpOutputResource, ResourceStates.UnorderedAccess, ResourceStates.CopySource);
context.ResourceBarrier(mpCmdList, mFrameObjects[rtvIndex].pSwapChainBuffer, ResourceStates.Present, ResourceStates.CopyDestination);
mpCmdList.CopyResource(mFrameObjects[rtvIndex].pSwapChainBuffer, mpOutputResource);
EndFrame(rtvIndex);
return true;
}
private void EndFrame(int rtvIndex)
{
context.ResourceBarrier(mpCmdList, mFrameObjects[rtvIndex].pSwapChainBuffer, ResourceStates.CopyDestination, ResourceStates.Present);
mFenceValue = context.SubmitCommandList(mpCmdList, mpCmdQueue, mpFence, mFenceValue);
mpSwapChain.Present(0, 0);
// Prepare the command list for the next frame
int bufferIndex = mpSwapChain.GetCurrentBackBufferIndex();
// Make sure we have the new back-buffer is ready
if (mFenceValue > context.kDefaultSwapChainBuffers)
{
mpFence.SetEventOnCompletion(mFenceValue - context.kDefaultSwapChainBuffers + 1, mFenceEvent);
this.mFenceEvent.WaitOne();
}
mFrameObjects[bufferIndex].pCmdAllocator.Reset();
mpCmdList.Reset(mFrameObjects[bufferIndex].pCmdAllocator, null);
}
public void Dispose()
{
mFenceValue++;
mpCmdQueue.Signal(mpFence, mFenceValue);
mpFence.SetEventOnCompletion(mFenceValue, mFenceEvent);
mFenceEvent.WaitOne();
}
}
}