-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathScene.cs
358 lines (292 loc) · 18.2 KB
/
Scene.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
using RayTracingTutorial08.RTX;
using RayTracingTutorial08.Structs;
using System;
using System.Runtime.CompilerServices;
using System.Threading;
using Vortice.Direct3D12;
using Vortice.Direct3D12.Debug;
using Vortice.DXGI;
using Vortice.Mathematics;
namespace RayTracingTutorial08
{
public class Scene
{
private const int D3D12DefaultShader4ComponentMapping = 5768;
private const int kRtvHeapSize = 3;
private Color4 clearColor = new Color4(0.4f, 0.6f, 0.2f, 1.0f);
private readonly Window Window;
private D3D12GraphicsContext context;
private IntPtr mHwnd;
private ID3D12Device5 mpDevice;
private ID3D12CommandQueue mpCmdQueue;
private IDXGISwapChain3 mpSwapChain;
private ID3D12GraphicsCommandList4 mpCmdList;
private HeapData mRtvHeap;
private FrameObject[] mFrameObjects;
private ID3D12Fence mpFence;
private uint mFenceValue = 0;
private EventWaitHandle mFenceEvent;
private Rect mSwapChainRect;
private ID3D12Resource mpTopLevelAS;
private ID3D12Resource mpBottomLevelAS;
private ID3D12StateObject mpPipelineState;
private ID3D12RootSignature mpEmptyRootSig;
private AccelerationStructures acs;
private ID3D12Resource mpOutputResource;
private ID3D12DescriptorHeap mpSrvUavHeap;
private ID3D12Resource mpShaderTable;
private uint mShaderTableEntrySize;
public Scene(Window window)
{
this.Window = window;
this.context = new D3D12GraphicsContext(window.Width, window.Height);
// InitDXR Tutorial 02
this.InitDXR((IntPtr)window.Handle, window.Width, window.Height);
// Acceleration Structures Tutorial 03
this.CreateAccelerationStructures();
// RtPipeline Tutorial 04
this.CreateRtPipelineState();
// ShaderResources Tutorial 06. Need to do this before initializing the shader-table
this.CreateShaderResources();
// ShaderTable Tutorial 05
this.CreateShaderTable();
}
private void InitDXR(IntPtr winHandle, int winWidth, int winHeight)
{
mHwnd = winHandle;
this.mSwapChainRect = new Rect(0, 0, winWidth, winHeight);
// Initialize the debug layer for debug builds
#if DEBUG
if (D3D12.D3D12GetDebugInterface<ID3D12Debug>(out var pDx12Debug).Success)
{
pDx12Debug.EnableDebugLayer();
}
#endif
// Create the DXGI factory
IDXGIFactory4 pDXGIFactory;
DXGI.CreateDXGIFactory1<IDXGIFactory4>(out pDXGIFactory);
mpDevice = this.context.CreateDevice(pDXGIFactory);
mpCmdQueue = this.context.CreateCommandQueue(mpDevice);
mpSwapChain = this.context.CreateDXGISwapChain(pDXGIFactory, mHwnd, winWidth, winHeight, Format.R8G8B8A8_UNorm, mpCmdQueue);
// Create a RTV descriptor heap
mRtvHeap.Heap = this.context.CreateDescriptorHeap(mpDevice, kRtvHeapSize, DescriptorHeapType.RenderTargetView, false);
// Create the per-frame objects
this.mFrameObjects = new FrameObject[this.context.kDefaultSwapChainBuffers];
for (int i = 0; i < this.context.kDefaultSwapChainBuffers; i++)
{
mFrameObjects[i].pCmdAllocator = mpDevice.CreateCommandAllocator(CommandListType.Direct);
mFrameObjects[i].pSwapChainBuffer = mpSwapChain.GetBuffer<ID3D12Resource>(i);
mFrameObjects[i].rtvHandle = context.CreateRTV(mpDevice, mFrameObjects[i].pSwapChainBuffer, mRtvHeap.Heap, ref mRtvHeap.usedEntries, Format.R8G8B8A8_UNorm_SRgb);
}
// Create the command-list
var cmdList = mpDevice.CreateCommandList(0, CommandListType.Direct, mFrameObjects[0].pCmdAllocator, null);
this.mpCmdList = cmdList.QueryInterface<ID3D12GraphicsCommandList4>();
// Create a fence and the event
this.mpFence = mpDevice.CreateFence(0, FenceFlags.None);
this.mFenceEvent = new EventWaitHandle(false, EventResetMode.AutoReset);
}
public void CreateAccelerationStructures()
{
acs = new AccelerationStructures();
long mTlasSize = 0;
var mpVertexBuffer = acs.CreateTriangleVB(mpDevice);
AccelerationStructureBuffers bottomLevelBuffers = acs.CreateBottomLevelAS(mpDevice, mpCmdList, mpVertexBuffer);
AccelerationStructureBuffers topLevelBuffers = acs.CreateTopLevelAS(mpDevice, mpCmdList, bottomLevelBuffers.pResult, ref mTlasSize);
// The tutorial doesn't have any resource lifetime management, so we flush and sync here. This is not required by the DXR spec - you can submit the list whenever you like as long as you take care of the resources lifetime.
mFenceValue = context.SubmitCommandList(mpCmdList, mpCmdQueue, mpFence, mFenceValue);
mpFence.SetEventOnCompletion(mFenceValue, mFenceEvent);
mFenceEvent.WaitOne();
int bufferIndex = mpSwapChain.GetCurrentBackBufferIndex();
mpCmdList.Reset(mFrameObjects[0].pCmdAllocator, null);
// Store the AS buffers. The rest of the buffers will be released once we exit the function
mpTopLevelAS = topLevelBuffers.pResult;
mpBottomLevelAS = bottomLevelBuffers.pResult;
}
public void CreateRtPipelineState()
{
var rtpipeline = new RTPipeline();
// Need 10 subobjects:
// 1 for the DXIL library
// 1 for hit-group
// 2 for RayGen root-signature (root-signature and the subobject association)
// 2 for the root-signature shared between miss and hit shaders (signature and association)
// 2 for shader config (shared between all programs. 1 for the config, 1 for association)
// 1 for pipeline config
// 1 for the global root signature
StateSubObject[] subobjects = new StateSubObject[10];
int index = 0;
// Create the DXIL library
DxilLibrary dxilLib = rtpipeline.CreateDxilLibrary();
subobjects[index++] = dxilLib.stateSubObject; // 0 Library
HitProgram hitProgram = new HitProgram(null, RTPipeline.kClosestHitShader, RTPipeline.kHitGroup);
subobjects[index++] = hitProgram.subObject; // 1 Hit Group
// Create the ray-gen root-signature and association
Structs.LocalRootSignature rgsRootSignature = new Structs.LocalRootSignature(mpDevice, rtpipeline.CreateRayGenRootDesc());
subobjects[index] = rgsRootSignature.subobject; // 2 RayGen Root Sig
int rgsRootIndex = index++; // 2
ExportAssociation rgsRootAssociation = new ExportAssociation(new string[] { RTPipeline.kRayGenShader }, subobjects[rgsRootIndex]);
subobjects[index++] = rgsRootAssociation.subobject; // 3 Associate Root Sig to RGS
// Create the miss- and hit-programs root-signature and association
RootSignatureDescription emptyDesc = new RootSignatureDescription(RootSignatureFlags.LocalRootSignature);
Structs.LocalRootSignature hitMissRootSignature = new Structs.LocalRootSignature(mpDevice, emptyDesc);
subobjects[index] = hitMissRootSignature.subobject; // 4 Root Sig to be shared between Miss and CHS
int hitMissRootIndex = index++; // 4
string[] missHitExportName = new string[] { RTPipeline.kMissShader, RTPipeline.kClosestHitShader };
ExportAssociation missHitRootAssociation = new ExportAssociation(missHitExportName, subobjects[hitMissRootIndex]);
subobjects[index++] = missHitRootAssociation.subobject; // 5 Associate Root Sig to Miss and CHS
// Bind the payload size to the programs
ShaderConfig shaderConfig = new ShaderConfig(sizeof(float) * 2, sizeof(float) * 3);
subobjects[index] = shaderConfig.subObject; // 6 Shader Config;
int shaderConfigIndex = index++; // 6
string[] shaderExports = new string[] { RTPipeline.kMissShader, RTPipeline.kClosestHitShader, RTPipeline.kRayGenShader };
ExportAssociation configAssociation = new ExportAssociation(shaderExports, subobjects[shaderConfigIndex]);
subobjects[index++] = configAssociation.subobject; // 7 Associate Shader Config to Miss, CHS, RGS
// Create the pipeline config
PipelineConfig config = new PipelineConfig(1);
subobjects[index++] = config.suboject; // 8
// Create the global root signature and store the empty signature
Structs.GlobalRootSignature root = new Structs.GlobalRootSignature(mpDevice, new RootSignatureDescription());
mpEmptyRootSig = root.pRootSig.RootSignature;
subobjects[index++] = root.suboject; // 9
// Create the state
StateObjectDescription desc = new StateObjectDescription(StateObjectType.RaytracingPipeline, subobjects);
mpPipelineState = mpDevice.CreateStateObject(desc);
}
private const uint D3D12ShaderIdentifierSizeInBytes = 32;
private const uint D3D12RaytracingShaderRecordByteAlignment = 32;
private static uint align_to(uint _alignment, uint _val)
{
return (((_val + _alignment - 1) / _alignment) * _alignment);
}
public unsafe void CreateShaderTable()
{
/** The shader-table layout is as follows:
Entry 0 - Ray-gen program
Entry 1 - Miss program
Entry 2 - Hit program
All entries in the shader-table must have the same size, so we will choose it base on the largest required entry.
The ray-gen program requires the largest entry - sizeof(program identifier) + 8 bytes for a descriptor-table.
The entry size must be aligned up to D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT
*/
// Calculate the size and create the buffer
mShaderTableEntrySize = D3D12ShaderIdentifierSizeInBytes;
mShaderTableEntrySize += 8; // the ray-gen's descriptor table
mShaderTableEntrySize = align_to(D3D12RaytracingShaderRecordByteAlignment, mShaderTableEntrySize);
uint shaderTableSize = mShaderTableEntrySize * 3;
// For simplicity, we create the shader.table on the upload heap. You can also create it on the default heap
mpShaderTable = this.acs.CreateBuffer(mpDevice, shaderTableSize, ResourceFlags.None, ResourceStates.GenericRead, AccelerationStructures.kUploadHeapProps);
// Map the buffer
IntPtr pData;
pData = mpShaderTable.Map(0, null);
ID3D12StateObjectProperties pRtsoProps;
pRtsoProps = mpPipelineState.QueryInterface<ID3D12StateObjectProperties>();
// Entry 0 - ray-gen program ID and descriptor data
Unsafe.CopyBlock((void*)pData, (void*)pRtsoProps.GetShaderIdentifier(RTPipeline.kRayGenShader), D3D12ShaderIdentifierSizeInBytes);
ulong heapStart = (ulong)mpSrvUavHeap.GetGPUDescriptorHandleForHeapStart().Ptr;
Unsafe.Write<ulong>((pData + (int)D3D12ShaderIdentifierSizeInBytes).ToPointer(), heapStart);
// This is where we need to set the descriptor data for the ray-gen shader. We'll get to it in the next tutorial
// Entry 1 - miss program
pData += (int)mShaderTableEntrySize; // +1 skips ray-gen
Unsafe.CopyBlock((void*)pData, (void*)pRtsoProps.GetShaderIdentifier(RTPipeline.kMissShader), D3D12ShaderIdentifierSizeInBytes);
// Entry 2 - hit program
pData += (int)mShaderTableEntrySize; // +1 skips miss entries
Unsafe.CopyBlock((void*)pData, (void*)pRtsoProps.GetShaderIdentifier(RTPipeline.kHitGroup), D3D12ShaderIdentifierSizeInBytes);
// Unmap
mpShaderTable.Unmap(0, null);
}
public void CreateShaderResources()
{
// Create the output resource. The dimensions and format should match the swap-chain
ResourceDescription resDesc = new ResourceDescription();
resDesc.DepthOrArraySize = 1;
resDesc.Dimension = ResourceDimension.Texture2D;
resDesc.Format = Format.R8G8B8A8_UNorm; // The backbuffer is actually DXGI_FORMAT_R8G8B8A8_UNORM_SRGB, but sRGB formats can't be used with UAVs. We will convert to sRGB ourselves in the shader
resDesc.Flags = ResourceFlags.AllowUnorderedAccess;
resDesc.Height = mSwapChainRect.Height;
resDesc.Layout = TextureLayout.Unknown;
resDesc.MipLevels = 1;
resDesc.SampleDescription = new SampleDescription(1, 0);
resDesc.Width = mSwapChainRect.Width;
mpOutputResource = mpDevice.CreateCommittedResource(AccelerationStructures.kDefaultHeapProps, HeapFlags.None, resDesc, ResourceStates.CopySource, null); // Starting as copy-source to simplify onFrameRender()
// Create an SRV/UAV descriptor heap. Need 2 entries - 1 SRV for the scene and 1 UAV for the output
mpSrvUavHeap = this.context.CreateDescriptorHeap(mpDevice, 2, DescriptorHeapType.ConstantBufferViewShaderResourceViewUnorderedAccessView, true);
// Create the UAV. Based on the root signature we created it should be the first entry
UnorderedAccessViewDescription uavDesc = new UnorderedAccessViewDescription();
uavDesc.ViewDimension = UnorderedAccessViewDimension.Texture2D;
mpDevice.CreateUnorderedAccessView(mpOutputResource, null, uavDesc, mpSrvUavHeap.GetCPUDescriptorHandleForHeapStart());
// Create the TLAS SRV right after the UAV. Note that we are using a different SRV desc here
ShaderResourceViewDescription srvDesc = new ShaderResourceViewDescription();
srvDesc.ViewDimension = ShaderResourceViewDimension.RaytracingAccelerationStructure;
srvDesc.Shader4ComponentMapping = D3D12DefaultShader4ComponentMapping;
srvDesc.RaytracingAccelerationStructure = new RaytracingAccelerationStructureShaderResourceView();
srvDesc.RaytracingAccelerationStructure.Location = mpTopLevelAS.GPUVirtualAddress;
CpuDescriptorHandle srvHandle = mpSrvUavHeap.GetCPUDescriptorHandleForHeapStart();
srvHandle.Ptr += mpDevice.GetDescriptorHandleIncrementSize(DescriptorHeapType.ConstantBufferViewShaderResourceViewUnorderedAccessView);
mpDevice.CreateShaderResourceView(null, srvDesc, srvHandle);
}
private int BeginFrame()
{
// Bind the descriptor heaps
ID3D12DescriptorHeap[] heaps = new ID3D12DescriptorHeap[] { mpSrvUavHeap };
mpCmdList.SetDescriptorHeaps(1, heaps);
return this.mpSwapChain.GetCurrentBackBufferIndex();
}
public bool DrawFrame(Action<int, int> draw, [CallerMemberName] string frameName = null)
{
int rtvIndex = BeginFrame();
// Let's raytrace
context.ResourceBarrier(mpCmdList, mpOutputResource, ResourceStates.CopySource, ResourceStates.UnorderedAccess);
DispatchRaysDescription raytraceDesc = new DispatchRaysDescription();
raytraceDesc.Width = mSwapChainRect.Width;
raytraceDesc.Height = mSwapChainRect.Height;
raytraceDesc.Depth = 1;
// RayGen is the first entry in the shader-table
raytraceDesc.RayGenerationShaderRecord.StartAddress = mpShaderTable.GPUVirtualAddress + 0 * mShaderTableEntrySize;
raytraceDesc.RayGenerationShaderRecord.SizeInBytes = mShaderTableEntrySize;
// Miss is the second entry in the shader-table
uint missOffset = 1 * mShaderTableEntrySize;
raytraceDesc.MissShaderTable.StartAddress = mpShaderTable.GPUVirtualAddress + missOffset;
raytraceDesc.MissShaderTable.StrideInBytes = mShaderTableEntrySize;
raytraceDesc.MissShaderTable.SizeInBytes = mShaderTableEntrySize; // Only a s single miss-entry
// Hit is the third entry in the shader-table
uint hitOffset = 2 * mShaderTableEntrySize;
raytraceDesc.HitGroupTable.StartAddress = mpShaderTable.GPUVirtualAddress + hitOffset;
raytraceDesc.HitGroupTable.StrideInBytes = mShaderTableEntrySize;
raytraceDesc.HitGroupTable.SizeInBytes = mShaderTableEntrySize;
// Bind the empty root signature
mpCmdList.SetComputeRootSignature(mpEmptyRootSig);
// Dispatch
mpCmdList.SetPipelineState1(mpPipelineState);
mpCmdList.DispatchRays(raytraceDesc);
// Copy the results to the back-buffer
context.ResourceBarrier(mpCmdList, mpOutputResource, ResourceStates.UnorderedAccess, ResourceStates.CopySource);
context.ResourceBarrier(mpCmdList, mFrameObjects[rtvIndex].pSwapChainBuffer, ResourceStates.Present, ResourceStates.CopyDestination);
mpCmdList.CopyResource(mFrameObjects[rtvIndex].pSwapChainBuffer, mpOutputResource);
EndFrame(rtvIndex);
return true;
}
private void EndFrame(int rtvIndex)
{
context.ResourceBarrier(mpCmdList, mFrameObjects[rtvIndex].pSwapChainBuffer, ResourceStates.CopyDestination, ResourceStates.Present);
mFenceValue = context.SubmitCommandList(mpCmdList, mpCmdQueue, mpFence, mFenceValue);
mpSwapChain.Present(0, 0);
// Prepare the command list for the next frame
int bufferIndex = mpSwapChain.GetCurrentBackBufferIndex();
// Make sure we have the new back-buffer is ready
if (mFenceValue > context.kDefaultSwapChainBuffers)
{
mpFence.SetEventOnCompletion(mFenceValue - context.kDefaultSwapChainBuffers + 1, mFenceEvent);
this.mFenceEvent.WaitOne();
}
mFrameObjects[bufferIndex].pCmdAllocator.Reset();
mpCmdList.Reset(mFrameObjects[bufferIndex].pCmdAllocator, null);
}
public void Dispose()
{
mFenceValue++;
mpCmdQueue.Signal(mpFence, mFenceValue);
mpFence.SetEventOnCompletion(mFenceValue, mFenceEvent);
mFenceEvent.WaitOne();
}
}
}