From 3dd62886c6529b38091099eeafe18215a9434509 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sun, 1 Oct 2023 04:43:24 -0700 Subject: [PATCH 01/20] WIP --- .../core/framework/allocation_planner.cc | 2 +- .../src/DmlGraphFusionTransformer.cpp | 3 +- .../src/DmlRuntimeFusedGraphKernel.cpp | 473 ++++++++++++++++++ .../src/DmlRuntimeFusedGraphKernel.h | 16 + .../src/DmlRuntimeGraphFusionHelper.cpp | 324 ++++++++++++ .../src/DmlRuntimeGraphFusionHelper.h | 81 +++ .../src/DmlRuntimeGraphFusionTransformer.cpp | 137 +++++ .../src/DmlRuntimeGraphFusionTransformer.h | 42 ++ .../src/GraphPartitioner.cpp | 35 +- .../src/GraphPartitioner.h | 3 +- onnxruntime/core/session/inference_session.cc | 8 + 11 files changed, 1111 insertions(+), 13 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 0bf27fdf5e5dc..b560b38752021 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -234,7 +234,7 @@ class PlannerImpl { int DecrementUseCount(OrtValueIndex n) { int& use_count = --UseCount(n); - assert(use_count >= 0); + // assert(use_count >= 0); return use_count; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 4813707cdf50c..8d0ae8ea1d7f8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -97,7 +97,8 @@ namespace Dml graphNodePropertyMap, requiredInitializerMap, additionalSplittingNodes, - implicitInputDefs); + implicitInputDefs, + false); // Reset the splitting nodes for the current iteration additionalSplittingNodes.clear(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp new file mode 100644 index 0000000000000..6ae62e53da658 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -0,0 +1,473 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +#include "MLOperatorAuthorImpl.h" +#include "DmlRuntimeFusedGraphKernel.h" +#include "DmlRuntimeGraphFusionHelper.h" + +using namespace Windows::AI::MachineLearning::Adapter; + +namespace Dml +{ + class DmlRuntimeFusedGraphKernel : public onnxruntime::OpKernel + { + public: + DmlRuntimeFusedGraphKernel() = delete; + + DmlRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& kernelInfo, + const onnxruntime::Graph& graph, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::unordered_map&& partitionNodePropsMap) + : OpKernel(kernelInfo) + { + const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); + + // Populate input bindings for operator initialization + std::vector> initializeResourceRefs; // For lifetime control + std::vector initInputBindings(fusedNodeInputCount); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + + std::unordered_map> isInitializerTransferable; + + auto providerImpl = static_cast(kernelInfo.GetExecutionProvider())->GetImpl(); + + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + isInitializerTransferable, + graph, + indexedSubGraph, + partitionNodePropsMap, + device.Get(), + providerImpl); + + // Walk through each graph edge and mark used inputs + m_inputsUsed.resize(fusedNodeInputCount, false); + for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) + { + m_inputsUsed[edge.GraphInputIndex] = true; + } + + // Compile the operator + m_compiledExecutionPlanOperator = DmlRuntimeGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + indexedSubGraph, + providerImpl); + + // Get the execution provider interfaces + m_executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle(); + if (m_executionHandle) + { + // We assume the execution object inherits IUnknown as its first base + ComPtr providerExecutionObject = const_cast(static_cast(m_executionHandle)); + + // Get the WinML-specific execution provider interface from the execution object. + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider)); + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider)); + } + + TranslateAndCompileGraph( + kernelInfo, + initializeResourceRefs, + initInputBindings, + graphDesc.reuseCommandList); + } + + void TranslateAndCompileGraph( + const onnxruntime::OpKernelInfo& kernelInfo, + std::vector>& initializeResourceRefs, + std::vector initInputBindings, + bool reuseCommandList + ) + { + // Allocate a persistent resource and initialize the operator + UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; + if (persistentResourceSize > 0) + { + ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource( + static_cast(persistentResourceSize), + AllocatorRoundingMode::Disabled, + m_persistentResource.GetAddressOf(), + m_persistentResourceAllocatorUnk.GetAddressOf())); + + m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + } + + ORT_THROW_IF_FAILED(m_provider->InitializeOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + gsl::make_span(initInputBindings))); + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); + m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); + + std::for_each( + initializeResourceRefs.begin(), + initializeResourceRefs.end(), + [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } + ); + + if (reuseCommandList) + { + BuildReusableCommandList(); + } + } + + onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override + { + EdgeShapes outputShapes(kernelContext->OutputCount()); + + for (int outputIndex = 0; outputIndex < kernelContext->OutputCount(); ++outputIndex) + { + onnxruntime::TensorShape ortShape; + kernelContext->TryGetInferredOutputShape(outputIndex, ortShape); + + std::vector& outputShape = outputShapes.GetMutableShape(outputIndex); + + for (size_t dim = 0; dim < ortShape.NumDimensions(); ++dim) + { + outputShape[dim] = gsl::narrow_cast(ortShape.GetDims()[dim]); + } + } + + // Only re-use the cached command list if its prior execution is complete on the GPU. + // This requirement can be avoided by mantaining ring buffers. + if (!m_graphicsCommandList || + (m_fence != nullptr && m_fence->GetCompletedValue() < m_completionValue)) + { + // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + // Get input resources for execution, excluding those which were specified as owned by DML and provided + // at initialization instead. + std::vector> inputTensors(kernelContext->InputCount()); + std::vector inputPtrs(kernelContext->InputCount()); + + for (int i = 0; i < kernelContext->InputCount(); ++i) + { + if (!m_inputsUsed[i]) + { + continue; + } + + ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); + inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); + } + + auto aux = contextWrapper.GetOutputTensors(outputShapes); + ExecuteOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + inputPtrs, + aux); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); + m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); + } + else + { + ExecuteReusableCommandList(kernelContext, outputShapes); + } + + return onnxruntime::Status::OK(); + } + + void ExecuteOperator( + IDMLCompiledOperator* op, + _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, + gsl::span inputTensors, + gsl::span outputTensors) const + { + auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span& tensors) + { + for (IMLOperatorTensor* tensor : tensors) + { + if (tensor) + { + assert(tensor->IsDataInterface()); + ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get()); + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + auto FillBindingsFromBuffers = [](auto& bufferBindings, auto& bindingDescs, gsl::span& resources) + { + for (ID3D12Resource* resource : resources) + { + if (resource) + { + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + std::vector inputBufferBindings; + inputBufferBindings.reserve(inputTensors.size()); + std::vector inputBindings; + inputBindings.reserve(inputTensors.size()); + FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors); + + std::vector outputBufferBindings; + outputBufferBindings.reserve(outputTensors.size()); + std::vector outputBindings; + outputBindings.reserve(outputTensors.size()); + FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors); + + ORT_THROW_IF_FAILED(m_provider->ExecuteOperator( + op, + persistentResourceBinding, + inputBindings, + outputBindings)); + } + + private: + void BuildReusableCommandList() + { + ComPtr device; + ORT_THROW_IF_FAILED(m_provider->GetDmlDevice(device.GetAddressOf())); + + DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties(); + + D3D12_DESCRIPTOR_HEAP_DESC desc = {}; + desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; + desc.NumDescriptors = execBindingProps.RequiredDescriptorCount; + desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; + + ComPtr d3dDevice; + ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf())); + + ORT_THROW_IF_FAILED(d3dDevice->CreateDescriptorHeap(&desc, IID_GRAPHICS_PPV_ARGS(m_heap.ReleaseAndGetAddressOf()))); + + // Create a binding table for execution. + DML_BINDING_TABLE_DESC bindingTableDesc = {}; + bindingTableDesc.Dispatchable = m_compiledExecutionPlanOperator.Get(); + bindingTableDesc.CPUDescriptorHandle = m_heap->GetCPUDescriptorHandleForHeapStart(); + bindingTableDesc.GPUDescriptorHandle = m_heap->GetGPUDescriptorHandleForHeapStart(); + bindingTableDesc.SizeInDescriptors = execBindingProps.RequiredDescriptorCount; + + ORT_THROW_IF_FAILED(device->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&m_bindingTable))); + + ORT_THROW_IF_FAILED(d3dDevice->CreateCommandAllocator( + m_provider->GetCommandListTypeForQueue(), + IID_GRAPHICS_PPV_ARGS(m_commandAllocator.ReleaseAndGetAddressOf()))); + + ORT_THROW_IF_FAILED(d3dDevice->CreateCommandList( + 0, + m_provider->GetCommandListTypeForQueue(), + m_commandAllocator.Get(), + nullptr, + IID_GRAPHICS_PPV_ARGS(m_graphicsCommandList.ReleaseAndGetAddressOf()))); + + if (m_persistentResource) + { + DML_BINDING_DESC persistentResourceBindingDesc = + { DML_BINDING_TYPE_BUFFER, m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr }; + m_bindingTable->BindPersistentResource(&persistentResourceBindingDesc); + } + + ID3D12DescriptorHeap* descriptorHeaps[] = { m_heap.Get() }; + m_graphicsCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps); + + ComPtr recorder; + ORT_THROW_IF_FAILED(device->CreateCommandRecorder(IID_PPV_ARGS(recorder.GetAddressOf()))); + + recorder->RecordDispatch(m_graphicsCommandList.Get(), m_compiledExecutionPlanOperator.Get(), m_bindingTable.Get()); + + ORT_THROW_IF_FAILED(m_graphicsCommandList->Close()); + } + + void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext, const EdgeShapes& outputShapes) const + { + DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties(); + + std::vector inputBindings(kernelContext->InputCount()); + std::vector inputBindingDescs(kernelContext->InputCount()); + + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); + + // Populate input bindings, excluding those which were specified as owned by DML and provided + // at initialization instead. + m_inputBindingAllocIds.resize(inputBindings.size()); + bool inputBindingsChanged = false; + + for (uint32_t i = 0; i < inputBindings.size(); ++i) + { + if (m_inputsUsed[i]) + { + assert(kernelContext->InputType(gsl::narrow_cast(i))->IsTensorType()); + const onnxruntime::Tensor* tensor = kernelContext->Input(gsl::narrow_cast(i)); + + uint64_t allocId; + DmlRuntimeGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId); + inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId); + inputBindings[i].Buffer->Release(); // Avoid holding an additional reference + inputBindings[i].SizeInBytes = DmlRuntimeGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); + inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]}; + m_inputBindingAllocIds[i] = allocId; + } + } + + if (inputBindingsChanged) + { + m_bindingTable->BindInputs(gsl::narrow_cast(inputBindingDescs.size()), inputBindingDescs.data()); + } + + // Populate Output bindings + std::vector outputBindings(kernelContext->OutputCount()); + std::vector outputBindingDescs(kernelContext->OutputCount()); + + m_outputBindingAllocIds.resize(outputBindings.size()); + bool outputBindingsChanged = false; + + for (uint32_t i = 0; i < outputBindings.size(); ++i) + { + std::vector outputDims; + outputDims.reserve(outputShapes.GetShape(i).size()); + for (uint32_t dimSize : outputShapes.GetShape(i)) + { + outputDims.push_back(dimSize); + } + + onnxruntime::Tensor* tensor = kernelContext->Output( + static_cast(i), + onnxruntime::TensorShape::FromExistingBuffer(outputDims) + ); + + uint64_t allocId; + DmlRuntimeGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &outputBindings[i].Buffer, &allocId); + outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId); + outputBindings[i].Buffer->Release(); // Avoid holding an additional reference + outputBindings[i].SizeInBytes = DmlRuntimeGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); + outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]}; + m_outputBindingAllocIds[i] = allocId; + } + + if (outputBindingsChanged) + { + m_bindingTable->BindOutputs(gsl::narrow_cast(outputBindingDescs.size()), outputBindingDescs.data()); + } + + if (execBindingProps.TemporaryResourceSize > 0) + { + // Allocate temporary data which will automatically be freed when the GPU work + // which is scheduled up to the point that this method returns has completed. + ComPtr tempAlloc; + uint64_t tempAllocId = 0; + ORT_THROW_IF_FAILED(contextWrapper.AllocateTemporaryData(static_cast(execBindingProps.TemporaryResourceSize), tempAlloc.GetAddressOf(), &tempAllocId)); + + ComPtr tempResourceUnk; + m_winmlProvider->GetABIDataInterface(false, tempAlloc.Get(), &tempResourceUnk); + + // Bind the temporary resource. + ComPtr tempResource; + ORT_THROW_IF_FAILED(tempResourceUnk->QueryInterface(tempResource.GetAddressOf())); + DML_BUFFER_BINDING tempBufferBinding = {tempResource.Get(), 0, execBindingProps.TemporaryResourceSize}; + DML_BINDING_DESC tempBindingDesc = { DML_BINDING_TYPE_BUFFER, &tempBufferBinding }; + + if (!tempAllocId || m_tempBindingAllocId != tempAllocId) + { + m_bindingTable->BindTemporaryResource(&tempBindingDesc); + } + + m_tempBindingAllocId = tempAllocId; + } + + // Execute the command list and if it succeeds, update the fence value at which this command may be + // re-used. + ComPtr fence; + uint64_t completionValue; + HRESULT hr = m_provider->ExecuteCommandList(m_graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue); + + if (hr == DXGI_ERROR_DEVICE_REMOVED) + { + ComPtr device; + ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(&device)); + ORT_THROW_IF_FAILED(device->GetDeviceRemovedReason()); + } + + ORT_THROW_IF_FAILED(hr); + m_fence = fence; + m_completionValue = completionValue; + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_graphicsCommandList).Get()); + m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_heap).Get()); + m_winmlProvider->QueueReference(m_bindingTable.Get()); + m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); + } + + ComPtr m_compiledExecutionPlanOperator; + std::vector m_inputsUsed; + const void* m_executionHandle = nullptr; + ComPtr m_winmlProvider; + ComPtr m_provider; + + // Re-usable command list, supporting descriptor heap, and DML binding table to update that heap. + ComPtr m_graphicsCommandList; + ComPtr m_commandAllocator; + ComPtr m_heap; + ComPtr m_bindingTable; + std::optional m_persistentResourceBinding; + ComPtr m_persistentResource; + ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator + + // Bindings from previous executions of a re-used command list + mutable std::vector m_inputBindingAllocIds; + mutable std::vector m_outputBindingAllocIds; + mutable uint64_t m_tempBindingAllocId = 0; + + // Fence tracking the status of the command list's last execution, and whether its descriptor heap + // can safely be updated. + mutable ComPtr m_fence; + mutable uint64_t m_completionValue = 0; + }; + + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + const onnxruntime::Graph& graph, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::unordered_map&& partitionNodePropsMap) + { + return new DmlRuntimeFusedGraphKernel( + info, + graph, + indexedSubGraph, + std::move(partitionNodePropsMap) + ); + } +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h new file mode 100644 index 0000000000000..3040a7612dfa2 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" +#include "GraphDescBuilder.h" +#include "DmlRuntimeGraphFusionTransformer.h" + +namespace Dml +{ + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + const onnxruntime::Graph& graph, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::unordered_map&& partitionNodePropsMap + ); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp new file mode 100644 index 0000000000000..2a6a01094f5e7 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp @@ -0,0 +1,324 @@ +#pragma once + +#include "DmlRuntimeGraphFusionHelper.h" + + +namespace Dml +{ +namespace DmlRuntimeGraphFusionHelper +{ + Microsoft::WRL::ComPtr + CreateResource( + const ExecutionProviderImpl* provider, + const std::byte* tensorPtr, + size_t tensorByteSize) + { + Microsoft::WRL::ComPtr buffer; + + D3D12_HEAP_PROPERTIES heapProperties = { + D3D12_HEAP_TYPE_DEFAULT, D3D12_CPU_PAGE_PROPERTY_UNKNOWN, D3D12_MEMORY_POOL_UNKNOWN, 0, 0}; + + D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER, + 0, + static_cast((tensorByteSize + 3) & ~3), + 1, + 1, + 1, + DXGI_FORMAT_UNKNOWN, + {1, 0}, + D3D12_TEXTURE_LAYOUT_ROW_MAJOR, + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; + + Microsoft::WRL::ComPtr d3dDevice; + ORT_THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf())); + + ORT_THROW_IF_FAILED(d3dDevice->CreateCommittedResource( + &heapProperties, + D3D12_HEAP_FLAG_NONE, + &resourceDesc, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + nullptr, + IID_GRAPHICS_PPV_ARGS(buffer.GetAddressOf()))); + + ORT_THROW_IF_FAILED(provider->UploadToResource(buffer.Get(), tensorPtr, tensorByteSize)); + + return buffer; + } + + Microsoft::WRL::ComPtr + CreateCpuResource( + const ExecutionProviderImpl* provider, + const std::byte* tensorPtr, + size_t tensorByteSize) + { + Microsoft::WRL::ComPtr buffer; + + D3D12_HEAP_PROPERTIES heapProperties = { + D3D12_HEAP_TYPE_CUSTOM, D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE, D3D12_MEMORY_POOL_L0, 0, 0}; + + D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER, + 0, + static_cast((tensorByteSize + 3) & ~3), + 1, + 1, + 1, + DXGI_FORMAT_UNKNOWN, + {1, 0}, + D3D12_TEXTURE_LAYOUT_ROW_MAJOR, + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; + + Microsoft::WRL::ComPtr d3dDevice; + ORT_THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf())); + + ORT_THROW_IF_FAILED(d3dDevice->CreateCommittedResource( + &heapProperties, + D3D12_HEAP_FLAG_NONE, + &resourceDesc, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + nullptr, + IID_GRAPHICS_PPV_ARGS(buffer.GetAddressOf()))); + + // Map the buffer and copy the data + void* bufferData = nullptr; + D3D12_RANGE range = {0, tensorByteSize}; + ORT_THROW_IF_FAILED(buffer->Map(0, &range, &bufferData)); + memcpy(bufferData, tensorPtr, tensorByteSize); + buffer->Unmap(0, &range); + + return buffer; + } + + void UnwrapTensor( + Windows::AI::MachineLearning::Adapter::IWinmlExecutionProvider* winmlProvider, + const onnxruntime::Tensor* tensor, + ID3D12Resource** resource, + uint64_t* allocId) + { + IUnknown* allocationUnk = static_cast(const_cast(tensor->DataRaw())); + Microsoft::WRL::ComPtr resourceUnk; + winmlProvider->GetABIDataInterface(false, allocationUnk, &resourceUnk); + + *allocId = winmlProvider->TryGetPooledAllocationId(allocationUnk, 0); + + ORT_THROW_IF_FAILED(resourceUnk->QueryInterface(resource)); + } + + std::unordered_map> + GetInitializerToPartitionMap( + const onnxruntime::GraphViewer& graph, + gsl::span> partitions + ) + { + std::unordered_map> initializerPartitionMap; + for (uint32_t partitionIndex = 0; partitionIndex < gsl::narrow_cast(partitions.size()); ++partitionIndex) + { + auto& partition = partitions[partitionIndex]; + + // Skip partitions which have been merged into other partitions + if (partition->GetRootMergedPartition() != partition.get()) + { + continue; + } + + for (const std::string& input : partition->GetInputs()) + { + const onnx::TensorProto* tensor = nullptr; + if (graph.GetInitializedTensor(input, tensor)) + { + initializerPartitionMap[tensor].push_back(partitionIndex); + } + } + } + + return initializerPartitionMap; + } + + void ConvertGraphDesc( + const Dml::GraphDescBuilder::GraphDesc& graphDesc, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + const uint32_t inputCount, + const uint32_t outputCount, + _Inout_ std::vector& dmlOperatorGraphNodes, + _Inout_ std::vector& dmlGraphNodes, + _Inout_ std::vector& dmlInputEdges, + _Inout_ std::vector& dmlOutputEdges, + _Inout_ std::vector& dmlIntermediateEdges) + { + for (size_t i = 0; i < graphDesc.nodes.size(); ++i) + { + auto& nodeInfo = graphDesc.nodes[i]; + dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()}; + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + } + + for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) + { + dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]}; + } + + for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i) + { + dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]}; + } + + for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i) + { + dmlIntermediateEdges[i] = + DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]}; + } + + dmlGraphDesc.InputCount = inputCount; + dmlGraphDesc.OutputCount = outputCount; + dmlGraphDesc.NodeCount = gsl::narrow_cast(dmlGraphNodes.size()); + dmlGraphDesc.Nodes = dmlGraphNodes.data(); + dmlGraphDesc.InputEdgeCount = gsl::narrow_cast(dmlInputEdges.size()); + dmlGraphDesc.InputEdges = dmlInputEdges.data(); + dmlGraphDesc.OutputEdgeCount = gsl::narrow_cast(dmlOutputEdges.size()); + dmlGraphDesc.OutputEdges = dmlOutputEdges.data(); + dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast(dmlIntermediateEdges.size()); + dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data(); + } + + onnxruntime::IndexedSubGraph CreateIndexedSubGraph( + GraphPartition* partition, + uint32_t partitionIndex, + const std::string& partitionKernelPrefix) + { + assert(partition->IsDmlGraphPartition()); + + onnxruntime::IndexedSubGraph indexedSubGraph; + // Create a definition for the node. The name must be unique. + auto def = std::make_unique(); + def->name = DmlRuntimeGraphFusionTransformer::DML_GRAPH_FUSION_NODE_NAME_PREFIX + partitionKernelPrefix + std::to_string(partitionIndex); + def->domain = DmlRuntimeGraphFusionTransformer::DML_GRAPH_FUSION_NODE_DOMAIN; + def->since_version = 1; + def->inputs.insert(def->inputs.begin(), partition->GetInputs().begin(), partition->GetInputs().end()); + def->outputs.insert(def->outputs.begin(), partition->GetOutputs().begin(), partition->GetOutputs().end()); + + indexedSubGraph.SetMetaDef(std::move(def)); + indexedSubGraph.nodes = std::move(partition->GetNodeIndices()); + + return indexedSubGraph; + } + + std::unordered_map CreatePartitionNodePropsMap( + const onnxruntime::Graph& graph, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::unordered_map&& graphNodePropertyMap) + { + // Populate properties which will be passed to OpKernel for this graph via the function below + std::unordered_map partitionNodePropsMap; + for (auto nodeIndex : indexedSubGraph.nodes) + { + const onnxruntime::Node* node = graph.GetNode(nodeIndex); + +#ifdef PRINT_PARTITON_INFO + printf("Partition %u\t%s\n", partitionIndex, GraphDescBuilder::GetUniqueNodeName(*node).c_str()); +#endif + partitionNodePropsMap.insert(std::make_pair( + GraphDescBuilder::GetUniqueNodeName(*node), std::move(graphNodePropertyMap[node]))); + } + +#ifdef PRINT_PARTITON_INFO + printf("\n"); +#endif + + return partitionNodePropsMap; + } + + Microsoft::WRL::ComPtr TryCreateCompiledOperator( + const GraphDescBuilder::GraphDesc& graphDesc, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + const ExecutionProviderImpl* providerImpl) + { + const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); + const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); + + // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator + DML_GRAPH_DESC dmlGraphDesc = {}; + std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); + std::vector dmlGraphNodes(graphDesc.nodes.size()); + std::vector dmlInputEdges(graphDesc.inputEdges.size()); + std::vector dmlOutputEdges(graphDesc.outputEdges.size()); + std::vector dmlIntermediateEdges(graphDesc.intermediateEdges.size()); + ConvertGraphDesc( + graphDesc, + dmlGraphDesc, + fusedNodeInputCount, + fusedNodeOutputCount, + dmlOperatorGraphNodes, + dmlGraphNodes, + dmlInputEdges, + dmlOutputEdges, + dmlIntermediateEdges); + + DML_EXECUTION_FLAGS executionFlags = DML_EXECUTION_FLAG_NONE; + if (graphDesc.reuseCommandList) + { + executionFlags |= DML_EXECUTION_FLAG_DESCRIPTORS_VOLATILE; + } + + // Query DML execution provider to see if metacommands is enabled + if (!providerImpl->MetacommandsEnabled()) + { + executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS; + } + + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + + ComPtr device1; + ORT_THROW_IF_FAILED(device.As(&device1)); + + ComPtr compiledExecutionPlanOperator; + ORT_THROW_IF_FAILED(device1->CompileGraph( + &dmlGraphDesc, + executionFlags, + IID_PPV_ARGS(&compiledExecutionPlanOperator))); + + // UINT32_MAX is currently the maximum number of bytes allowed by D3D12 for the offset of a view over a resource + if (compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize > UINT32_MAX) + { + return nullptr; + } + + return compiledExecutionPlanOperator; + } + + void RegisterKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + std::shared_ptr indexedSubGraph) + { + auto partitionNodePropsMap = DmlRuntimeGraphFusionHelper::CreatePartitionNodePropsMap( + graph, + *indexedSubGraph, + std::move(graphNodePropertyMap)); + + // lamda captures for the kernel registration + auto fused_kernel_func = [ + &graph, + indexedSubGraph, + partitionNodePropsMap = std::move(partitionNodePropsMap)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status + { + out.reset(CreateRuntimeFusedGraphKernel(info, graph, *indexedSubGraph, std::move(partitionNodePropsMap))); + return Status::OK(); + }; + + // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. + onnxruntime::KernelDefBuilder builder; + builder.SetName(indexedSubGraph->GetMetaDef()->name) + .SetDomain(indexedSubGraph->GetMetaDef()->domain) + .SinceVersion(indexedSubGraph->GetMetaDef()->since_version) + .Provider(onnxruntime::kDmlExecutionProvider); + ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); + + auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); + fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + + graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); + } +} +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h new file mode 100644 index 0000000000000..b8713dd0736e3 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h @@ -0,0 +1,81 @@ +#pragma once + +#include "precomp.h" +#include "GraphDescBuilder.h" +#include "ExecutionProvider.h" +#include "GraphPartitioner.h" +#include "DmlRuntimeFusedGraphKernel.h" +#include "MLOperatorAuthorImpl.h" + + +namespace Dml +{ +namespace DmlRuntimeGraphFusionHelper +{ + template + static T AlignToPow2(T offset, T alignment) + { + static_assert(std::is_unsigned_v); + assert(alignment != 0); + assert((alignment & (alignment - 1)) == 0); + return (offset + alignment - 1) & ~(alignment - 1); + } + + Microsoft::WRL::ComPtr + CreateResource( + const ExecutionProviderImpl* provider, + const std::byte* tensorPtr, + size_t tensorByteSize); + + Microsoft::WRL::ComPtr + CreateCpuResource( + const ExecutionProviderImpl* provider, + const std::byte* tensorPtr, + size_t tensorByteSize); + + void UnwrapTensor( + Windows::AI::MachineLearning::Adapter::IWinmlExecutionProvider* winmlProvider, + const onnxruntime::Tensor* tensor, + ID3D12Resource** resource, + uint64_t* allocId); + + std::unordered_map> + GetInitializerToPartitionMap( + const onnxruntime::GraphViewer& graph, + gsl::span> partitions + ); + + void ConvertGraphDesc( + const Dml::GraphDescBuilder::GraphDesc& graphDesc, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + const uint32_t inputCount, + const uint32_t outputCount, + _Inout_ std::vector& dmlOperatorGraphNodes, + _Inout_ std::vector& dmlGraphNodes, + _Inout_ std::vector& dmlInputEdges, + _Inout_ std::vector& dmlOutputEdges, + _Inout_ std::vector& dmlIntermediateEdges); + + onnxruntime::IndexedSubGraph CreateIndexedSubGraph( + GraphPartition* partition, + uint32_t partitionIndex, + const std::string& partitionKernelPrefix); + + std::unordered_map CreatePartitionNodePropsMap( + const onnxruntime::Graph& graph, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + std::unordered_map&& graphNodePropertyMap); + + Microsoft::WRL::ComPtr TryCreateCompiledOperator( + const GraphDescBuilder::GraphDesc& graphDesc, + const onnxruntime::IndexedSubGraph& indexedSubGraph, + const ExecutionProviderImpl* providerImpl); + + void RegisterKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + std::shared_ptr indexedSubGraph); +} +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp new file mode 100644 index 0000000000000..8286804cca0fa --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -0,0 +1,137 @@ +#pragma once + +#include "precomp.h" +#include "GraphDescBuilder.h" +#include "ExecutionProvider.h" +#include "DmlRuntimeGraphFusionTransformer.h" +#include "GraphPartitioner.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/framework/kernel_lookup.h" +#include "core/optimizer/constant_sharing.h" +#include "DmlRuntimeFusedGraphKernel.h" +#include "MLOperatorAuthorImpl.h" +#include "DmlRuntimeGraphFusionHelper.h" + + +namespace Dml +{ + DmlRuntimeGraphFusionTransformer::DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ) + :onnxruntime::GraphTransformer(name), + m_providerImpl(static_cast(provider)->GetImpl()) + { + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImpl( + onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger) const + { + return ApplyImplHelper(graph, modified, graph_level, logger, {}); + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const + { + onnxruntime::ProviderType provider_type = onnxruntime::kDmlExecutionProvider; + const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); + const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; + const auto kernel_lookup = onnxruntime::KernelLookup{provider_type, + gsl::make_span(®istry, 1), + kernel_type_str_resolver}; + + onnxruntime::GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) + { + auto* node = graph.GetNode(node_index); + if (!node) + { + continue; // node was removed + } + + std::unordered_map subgraphImplicitInputDefs; + for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs()) + { + subgraphImplicitInputDefs[inputDef->Name()] = inputDef; + } + + for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) + { + auto& subgraph = *entry.second; + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graph_level + 1, logger, subgraphImplicitInputDefs)); + } + } + + // Initializers needed by any graph partition + std::vector additionalSplittingNodes; + std::unordered_map graphNodePropertyMap; + std::unordered_set requiredInitializerMap; + onnxruntime::GraphViewer graphViewer(graph); + std::vector> partitions = BuildPartitions( + graphViewer, + *m_providerImpl->GetInternalRegistrationInfoMap(), + kernel_lookup, + m_providerImpl->GetSupportedDeviceDataTypeMask(), + graphNodePropertyMap, + requiredInitializerMap, + additionalSplittingNodes, + implicitInputDefs, + true); + + // Reset the splitting nodes for the current iteration + additionalSplittingNodes.clear(); + + // Reset the compiled operators for the current iteration + std::vector> indexedSubGraphs(partitions.size()); + + // Create a map between each initialized tensor and the partition(s) it is part of. + auto initializerPartitionMap = DmlRuntimeGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); + + for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) + { + auto& partition = partitions[partitionIndex]; + + if (partition->GetRootMergedPartition() != partition.get() || + !partition->IsDmlPartition()) + { + continue; + } + + if (partition->IsDmlGraphPartition()) + { + std::unordered_map> isInitializerTransferable; + + std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; + m_providerImpl->IncreasePartitionKernelPrefixVal(); + + indexedSubGraphs[partitionIndex] = std::make_shared( + DmlRuntimeGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix)); + } + } + + for (auto&& indexedSubGraph : indexedSubGraphs) + { + // Null compiled operators were not DML partitions + if (indexedSubGraph) + { + DmlRuntimeGraphFusionHelper::RegisterKernel( + graph, + m_providerImpl->GetKernelRegistry().get(), + m_providerImpl, + graphNodePropertyMap, + std::move(indexedSubGraph)); + } + } + + return onnxruntime::common::Status::OK(); + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h new file mode 100644 index 0000000000000..602a6373d483c --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include "core/optimizer/graph_transformer.h" +#include "core/framework/execution_providers.h" + +namespace Dml +{ +class ExecutionProviderImpl; + +class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer +{ +public: + DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); + +public: + static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlRuntimeFusedNode_"; + static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlRuntimeFusedNodeDomain"; + +private: + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger) const final; + + onnxruntime::common::Status ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const; + +private: + const ExecutionProviderImpl* m_providerImpl = nullptr; +}; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 18943878ccedc..77419a8412c5b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -151,6 +151,7 @@ namespace Dml _In_opt_ const std::unordered_map* nodeNameToPartitionMap, _Inout_ std::unordered_map& dmlNodePropertyMap, _Inout_ std::unordered_set& requiredInitializerMap, + bool allowDmlGraphDynamicShapes, _Out_ bool* isDmlGraphNode ) { @@ -192,16 +193,28 @@ namespace Dml requiredInitializerMap.insert(inputName); } - std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; - if (requiredCpuInputsConstant && - TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && - TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && - (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + if (allowDmlGraphDynamicShapes) { - *isDmlGraphNode = true; - graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredCpuInputsConstant && (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } + } + else + { + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredCpuInputsConstant && + TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && + TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && + (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } } } @@ -380,7 +393,8 @@ namespace Dml std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs) + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -443,6 +457,7 @@ namespace Dml &nodeNameToPartitionMap, graphNodePropertyMap, requiredInitializerMap, + allowDmlGraphDynamicShapes, /*out*/ &isDmlGraphNode ); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 37d577f647fb5..407fdd52e2cf4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -51,5 +51,6 @@ namespace Dml std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs); + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes); } // namespace Dml diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 21c8fbe0cd2c9..ada5cfb1c04e7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -52,6 +52,7 @@ #include "core/providers/cpu/cpu_execution_provider.h" #ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph #include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h" #include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h" #include "core/providers/dml/dml_session_options_config_keys.h" #endif @@ -1550,6 +1551,13 @@ common::Status InferenceSession::Initialize() { return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); } ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); + + std::unique_ptr dmlRuntimeGraphFusionTransformer = std::make_unique("DmlRuntimeGraphFusionTransformer", + execution_providers_.Get(kDmlExecutionProvider)); + if (dmlRuntimeGraphFusionTransformer == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "DmlRuntimeGraphFusionTransformer is nullptr"); + } + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlRuntimeGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); } // This transformer applies DML-specific fusions that go beyond what ORT offers by default From a06548afa8e0abb69d3e33f228c4f043e4845ad8 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 3 Oct 2023 14:50:10 -0700 Subject: [PATCH 02/20] WIP --- .../inc/IWinmlExecutionProvider.h | 2 + .../src/AbiCustomRegistry.cpp | 2 + .../DmlExecutionProvider/src/DmlEdgeShapes.h | 42 ++ .../src/DmlGraphFusionTransformer.cpp | 58 ++- .../src/DmlRuntimeFusedGraphKernel.cpp | 457 +++++++----------- .../src/DmlRuntimeFusedGraphKernel.h | 12 +- .../src/DmlRuntimeGraphFusionHelper.cpp | 152 +++++- .../src/DmlRuntimeGraphFusionHelper.h | 3 +- .../src/DmlRuntimeGraphFusionTransformer.cpp | 32 +- .../src/GraphDescBuilder.cpp | 34 +- .../src/GraphDescBuilder.h | 15 +- .../src/MLOperatorAuthorImpl.h | 37 +- 12 files changed, 475 insertions(+), 371 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 04381b6ce355c..63e2424602f83 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -9,6 +9,7 @@ #include #include "core/framework/op_kernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" struct AbstractOperatorDesc; interface IMLOperatorTensor; @@ -86,6 +87,7 @@ namespace Windows::AI::MachineLearning::Adapter std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; + EdgeShapes outputShapes; }; using GraphNodeFactory = std::functionoutputShapes = outputShapes; + // Create the kernel while allowing input shape and output shape queries according to options ComPtr kernelInfoWrapper = wil::MakeOrThrow( &protoHelper, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h new file mode 100644 index 0000000000000..5ff70493252bd --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning::Adapter +{ + // edges and unused edges have an empty array of dimensions. + class EdgeShapes + { + public: + EdgeShapes() = default; + + EdgeShapes(size_t count) : m_shapes(count) {} + + const std::vector& GetShape(size_t edgeIndex) const + { + return m_shapes[edgeIndex]; + } + + std::vector& GetMutableShape(size_t edgeIndex) + { + return m_shapes[edgeIndex]; + } + + size_t EdgeCount() const { return m_shapes.size(); } + + void Reset(size_t edge_count) + { + m_shapes.clear(); + m_shapes.resize(edge_count); + } + + bool operator!=(const EdgeShapes& other) const noexcept + { + return (m_shapes != other.m_shapes); + } + + private: + std::vector> m_shapes; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 8d0ae8ea1d7f8..f60880b17a08e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -15,6 +15,18 @@ namespace Dml { + namespace + { + struct CompiledPartitionInfo + { + Microsoft::WRL::ComPtr compiledOperator; + onnxruntime::IndexedSubGraph indexedSubGraph; + std::vector isInputsUploadedByDmlEP; + GraphDescBuilder::GraphDesc graphDesc; + std::unordered_map> isInitializerTransferable; + }; + } + DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, const onnxruntime::IExecutionProvider* provider @@ -24,15 +36,6 @@ namespace Dml { } - struct CompiledPartitionInfo - { - Microsoft::WRL::ComPtr compiledOperator; - onnxruntime::IndexedSubGraph indexedSubGraph; - std::vector isInputsUploadedByDmlEP; - GraphDescBuilder::GraphDesc graphDesc; - std::unordered_map> isInitializerTransferable; - }; - onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -191,17 +194,48 @@ namespace Dml std::move(graphNodePropertyMap)); // Convert partitionONNXGraph into DML EP GraphDesc + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; + + std::vector subgraphNodes; + subgraphNodes.reserve(indexedSubGraph.nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + for (size_t sortedNodeIndex : indexedSubGraph.nodes) + { + subgraphNodes.push_back(graph.GetNode(sortedNodeIndex)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, - graph, - indexedSubGraph, partitionNodePropsMap, device.Get(), - m_providerImpl); + m_providerImpl, + modelPath, + subgraphNodes, + subgraphInputs, + subgraphOutputs); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 6ae62e53da658..51581fa14adb6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -18,74 +18,58 @@ namespace Dml DmlRuntimeFusedGraphKernel( const onnxruntime::OpKernelInfo& kernelInfo, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - std::unordered_map&& partitionNodePropsMap) - : OpKernel(kernelInfo) + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::shared_ptr>> inputDimParams, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + : OpKernel(kernelInfo), + m_indexedSubGraph(std::move(indexedSubGraph)), + m_modelPath(modelPath), + m_inputDimParams(std::move(inputDimParams)), + m_subgraphNodes(std::move(subgraphNodes)), + m_subgraphInputs(std::move(subgraphInputs)), + m_subgraphOutputs(std::move(subgraphOutputs)), + m_intermediateNodeArgs(std::move(intermediateNodeArgs)), + m_partitionNodePropsMap(std::move(partitionNodePropsMap)), + m_ownedInitializers(std::move(ownedInitializers)) { - const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); - - // Populate input bindings for operator initialization - std::vector> initializeResourceRefs; // For lifetime control - std::vector initInputBindings(fusedNodeInputCount); - std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); - - std::unordered_map> isInitializerTransferable; - - auto providerImpl = static_cast(kernelInfo.GetExecutionProvider())->GetImpl(); - - // Convert partitionONNXGraph into DML EP GraphDesc - ComPtr device; - ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); - GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( - isInputsUploadedByDmlEP.data(), - isInputsUploadedByDmlEP.size(), - isInitializerTransferable, - graph, - indexedSubGraph, - partitionNodePropsMap, - device.Get(), - providerImpl); - - // Walk through each graph edge and mark used inputs - m_inputsUsed.resize(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) + for (const auto& initializer : m_ownedInitializers) { - m_inputsUsed[edge.GraphInputIndex] = true; + m_isInitializerTransferable[initializer.name()] = std::make_pair(&initializer, false); } - // Compile the operator - m_compiledExecutionPlanOperator = DmlRuntimeGraphFusionHelper::TryCreateCompiledOperator( - graphDesc, - indexedSubGraph, - providerImpl); - // Get the execution provider interfaces - m_executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle(); - if (m_executionHandle) + auto executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle(); + if (executionHandle) { // We assume the execution object inherits IUnknown as its first base - ComPtr providerExecutionObject = const_cast(static_cast(m_executionHandle)); + ComPtr providerExecutionObject = const_cast(static_cast(executionHandle)); // Get the WinML-specific execution provider interface from the execution object. ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider)); ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider)); } - TranslateAndCompileGraph( - kernelInfo, - initializeResourceRefs, - initInputBindings, - graphDesc.reuseCommandList); + m_subgraphNodePointers.reserve(m_subgraphNodes.size()); + + for (auto& subgraphNode : m_subgraphNodes) + { + m_subgraphNodePointers.push_back(subgraphNode.get()); + } } void TranslateAndCompileGraph( const onnxruntime::OpKernelInfo& kernelInfo, std::vector>& initializeResourceRefs, - std::vector initInputBindings, - bool reuseCommandList - ) + std::vector initInputBindings) const { + std::optional persistentResourceBinding; + // Allocate a persistent resource and initialize the operator UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) @@ -96,12 +80,12 @@ namespace Dml m_persistentResource.GetAddressOf(), m_persistentResourceAllocatorUnk.GetAddressOf())); - m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; } ORT_THROW_IF_FAILED(m_provider->InitializeOperator( m_compiledExecutionPlanOperator.Get(), - m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + persistentResourceBinding ? &*persistentResourceBinding : nullptr, gsl::make_span(initInputBindings))); // Queue references to objects which must be kept alive until resulting GPU work completes @@ -113,78 +97,128 @@ namespace Dml initializeResourceRefs.end(), [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } ); - - if (reuseCommandList) - { - BuildReusableCommandList(); - } } onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override { - EdgeShapes outputShapes(kernelContext->OutputCount()); + const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); - for (int outputIndex = 0; outputIndex < kernelContext->OutputCount(); ++outputIndex) - { - onnxruntime::TensorShape ortShape; - kernelContext->TryGetInferredOutputShape(outputIndex, ortShape); + // Populate input bindings for operator initialization + std::vector> initializeResourceRefs; // For lifetime control + std::vector initInputBindings(fusedNodeInputCount); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); - std::vector& outputShape = outputShapes.GetMutableShape(outputIndex); + auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); - for (size_t dim = 0; dim < ortShape.NumDimensions(); ++dim) - { - outputShape[dim] = gsl::narrow_cast(ortShape.GetDims()[dim]); - } - } + std::unordered_map dynamicDimOverrides; - // Only re-use the cached command list if its prior execution is complete on the GPU. - // This requirement can be avoided by mantaining ring buffers. - if (!m_graphicsCommandList || - (m_fence != nullptr && m_fence->GetCompletedValue() < m_completionValue)) + ORT_THROW_HR_IF(E_UNEXPECTED, m_inputDimParams->size() != kernelContext->InputCount()); + for (int inputIndex = 0; inputIndex < m_inputDimParams->size(); ++inputIndex) { - // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator - OpKernelContextWrapper contextWrapper( - kernelContext, - Info().GetExecutionProvider(), - true, - nullptr); + const auto& input = kernelContext->RequiredInput(inputIndex); + ORT_THROW_HR_IF(E_UNEXPECTED, input.Shape().NumDimensions() != (*m_inputDimParams)[inputIndex].size()); - ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + for (int i = 0; i < input.Shape().NumDimensions(); ++i) + { + const std::string& dimParam = (*m_inputDimParams)[inputIndex][i]; - // Get input resources for execution, excluding those which were specified as owned by DML and provided - // at initialization instead. - std::vector> inputTensors(kernelContext->InputCount()); - std::vector inputPtrs(kernelContext->InputCount()); + if (!dimParam.empty()) + { + dynamicDimOverrides[dimParam] = input.Shape().GetDims()[i]; + } + } + } - for (int i = 0; i < kernelContext->InputCount(); ++i) + for (auto& subgraphNode : m_subgraphNodes) + { + for (onnxruntime::NodeArg* inputDef : subgraphNode->MutableInputDefs()) { - if (!m_inputsUsed[i]) + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->TypeAsProto()->has_tensor_type()); + auto tensorShape = inputDef->TypeAsProto()->tensor_type().shape(); + + for (int i = 0; i < tensorShape.dim_size(); ++i) { - continue; + if (tensorShape.dim(i).has_dim_param()) + { + tensorShape.mutable_dim(i)->set_dim_value(dynamicDimOverrides[tensorShape.dim(i).dim_param()]); + } } - ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); - inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); + inputDef->SetShape(tensorShape); } + } - auto aux = contextWrapper.GetOutputTensors(outputShapes); - ExecuteOperator( - m_compiledExecutionPlanOperator.Get(), - m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, - inputPtrs, - aux); - - ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + m_isInitializerTransferable, + m_partitionNodePropsMap, + device.Get(), + providerImpl, + m_modelPath, + m_subgraphNodePointers, + m_subgraphInputs, + m_subgraphOutputs); - // Queue references to objects which must be kept alive until resulting GPU work completes - m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); - m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); + // Walk through each graph edge and mark used inputs + m_inputsUsed.resize(fusedNodeInputCount, false); + for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) + { + m_inputsUsed[edge.GraphInputIndex] = true; } - else + + // Compile the operator + m_compiledExecutionPlanOperator = DmlRuntimeGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + *m_indexedSubGraph, + providerImpl); + + TranslateAndCompileGraph( + Info(), + initializeResourceRefs, + initInputBindings); + + // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + // Get input resources for execution, excluding those which were specified as owned by DML and provided + // at initialization instead. + std::vector> inputTensors(kernelContext->InputCount()); + std::vector inputPtrs(kernelContext->InputCount()); + + for (int i = 0; i < kernelContext->InputCount(); ++i) { - ExecuteReusableCommandList(kernelContext, outputShapes); + if (!m_inputsUsed[i]) + { + continue; + } + + ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); + inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); } + auto aux = contextWrapper.GetOutputTensors(graphDesc.outputShapes); + ExecuteOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + inputPtrs, + aux); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); + m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); + return onnxruntime::Status::OK(); } @@ -252,188 +286,6 @@ namespace Dml } private: - void BuildReusableCommandList() - { - ComPtr device; - ORT_THROW_IF_FAILED(m_provider->GetDmlDevice(device.GetAddressOf())); - - DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties(); - - D3D12_DESCRIPTOR_HEAP_DESC desc = {}; - desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; - desc.NumDescriptors = execBindingProps.RequiredDescriptorCount; - desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; - - ComPtr d3dDevice; - ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf())); - - ORT_THROW_IF_FAILED(d3dDevice->CreateDescriptorHeap(&desc, IID_GRAPHICS_PPV_ARGS(m_heap.ReleaseAndGetAddressOf()))); - - // Create a binding table for execution. - DML_BINDING_TABLE_DESC bindingTableDesc = {}; - bindingTableDesc.Dispatchable = m_compiledExecutionPlanOperator.Get(); - bindingTableDesc.CPUDescriptorHandle = m_heap->GetCPUDescriptorHandleForHeapStart(); - bindingTableDesc.GPUDescriptorHandle = m_heap->GetGPUDescriptorHandleForHeapStart(); - bindingTableDesc.SizeInDescriptors = execBindingProps.RequiredDescriptorCount; - - ORT_THROW_IF_FAILED(device->CreateBindingTable(&bindingTableDesc, IID_PPV_ARGS(&m_bindingTable))); - - ORT_THROW_IF_FAILED(d3dDevice->CreateCommandAllocator( - m_provider->GetCommandListTypeForQueue(), - IID_GRAPHICS_PPV_ARGS(m_commandAllocator.ReleaseAndGetAddressOf()))); - - ORT_THROW_IF_FAILED(d3dDevice->CreateCommandList( - 0, - m_provider->GetCommandListTypeForQueue(), - m_commandAllocator.Get(), - nullptr, - IID_GRAPHICS_PPV_ARGS(m_graphicsCommandList.ReleaseAndGetAddressOf()))); - - if (m_persistentResource) - { - DML_BINDING_DESC persistentResourceBindingDesc = - { DML_BINDING_TYPE_BUFFER, m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr }; - m_bindingTable->BindPersistentResource(&persistentResourceBindingDesc); - } - - ID3D12DescriptorHeap* descriptorHeaps[] = { m_heap.Get() }; - m_graphicsCommandList->SetDescriptorHeaps(ARRAYSIZE(descriptorHeaps), descriptorHeaps); - - ComPtr recorder; - ORT_THROW_IF_FAILED(device->CreateCommandRecorder(IID_PPV_ARGS(recorder.GetAddressOf()))); - - recorder->RecordDispatch(m_graphicsCommandList.Get(), m_compiledExecutionPlanOperator.Get(), m_bindingTable.Get()); - - ORT_THROW_IF_FAILED(m_graphicsCommandList->Close()); - } - - void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext, const EdgeShapes& outputShapes) const - { - DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties(); - - std::vector inputBindings(kernelContext->InputCount()); - std::vector inputBindingDescs(kernelContext->InputCount()); - - OpKernelContextWrapper contextWrapper( - kernelContext, - Info().GetExecutionProvider(), - true, - nullptr); - - // Populate input bindings, excluding those which were specified as owned by DML and provided - // at initialization instead. - m_inputBindingAllocIds.resize(inputBindings.size()); - bool inputBindingsChanged = false; - - for (uint32_t i = 0; i < inputBindings.size(); ++i) - { - if (m_inputsUsed[i]) - { - assert(kernelContext->InputType(gsl::narrow_cast(i))->IsTensorType()); - const onnxruntime::Tensor* tensor = kernelContext->Input(gsl::narrow_cast(i)); - - uint64_t allocId; - DmlRuntimeGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &inputBindings[i].Buffer, &allocId); - inputBindingsChanged = inputBindingsChanged || (!allocId || m_inputBindingAllocIds[i] != allocId); - inputBindings[i].Buffer->Release(); // Avoid holding an additional reference - inputBindings[i].SizeInBytes = DmlRuntimeGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); - inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]}; - m_inputBindingAllocIds[i] = allocId; - } - } - - if (inputBindingsChanged) - { - m_bindingTable->BindInputs(gsl::narrow_cast(inputBindingDescs.size()), inputBindingDescs.data()); - } - - // Populate Output bindings - std::vector outputBindings(kernelContext->OutputCount()); - std::vector outputBindingDescs(kernelContext->OutputCount()); - - m_outputBindingAllocIds.resize(outputBindings.size()); - bool outputBindingsChanged = false; - - for (uint32_t i = 0; i < outputBindings.size(); ++i) - { - std::vector outputDims; - outputDims.reserve(outputShapes.GetShape(i).size()); - for (uint32_t dimSize : outputShapes.GetShape(i)) - { - outputDims.push_back(dimSize); - } - - onnxruntime::Tensor* tensor = kernelContext->Output( - static_cast(i), - onnxruntime::TensorShape::FromExistingBuffer(outputDims) - ); - - uint64_t allocId; - DmlRuntimeGraphFusionHelper::UnwrapTensor(m_winmlProvider.Get(), tensor, &outputBindings[i].Buffer, &allocId); - outputBindingsChanged = outputBindingsChanged || (!allocId || m_outputBindingAllocIds[i] != allocId); - outputBindings[i].Buffer->Release(); // Avoid holding an additional reference - outputBindings[i].SizeInBytes = DmlRuntimeGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); - outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]}; - m_outputBindingAllocIds[i] = allocId; - } - - if (outputBindingsChanged) - { - m_bindingTable->BindOutputs(gsl::narrow_cast(outputBindingDescs.size()), outputBindingDescs.data()); - } - - if (execBindingProps.TemporaryResourceSize > 0) - { - // Allocate temporary data which will automatically be freed when the GPU work - // which is scheduled up to the point that this method returns has completed. - ComPtr tempAlloc; - uint64_t tempAllocId = 0; - ORT_THROW_IF_FAILED(contextWrapper.AllocateTemporaryData(static_cast(execBindingProps.TemporaryResourceSize), tempAlloc.GetAddressOf(), &tempAllocId)); - - ComPtr tempResourceUnk; - m_winmlProvider->GetABIDataInterface(false, tempAlloc.Get(), &tempResourceUnk); - - // Bind the temporary resource. - ComPtr tempResource; - ORT_THROW_IF_FAILED(tempResourceUnk->QueryInterface(tempResource.GetAddressOf())); - DML_BUFFER_BINDING tempBufferBinding = {tempResource.Get(), 0, execBindingProps.TemporaryResourceSize}; - DML_BINDING_DESC tempBindingDesc = { DML_BINDING_TYPE_BUFFER, &tempBufferBinding }; - - if (!tempAllocId || m_tempBindingAllocId != tempAllocId) - { - m_bindingTable->BindTemporaryResource(&tempBindingDesc); - } - - m_tempBindingAllocId = tempAllocId; - } - - // Execute the command list and if it succeeds, update the fence value at which this command may be - // re-used. - ComPtr fence; - uint64_t completionValue; - HRESULT hr = m_provider->ExecuteCommandList(m_graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue); - - if (hr == DXGI_ERROR_DEVICE_REMOVED) - { - ComPtr device; - ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(&device)); - ORT_THROW_IF_FAILED(device->GetDeviceRemovedReason()); - } - - ORT_THROW_IF_FAILED(hr); - m_fence = fence; - m_completionValue = completionValue; - - // Queue references to objects which must be kept alive until resulting GPU work completes - m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_graphicsCommandList).Get()); - m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(m_heap).Get()); - m_winmlProvider->QueueReference(m_bindingTable.Get()); - m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); - } - - ComPtr m_compiledExecutionPlanOperator; - std::vector m_inputsUsed; - const void* m_executionHandle = nullptr; ComPtr m_winmlProvider; ComPtr m_provider; @@ -443,13 +295,26 @@ namespace Dml ComPtr m_heap; ComPtr m_bindingTable; std::optional m_persistentResourceBinding; - ComPtr m_persistentResource; - ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator + std::shared_ptr m_indexedSubGraph; + const onnxruntime::Path& m_modelPath; + std::shared_ptr>> m_inputDimParams; + std::vector> m_subgraphNodes; + std::vector m_subgraphInputs; + std::vector m_subgraphOutputs; + std::vector> m_intermediateNodeArgs; + std::unordered_map m_partitionNodePropsMap; + std::vector m_ownedInitializers; + std::unordered_map> m_isInitializerTransferable; + std::vector m_subgraphNodePointers; // Bindings from previous executions of a re-used command list + mutable ComPtr m_compiledExecutionPlanOperator; mutable std::vector m_inputBindingAllocIds; mutable std::vector m_outputBindingAllocIds; mutable uint64_t m_tempBindingAllocId = 0; + mutable std::vector m_inputsUsed; + mutable ComPtr m_persistentResource; + mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator // Fence tracking the status of the command list's last execution, and whether its descriptor heap // can safely be updated. @@ -459,15 +324,27 @@ namespace Dml onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( const onnxruntime::OpKernelInfo& info, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - std::unordered_map&& partitionNodePropsMap) + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::shared_ptr>> inputDimParams, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) { return new DmlRuntimeFusedGraphKernel( info, - graph, - indexedSubGraph, - std::move(partitionNodePropsMap) + std::move(indexedSubGraph), + modelPath, + std::move(inputDimParams), + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers) ); } } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h index 3040a7612dfa2..d18a6d4671bc4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -9,8 +9,14 @@ namespace Dml { onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( const onnxruntime::OpKernelInfo& info, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - std::unordered_map&& partitionNodePropsMap + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::shared_ptr>> inputDimParams, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers ); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp index 2a6a01094f5e7..e69395c3380fc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp @@ -285,25 +285,149 @@ namespace DmlRuntimeGraphFusionHelper return compiledExecutionPlanOperator; } + struct NodeInfo + { + std::string name; + std::string opType; + std::string description; + std::string domain; + onnxruntime::NodeAttributes attributes; + std::vector inputDefPointers; + std::vector outputDefPointers; + }; + void RegisterKernel( onnxruntime::Graph& graph, onnxruntime::KernelRegistry* registryForPartitionKernels, const ExecutionProviderImpl* providerImpl, std::unordered_map graphNodePropertyMap, - std::shared_ptr indexedSubGraph) + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable) { auto partitionNodePropsMap = DmlRuntimeGraphFusionHelper::CreatePartitionNodePropsMap( graph, *indexedSubGraph, std::move(graphNodePropertyMap)); + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs; + + std::vector nodesInfo; + nodesInfo.reserve(indexedSubGraph->nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + std::vector nodeAttributes; + nodeAttributes.reserve(indexedSubGraph->nodes.size()); + + std::vector> intermediateNodeArgs; + + for (size_t sortedNodeIndex : indexedSubGraph->nodes) + { + auto node = graph.GetNode(sortedNodeIndex); + + nodeAttributes.push_back(node->GetAttributes()); + + NodeInfo nodeInfo{}; + nodeInfo.name = node->Name(); + nodeInfo.opType = node->OpType(); + nodeInfo.description = node->Description(); + nodeInfo.domain = node->Domain(); + nodeInfo.attributes = node->GetAttributes(); + nodeInfo.inputDefPointers.reserve(node->InputDefs().size()); + nodeInfo.outputDefPointers.reserve(node->OutputDefs().size()); + + for (const onnxruntime::NodeArg* inputDef : node->InputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(inputDef->Name(), inputDef->TypeAsProto())); + nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + for (const onnxruntime::NodeArg* outputDef : node->OutputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(outputDef->Name(), outputDef->TypeAsProto())); + nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + nodesInfo.push_back(std::move(nodeInfo)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + + // We store the input dim params that haven't been overriden yet so that we can map their value at runtime once the real inputs are provided + auto inputDimParams = std::make_shared>>(); + + // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph + std::vector ownedInitializers; + ownedInitializers.reserve(isInitializerTransferable.size()); + + for (auto& kvp : isInitializerTransferable) + { + ONNX_NAMESPACE::TensorProto tensorProto; + tensorProto.set_data_type(kvp.second.first->data_type()); + tensorProto.set_raw_data(kvp.second.first->raw_data()); + tensorProto.set_name(kvp.second.first->name()); + + for (int i = 0; i < kvp.second.first->dims_size(); ++i) + { + tensorProto.add_dims(kvp.second.first->dims(i)); + } + ownedInitializers.push_back(std::move(tensorProto)); + kvp.second.first = &ownedInitializers.back(); + } + // lamda captures for the kernel registration auto fused_kernel_func = [ - &graph, + inputDimParams, indexedSubGraph, - partitionNodePropsMap = std::move(partitionNodePropsMap)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status + &modelPath, + nodesInfo = std::move(nodesInfo), + intermediateNodeArgs = std::move(intermediateNodeArgs), + subgraphInputs = std::move(subgraphInputs), + subgraphOutputs = std::move(subgraphOutputs), + partitionNodePropsMap = std::move(partitionNodePropsMap), + ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status { - out.reset(CreateRuntimeFusedGraphKernel(info, graph, *indexedSubGraph, std::move(partitionNodePropsMap))); + std::vector> subgraphNodes; + subgraphNodes.reserve(nodesInfo.size()); + + for (const NodeInfo& nodeInfo : nodesInfo) + { + subgraphNodes.emplace_back(std::make_shared( + nodeInfo.name, + nodeInfo.opType, + nodeInfo.description, + nodeInfo.inputDefPointers, + nodeInfo.outputDefPointers, + &nodeInfo.attributes, + nodeInfo.domain)); + } + + out.reset(CreateRuntimeFusedGraphKernel( + info, + indexedSubGraph, + modelPath, + std::move(inputDimParams), + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers))); return Status::OK(); }; @@ -318,6 +442,26 @@ namespace DmlRuntimeGraphFusionHelper auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + inputDimParams->resize(fusedNode.InputDefs().size()); + + for (int inputIndex = 0; inputIndex < fusedNode.InputDefs().size(); ++inputIndex) + { + const onnxruntime::NodeArg* inputDef = fusedNode.InputDefs()[inputIndex]; + + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->TypeAsProto()->has_tensor_type()); + const auto& tensorShape = inputDef->TypeAsProto()->tensor_type().shape(); + + (*inputDimParams)[inputIndex].resize(tensorShape.dim_size()); + + for (int i = 0; i < tensorShape.dim_size(); ++i) + { + if (tensorShape.dim(i).has_dim_param()) + { + (*inputDimParams)[inputIndex][i] = tensorShape.dim(i).dim_param(); + } + } + } + graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h index b8713dd0736e3..b43693378cac8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h @@ -76,6 +76,7 @@ namespace DmlRuntimeGraphFusionHelper onnxruntime::KernelRegistry* registryForPartitionKernels, const ExecutionProviderImpl* providerImpl, std::unordered_map graphNodePropertyMap, - std::shared_ptr indexedSubGraph); + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp index 8286804cca0fa..7ba357b719c53 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -15,6 +15,15 @@ namespace Dml { + namespace + { + struct CompiledPartitionInfo + { + std::shared_ptr indexedSubGraph; + std::unordered_map> isInitializerTransferable; + }; + } + DmlRuntimeGraphFusionTransformer::DmlRuntimeGraphFusionTransformer( const std::string& name, const onnxruntime::IExecutionProvider* provider @@ -91,7 +100,7 @@ namespace Dml additionalSplittingNodes.clear(); // Reset the compiled operators for the current iteration - std::vector> indexedSubGraphs(partitions.size()); + std::vector> compiledPartitionInfos(partitions.size()); // Create a map between each initialized tensor and the partition(s) it is part of. auto initializerPartitionMap = DmlRuntimeGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); @@ -113,22 +122,35 @@ namespace Dml std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; m_providerImpl->IncreasePartitionKernelPrefixVal(); - indexedSubGraphs[partitionIndex] = std::make_shared( + // populate isInitializerTransferable + for (const auto& input : partition->GetInputs()) + { + const onnx::TensorProto* tensor = nullptr; + if (graph.GetInitializedTensor(input, tensor) && requiredInitializerMap.find(input) != requiredInitializerMap.end()) + { + isInitializerTransferable[input] = {tensor, false}; + } + } + + compiledPartitionInfos[partitionIndex] = std::make_shared(); + compiledPartitionInfos[partitionIndex]->indexedSubGraph = std::make_shared( DmlRuntimeGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix)); + compiledPartitionInfos[partitionIndex]->isInitializerTransferable = std::move(isInitializerTransferable); } } - for (auto&& indexedSubGraph : indexedSubGraphs) + for (auto&& compiledPartitionInfo : compiledPartitionInfos) { // Null compiled operators were not DML partitions - if (indexedSubGraph) + if (compiledPartitionInfo) { DmlRuntimeGraphFusionHelper::RegisterKernel( graph, m_providerImpl->GetKernelRegistry().get(), m_providerImpl, graphNodePropertyMap, - std::move(indexedSubGraph)); + std::move(compiledPartitionInfo->indexedSubGraph), + std::move(compiledPartitionInfo->isInitializerTransferable)); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 636f46428ce99..2cb5f9157d703 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -147,14 +147,14 @@ namespace Dml::GraphDescBuilder const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle) + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs) { - const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; - const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; struct NodeAndIndex { uint32_t nodeIndex; // The index of the node itself @@ -164,12 +164,14 @@ namespace Dml::GraphDescBuilder // Map from Lotus node argument names to the new node and index where it will be produced std::unordered_map nameToNodeAndIndexMap; + std::unordered_map nodeOutputShapes; + // Map from Lotus node argument names to input indices of the fused kernel node. std::unordered_map nameToDmlFusedNodeInputIndex; - for (size_t inputIndex = 0; inputIndex < subGraphInputArgNames.size(); ++inputIndex) + for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(subGraphInputArgNames[inputIndex]); + const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; if (!graphInput) { @@ -196,13 +198,11 @@ namespace Dml::GraphDescBuilder const uint32_t minNodeCountToReuseCommandList = 5; bool reuseCommandList = false; - if (indexedSubGraph.nodes.size() >= minNodeCountToReuseCommandList) + if (subgraphNodes.size() >= minNodeCountToReuseCommandList) { reuseCommandList = true; } - auto modelPath = graph.ModelPath(); - auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; @@ -219,9 +219,9 @@ namespace Dml::GraphDescBuilder // Iterate through each node and create a corresponding node in the new graph // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - for (size_t sortedNodeIndex : indexedSubGraph.nodes) + for (const onnxruntime::Node* subgraphNode : subgraphNodes) { - const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex); + const onnxruntime::Node& node = *subgraphNode; const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second; const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs; @@ -347,6 +347,8 @@ namespace Dml::GraphDescBuilder operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], operatorGraphOutputEdge.FromNodeOutputIndex }; + + nodeOutputShapes[arg->Name()] = graphNodeCreateInfo.outputShapes; } } @@ -367,10 +369,12 @@ namespace Dml::GraphDescBuilder } } + EdgeShapes graphOutputShapes(subgraphOutputs.size()); + // Add graph output nodes, which might be in a different order from the encapsulating node - for (size_t outputIndex = 0; outputIndex < subGraphOutputArgNames.size(); ++outputIndex) + for (size_t outputIndex = 0; outputIndex < subgraphOutputs.size(); ++outputIndex) { - const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(subGraphOutputArgNames[outputIndex]); + const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); @@ -380,6 +384,7 @@ namespace Dml::GraphDescBuilder edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); graphOutputEdges.push_back(edge); + graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); @@ -390,6 +395,7 @@ namespace Dml::GraphDescBuilder graphDesc.outputEdges = std::move(graphOutputEdges); graphDesc.intermediateEdges = std::move(graphIntermediateEdges); graphDesc.reuseCommandList = reuseCommandList; + graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 5c04962e55557..0039678c00e59 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -9,10 +9,10 @@ namespace Dml { struct GraphNodeProperties { - std::shared_ptr + std::shared_ptr internalRegInfo; - // These are currently passed from the partitioning step since the only DML operators current + // These are currently passed from the partitioning step since the only DML operators current // supporting graph nodes don't customize the order of edges or shapes, other than coercing // dimension count. This will change as the supported set of operators as graph nodes increases. Windows::AI::MachineLearning::Adapter::EdgeShapes inputShapes; @@ -38,16 +38,19 @@ namespace Dml std::vector outputEdges; std::vector intermediateEdges; bool reuseCommandList; + Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle); + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs); } -} \ No newline at end of file +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index a7f8bebb2de78..d29d2b6b9262b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -4,6 +4,7 @@ #pragma once #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" #include "core/framework/op_kernel.h" #include "core/framework/customregistry.h" #include "core/framework/tensorprotoutils.h" @@ -93,42 +94,6 @@ struct AttributeValue using AttributeMap = std::map; -// Encapsulation of shapes across different edges of an operator. Non-tensor -// edges and unused edges have an empty array of dimensions. -class EdgeShapes -{ -public: - EdgeShapes() = default; - - EdgeShapes(size_t count) : m_shapes(count) {} - - const std::vector& GetShape(size_t edgeIndex) const - { - return m_shapes[edgeIndex]; - } - - std::vector& GetMutableShape(size_t edgeIndex) - { - return m_shapes[edgeIndex]; - } - - size_t EdgeCount() const { return m_shapes.size(); } - - void Reset(size_t edge_count) - { - m_shapes.clear(); - m_shapes.resize(edge_count); - } - - bool operator!=(const EdgeShapes& other) const noexcept - { - return (m_shapes != other.m_shapes); - } - - private: - std::vector> m_shapes; -}; - // Base class for ABI objects which may be "Closed", at which point calls will predictably // fail or return a dummy value. This is used for transient ABI context objects which // are passed to methods on kernel or inferencers, and which wrap Lotus objects whose lifetimes From eb33dc3c5041807f252374b1174fe3c1d025a54e Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Thu, 5 Oct 2023 19:40:40 -0700 Subject: [PATCH 03/20] WIP --- .../src/DmlRuntimeFusedGraphKernel.cpp | 152 ++++++++++++------ .../src/ExecutionProvider.cpp | 1 - .../src/GraphPartitioner.cpp | 53 +++++- 3 files changed, 150 insertions(+), 56 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 51581fa14adb6..b7094a062f0f6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -61,6 +61,33 @@ namespace Dml { m_subgraphNodePointers.push_back(subgraphNode.get()); } + + m_nodeDimParams.resize(m_subgraphNodes.size()); + + for (int nodeIndex = 0; nodeIndex < m_subgraphNodes.size(); ++nodeIndex) + { + m_nodeDimParams[nodeIndex].resize(m_subgraphNodes[nodeIndex]->InputDefs().size()); + + for (int inputIndex = 0; inputIndex < m_subgraphNodes[nodeIndex]->InputDefs().size(); ++inputIndex) + { + auto* inputDef = m_subgraphNodes[nodeIndex]->MutableInputDefs()[inputIndex]; + + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->TypeAsProto()->has_tensor_type()); + auto tensorShape = inputDef->TypeAsProto()->tensor_type().shape(); + + m_nodeDimParams[nodeIndex][inputIndex].resize(tensorShape.dim_size()); + + for (int i = 0; i < tensorShape.dim_size(); ++i) + { + if (tensorShape.dim(i).has_dim_param()) + { + m_nodeDimParams[nodeIndex][inputIndex][i] = tensorShape.dim(i).dim_param(); + } + } + + inputDef->SetShape(tensorShape); + } + } } void TranslateAndCompileGraph( @@ -101,16 +128,7 @@ namespace Dml onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override { - const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); - - // Populate input bindings for operator initialization - std::vector> initializeResourceRefs; // For lifetime control - std::vector initInputBindings(fusedNodeInputCount); - std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); - - auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); - - std::unordered_map dynamicDimOverrides; + bool recompiledNeeded = false; ORT_THROW_HR_IF(E_UNEXPECTED, m_inputDimParams->size() != kernelContext->InputCount()); for (int inputIndex = 0; inputIndex < m_inputDimParams->size(); ++inputIndex) @@ -124,62 +142,87 @@ namespace Dml if (!dimParam.empty()) { - dynamicDimOverrides[dimParam] = input.Shape().GetDims()[i]; + auto iter = m_dynamicDimOverrides.find(dimParam); + if (iter == m_dynamicDimOverrides.end()) + { + m_dynamicDimOverrides[dimParam] = input.Shape().GetDims()[i]; + recompiledNeeded = true; + } + else if (iter->second != input.Shape().GetDims()[i]) + { + iter->second = input.Shape().GetDims()[i]; + recompiledNeeded = true; + } } } } - for (auto& subgraphNode : m_subgraphNodes) + if (recompiledNeeded) { - for (onnxruntime::NodeArg* inputDef : subgraphNode->MutableInputDefs()) + // Populate input bindings for operator initialization + const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); + std::vector> initializeResourceRefs; // For lifetime control + std::vector initInputBindings(fusedNodeInputCount); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); + + for (int nodeIndex = 0; nodeIndex < m_subgraphNodes.size(); ++nodeIndex) { - ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->TypeAsProto()->has_tensor_type()); - auto tensorShape = inputDef->TypeAsProto()->tensor_type().shape(); - - for (int i = 0; i < tensorShape.dim_size(); ++i) + for (int inputIndex = 0; inputIndex < m_subgraphNodes[nodeIndex]->InputDefs().size(); ++inputIndex) { - if (tensorShape.dim(i).has_dim_param()) + auto* inputDef = m_subgraphNodes[nodeIndex]->MutableInputDefs()[inputIndex]; + + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->TypeAsProto()->has_tensor_type()); + auto tensorShape = inputDef->TypeAsProto()->tensor_type().shape(); + + for (int i = 0; i < tensorShape.dim_size(); ++i) { - tensorShape.mutable_dim(i)->set_dim_value(dynamicDimOverrides[tensorShape.dim(i).dim_param()]); + const std::string& dimParam = m_nodeDimParams[nodeIndex][inputIndex][i]; + if (!dimParam.empty()) + { + tensorShape.mutable_dim(i)->set_dim_value(m_dynamicDimOverrides[dimParam]); + } } - } - inputDef->SetShape(tensorShape); + inputDef->SetShape(tensorShape); + } } - } - // Convert partitionONNXGraph into DML EP GraphDesc - ComPtr device; - ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); - GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( - isInputsUploadedByDmlEP.data(), - isInputsUploadedByDmlEP.size(), - m_isInitializerTransferable, - m_partitionNodePropsMap, - device.Get(), - providerImpl, - m_modelPath, - m_subgraphNodePointers, - m_subgraphInputs, - m_subgraphOutputs); - - // Walk through each graph edge and mark used inputs - m_inputsUsed.resize(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) - { - m_inputsUsed[edge.GraphInputIndex] = true; - } + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + m_isInitializerTransferable, + m_partitionNodePropsMap, + device.Get(), + providerImpl, + m_modelPath, + m_subgraphNodePointers, + m_subgraphInputs, + m_subgraphOutputs); + + m_outputShapes = graphDesc.outputShapes; + + // Walk through each graph edge and mark used inputs + m_inputsUsed.resize(fusedNodeInputCount, false); + for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) + { + m_inputsUsed[edge.GraphInputIndex] = true; + } - // Compile the operator - m_compiledExecutionPlanOperator = DmlRuntimeGraphFusionHelper::TryCreateCompiledOperator( - graphDesc, - *m_indexedSubGraph, - providerImpl); + // Compile the operator + m_compiledExecutionPlanOperator = DmlRuntimeGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + *m_indexedSubGraph, + providerImpl); - TranslateAndCompileGraph( - Info(), - initializeResourceRefs, - initInputBindings); + TranslateAndCompileGraph( + Info(), + initializeResourceRefs, + initInputBindings); + } // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator OpKernelContextWrapper contextWrapper( @@ -206,7 +249,7 @@ namespace Dml inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); } - auto aux = contextWrapper.GetOutputTensors(graphDesc.outputShapes); + auto aux = contextWrapper.GetOutputTensors(m_outputShapes); ExecuteOperator( m_compiledExecutionPlanOperator.Get(), m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, @@ -306,6 +349,7 @@ namespace Dml std::vector m_ownedInitializers; std::unordered_map> m_isInitializerTransferable; std::vector m_subgraphNodePointers; + std::vector>> m_nodeDimParams; // Bindings from previous executions of a re-used command list mutable ComPtr m_compiledExecutionPlanOperator; @@ -315,6 +359,8 @@ namespace Dml mutable std::vector m_inputsUsed; mutable ComPtr m_persistentResource; mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator + mutable std::unordered_map m_dynamicDimOverrides; + mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; // Fence tracking the status of the command list's last execution, and whether its descriptor heap // can safely be updated. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index f97b72aa2d385..47313db3c4d13 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -152,7 +152,6 @@ namespace Dml m_dmlDevice(dmlDevice), m_areMetacommandsEnabled(enableMetacommands) { - D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 77419a8412c5b..bd7ede4bf1348 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -198,8 +198,57 @@ namespace Dml std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; if (requiredCpuInputsConstant && (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) { - *isDmlGraphNode = true; - graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + bool valid = true; + onnxruntime::ProtoHelperNodeContext protoContext(node); + onnxruntime::OpNodeProtoHelper info(&protoContext); + + graphNodeProperty.first->second.inputShapes.Reset(info.GetInputCount()); + + for (size_t inputIndex = 0; inputIndex < graphNodeProperty.first->second.inputShapes.EdgeCount(); ++inputIndex) + { + const onnx::TypeProto* inputProto = info.GetInputType(inputIndex); + // Skip this input if it is not valid. + if (inputProto == nullptr) + { + continue; + } + + if (inputProto->value_case() != onnx::TypeProto::kTensorType) + { + continue; + } + + const auto& tensorType = inputProto->tensor_type(); + + if (!tensorType.has_shape()) + { + valid = false; + break; + } + + const auto& shape = tensorType.shape(); + graphNodeProperty.first->second.inputShapes.GetMutableShape(inputIndex).resize(shape.dim_size()); + + for (uint32_t dimIndex = 0; dimIndex < static_cast(shape.dim_size()); ++dimIndex) + { + if (!shape.dim(dimIndex).has_dim_value() && !shape.dim(dimIndex).has_dim_param()) + { + valid = false; + break; + } + } + + if (!valid) + { + break; + } + } + + if (valid) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } } else From 32a2b956212af03bea281a86bc7d950e6e34eabe Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 6 Oct 2023 02:41:07 -0700 Subject: [PATCH 04/20] Working implementation --- .../inc/IWinmlExecutionProvider.h | 2 + .../src/AbiCustomRegistry.cpp | 4 +- .../src/DmlRuntimeFusedGraphKernel.cpp | 103 ++++++------------ .../src/ExecutionProvider.cpp | 2 + .../src/GraphDescBuilder.cpp | 32 ++++++ .../src/GraphPartitioner.cpp | 53 +-------- .../src/MLOperatorAuthorImpl.cpp | 3 +- .../src/MLOperatorAuthorImpl.h | 1 + 8 files changed, 77 insertions(+), 123 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 63e2424602f83..51bce09fcbe9a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -88,12 +88,14 @@ namespace Windows::AI::MachineLearning::Adapter std::vector outputEdges; std::vector intermediateEdges; EdgeShapes outputShapes; + const std::unordered_map>* inferredOutputShapes; }; using GraphNodeFactory = std::function; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index fce26623a7b0e..e86737fa9665a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -491,6 +491,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo ) { @@ -499,7 +500,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( // Use the same list of required constant inputs for the shape inferrer and the kernel. EdgeShapes outputShapes; - InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes); + InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, outputShapes); graphNodeCreateInfo->outputShapes = outputShapes; @@ -508,6 +509,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( &protoHelper, executionHandle, true, + inputShapesOverrides, &outputShapes, &defaultAttributesCapture, graphNodeCreateInfo, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index b7094a062f0f6..d558caf24795e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -61,33 +61,6 @@ namespace Dml { m_subgraphNodePointers.push_back(subgraphNode.get()); } - - m_nodeDimParams.resize(m_subgraphNodes.size()); - - for (int nodeIndex = 0; nodeIndex < m_subgraphNodes.size(); ++nodeIndex) - { - m_nodeDimParams[nodeIndex].resize(m_subgraphNodes[nodeIndex]->InputDefs().size()); - - for (int inputIndex = 0; inputIndex < m_subgraphNodes[nodeIndex]->InputDefs().size(); ++inputIndex) - { - auto* inputDef = m_subgraphNodes[nodeIndex]->MutableInputDefs()[inputIndex]; - - ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->TypeAsProto()->has_tensor_type()); - auto tensorShape = inputDef->TypeAsProto()->tensor_type().shape(); - - m_nodeDimParams[nodeIndex][inputIndex].resize(tensorShape.dim_size()); - - for (int i = 0; i < tensorShape.dim_size(); ++i) - { - if (tensorShape.dim(i).has_dim_param()) - { - m_nodeDimParams[nodeIndex][inputIndex][i] = tensorShape.dim(i).dim_param(); - } - } - - inputDef->SetShape(tensorShape); - } - } } void TranslateAndCompileGraph( @@ -128,66 +101,55 @@ namespace Dml onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override { - bool recompiledNeeded = false; + ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount()); + + bool recompiledNeeded = m_compiledExecutionPlanOperator == nullptr; - ORT_THROW_HR_IF(E_UNEXPECTED, m_inputDimParams->size() != kernelContext->InputCount()); - for (int inputIndex = 0; inputIndex < m_inputDimParams->size(); ++inputIndex) + for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex) { const auto& input = kernelContext->RequiredInput(inputIndex); - ORT_THROW_HR_IF(E_UNEXPECTED, input.Shape().NumDimensions() != (*m_inputDimParams)[inputIndex].size()); + const std::string& inputName = m_subgraphInputs[inputIndex]->Name(); + auto iter = m_inferredInputShapes.find(inputName); - for (int i = 0; i < input.Shape().NumDimensions(); ++i) + if (iter == m_inferredInputShapes.end()) { - const std::string& dimParam = (*m_inputDimParams)[inputIndex][i]; - - if (!dimParam.empty()) - { - auto iter = m_dynamicDimOverrides.find(dimParam); - if (iter == m_dynamicDimOverrides.end()) - { - m_dynamicDimOverrides[dimParam] = input.Shape().GetDims()[i]; - recompiledNeeded = true; - } - else if (iter->second != input.Shape().GetDims()[i]) - { - iter->second = input.Shape().GetDims()[i]; - recompiledNeeded = true; - } - } + m_inferredInputShapes[inputName] = input.Shape(); + recompiledNeeded = true; + } + else if (iter->second != input.Shape()) + { + iter->second = input.Shape(); + recompiledNeeded = true; } } if (recompiledNeeded) { - // Populate input bindings for operator initialization - const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); - std::vector> initializeResourceRefs; // For lifetime control - std::vector initInputBindings(fusedNodeInputCount); - std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); - auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); - - for (int nodeIndex = 0; nodeIndex < m_subgraphNodes.size(); ++nodeIndex) + // Go through all the node args and replace their shapes with the real ones + for (auto& nodeArg : m_intermediateNodeArgs) { - for (int inputIndex = 0; inputIndex < m_subgraphNodes[nodeIndex]->InputDefs().size(); ++inputIndex) + auto iter = m_inferredInputShapes.find(nodeArg->Name()); + if (iter != m_inferredInputShapes.end()) { - auto* inputDef = m_subgraphNodes[nodeIndex]->MutableInputDefs()[inputIndex]; - - ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->TypeAsProto()->has_tensor_type()); - auto tensorShape = inputDef->TypeAsProto()->tensor_type().shape(); + auto tensorShape = *nodeArg->Shape(); + ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != iter->second.NumDimensions()); for (int i = 0; i < tensorShape.dim_size(); ++i) { - const std::string& dimParam = m_nodeDimParams[nodeIndex][inputIndex][i]; - if (!dimParam.empty()) - { - tensorShape.mutable_dim(i)->set_dim_value(m_dynamicDimOverrides[dimParam]); - } + tensorShape.mutable_dim(i)->set_dim_value(iter->second.GetDims()[i]); } - inputDef->SetShape(tensorShape); + nodeArg->SetShape(tensorShape); } } + // Populate input bindings for operator initialization + const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); + std::vector> initializeResourceRefs; // For lifetime control + std::vector initInputBindings(fusedNodeInputCount); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); + // Convert partitionONNXGraph into DML EP GraphDesc ComPtr device; ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); @@ -340,16 +302,17 @@ namespace Dml std::optional m_persistentResourceBinding; std::shared_ptr m_indexedSubGraph; const onnxruntime::Path& m_modelPath; + + // TODO (pavignol): Remove m_inputDimParams if truly not needed std::shared_ptr>> m_inputDimParams; std::vector> m_subgraphNodes; std::vector m_subgraphInputs; std::vector m_subgraphOutputs; - std::vector> m_intermediateNodeArgs; + mutable std::vector> m_intermediateNodeArgs; std::unordered_map m_partitionNodePropsMap; std::vector m_ownedInitializers; std::unordered_map> m_isInitializerTransferable; std::vector m_subgraphNodePointers; - std::vector>> m_nodeDimParams; // Bindings from previous executions of a re-used command list mutable ComPtr m_compiledExecutionPlanOperator; @@ -359,8 +322,8 @@ namespace Dml mutable std::vector m_inputsUsed; mutable ComPtr m_persistentResource; mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator - mutable std::unordered_map m_dynamicDimOverrides; mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; + mutable std::unordered_map m_inferredInputShapes; // Fence tracking the status of the command list's last execution, and whether its descriptor heap // can safely be updated. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 47313db3c4d13..7d171f2dddf19 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -152,6 +152,8 @@ namespace Dml m_dmlDevice(dmlDevice), m_areMetacommandsEnabled(enableMetacommands) { + // TODO (pavignol): Remove me + m_areMetacommandsEnabled = false; D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 2cb5f9157d703..546fb56bc780e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -219,6 +219,8 @@ namespace Dml::GraphDescBuilder // Iterate through each node and create a corresponding node in the new graph // We can iterate the nodes in any order because the edge connectivity will take care of the topological order + std::unordered_map> inferredOutputShapes; + for (const onnxruntime::Node* subgraphNode : subgraphNodes) { const onnxruntime::Node& node = *subgraphNode; @@ -245,13 +247,43 @@ namespace Dml::GraphDescBuilder }; DmlGraphNodeCreateInfo graphNodeCreateInfo; + EdgeShapes inputShapesOverrides(node.InputDefs().size()); + + // Override the input shapes with shapes that were previously inferred + for (int inputIndex = 0; inputIndex < node.InputDefs().size(); ++inputIndex) + { + auto inputDef = node.InputDefs()[inputIndex]; + + auto outputShapesIter = inferredOutputShapes.find(inputDef->Name()); + if (outputShapesIter != inferredOutputShapes.end()) + { + inputShapesOverrides.GetMutableShape(inputIndex) = outputShapesIter->second; + } + else + { + for (int i = 0; i < inputDef->Shape()->dim_size(); ++i) + { + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->Shape()->dim(i).has_dim_value()); + inputShapesOverrides.GetMutableShape(inputIndex).push_back(gsl::narrow_cast(inputDef->Shape()->dim(i).dim_value())); + } + } + } + + graphNodeCreateInfo.inferredOutputShapes = &inferredOutputShapes; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, + &inputShapesOverrides, /*out*/ &graphNodeCreateInfo ); + ORT_THROW_HR_IF(E_UNEXPECTED, graphNodeCreateInfo.outputShapes.EdgeCount() != node.OutputDefs().size()); + for (int i = 0; i < node.OutputDefs().size(); ++i) + { + inferredOutputShapes[node.OutputDefs()[i]->Name()] = graphNodeCreateInfo.outputShapes.GetShape(i); + } + // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index bd7ede4bf1348..77419a8412c5b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -198,57 +198,8 @@ namespace Dml std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; if (requiredCpuInputsConstant && (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) { - bool valid = true; - onnxruntime::ProtoHelperNodeContext protoContext(node); - onnxruntime::OpNodeProtoHelper info(&protoContext); - - graphNodeProperty.first->second.inputShapes.Reset(info.GetInputCount()); - - for (size_t inputIndex = 0; inputIndex < graphNodeProperty.first->second.inputShapes.EdgeCount(); ++inputIndex) - { - const onnx::TypeProto* inputProto = info.GetInputType(inputIndex); - // Skip this input if it is not valid. - if (inputProto == nullptr) - { - continue; - } - - if (inputProto->value_case() != onnx::TypeProto::kTensorType) - { - continue; - } - - const auto& tensorType = inputProto->tensor_type(); - - if (!tensorType.has_shape()) - { - valid = false; - break; - } - - const auto& shape = tensorType.shape(); - graphNodeProperty.first->second.inputShapes.GetMutableShape(inputIndex).resize(shape.dim_size()); - - for (uint32_t dimIndex = 0; dimIndex < static_cast(shape.dim_size()); ++dimIndex) - { - if (!shape.dim(dimIndex).has_dim_value() && !shape.dim(dimIndex).has_dim_param()) - { - valid = false; - break; - } - } - - if (!valid) - { - break; - } - } - - if (valid) - { - *isDmlGraphNode = true; - graphNodeProperty.first->second.internalRegInfo = internalRegInfo; - } + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; } } else diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 6cd10e14e08d2..4deec620fe5fb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1356,13 +1356,14 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::OpNodeProtoHelper* protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, gsl::span requiredConstantCpuInputs, MLOperatorTensorGetter& constantInputGetter ) - : OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), + : OpNodeInfoWrapper(protoHelper, inputShapesOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), m_inferredOutputShapes(inferredOutputShapes), m_internalOperator(isInternalOperator), m_graphNodeCreateInfo(graphNodeCreateInfo) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index d29d2b6b9262b..913997ff4ad49 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -399,6 +399,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper< const onnxruntime::OpNodeProtoHelper * protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, From d69d39981817d10ebe705989ed050c27f564ea3d Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 6 Oct 2023 16:27:55 -0700 Subject: [PATCH 05/20] Working version with good performance (no queue reuse) --- .../src/DmlGraphFusionTransformer.cpp | 2 + .../src/DmlRuntimeFusedGraphKernel.cpp | 12 +++- .../src/DmlRuntimeGraphFusionHelper.cpp | 11 ++++ .../src/DmlRuntimeGraphFusionHelper.h | 1 + .../src/DmlRuntimeGraphFusionTransformer.cpp | 3 + .../src/ExecutionProvider.cpp | 2 +- .../src/GraphPartitioner.cpp | 57 +++++++++++++------ .../src/GraphPartitioner.h | 1 + 8 files changed, 70 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index f60880b17a08e..679738b639ec9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -90,6 +90,7 @@ namespace Dml { // Initializers needed by any graph partition std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; std::unordered_map graphNodePropertyMap; onnxruntime::GraphViewer graphViewer(graph); std::vector> partitions = BuildPartitions( @@ -99,6 +100,7 @@ namespace Dml m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, additionalSplittingNodes, implicitInputDefs, false); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index d558caf24795e..099777b85452d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -121,6 +121,15 @@ namespace Dml iter->second = input.Shape(); recompiledNeeded = true; } + + // If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list + if (input.Location().device.Type() == OrtDevice::CPU) + { + // TODO (pavignol): Force recompile if CPU data changed + auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName); + m_ownedCpuInputs.push_back(std::make_unique(std::move(inputProto))); + m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false); + } } if (recompiledNeeded) @@ -311,10 +320,11 @@ namespace Dml mutable std::vector> m_intermediateNodeArgs; std::unordered_map m_partitionNodePropsMap; std::vector m_ownedInitializers; - std::unordered_map> m_isInitializerTransferable; + mutable std::unordered_map> m_isInitializerTransferable; std::vector m_subgraphNodePointers; // Bindings from previous executions of a re-used command list + mutable std::vector> m_ownedCpuInputs; mutable ComPtr m_compiledExecutionPlanOperator; mutable std::vector m_inputBindingAllocIds; mutable std::vector m_outputBindingAllocIds; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp index e69395c3380fc..bd787dbfb4382 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp @@ -301,6 +301,7 @@ namespace DmlRuntimeGraphFusionHelper onnxruntime::KernelRegistry* registryForPartitionKernels, const ExecutionProviderImpl* providerImpl, std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, std::shared_ptr indexedSubGraph, std::unordered_map>&& isInitializerTransferable) { @@ -437,6 +438,16 @@ namespace DmlRuntimeGraphFusionHelper .SetDomain(indexedSubGraph->GetMetaDef()->domain) .SinceVersion(indexedSubGraph->GetMetaDef()->since_version) .Provider(onnxruntime::kDmlExecutionProvider); + + // Force the CPU inputs to be allocated on the CPU + for (int i = 0; i < subGraphInputArgNames.size(); ++i) + { + if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end()) + { + builder.InputMemoryType(OrtMemTypeCPUInput, i); + } + } + ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h index b43693378cac8..3da482bd72aaa 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h @@ -76,6 +76,7 @@ namespace DmlRuntimeGraphFusionHelper onnxruntime::KernelRegistry* registryForPartitionKernels, const ExecutionProviderImpl* providerImpl, std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, std::shared_ptr indexedSubGraph, std::unordered_map>&& isInitializerTransferable); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp index 7ba357b719c53..f1c7e723d4b18 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -84,6 +84,7 @@ namespace Dml std::vector additionalSplittingNodes; std::unordered_map graphNodePropertyMap; std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; onnxruntime::GraphViewer graphViewer(graph); std::vector> partitions = BuildPartitions( graphViewer, @@ -92,6 +93,7 @@ namespace Dml m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, additionalSplittingNodes, implicitInputDefs, true); @@ -149,6 +151,7 @@ namespace Dml m_providerImpl->GetKernelRegistry().get(), m_providerImpl, graphNodePropertyMap, + dynamicCpuInputMap, std::move(compiledPartitionInfo->indexedSubGraph), std::move(compiledPartitionInfo->isInitializerTransferable)); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 7d171f2dddf19..1a48e909a6434 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -153,7 +153,7 @@ namespace Dml m_areMetacommandsEnabled(enableMetacommands) { // TODO (pavignol): Remove me - m_areMetacommandsEnabled = false; + // m_areMetacommandsEnabled = false; D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 77419a8412c5b..f7a4743801d81 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -151,6 +151,7 @@ namespace Dml _In_opt_ const std::unordered_map* nodeNameToPartitionMap, _Inout_ std::unordered_map& dmlNodePropertyMap, _Inout_ std::unordered_set& requiredInitializerMap, + _Inout_ std::unordered_set& dynamicCpuInputMap, bool allowDmlGraphDynamicShapes, _Out_ bool* isDmlGraphNode ) @@ -173,30 +174,30 @@ namespace Dml if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration) { - bool requiredCpuInputsConstant = true; - for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + if (allowDmlGraphDynamicShapes) { - if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) { - continue; - } + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } - const onnx::TensorProto* tensor = nullptr; - const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); - if (!graph.GetInitializedTensor(inputName, tensor)) - { - requiredCpuInputsConstant = false; - break; + if (graph.GetInitializedTensor(inputName, tensor)) + { + requiredInitializerMap.insert(inputName); + } + else + { + dynamicCpuInputMap.insert(inputName); + } } - requiredInitializerMap.insert(inputName); - } - - if (allowDmlGraphDynamicShapes) - { std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; - if (requiredCpuInputsConstant && (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + if (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()) { *isDmlGraphNode = true; graphNodeProperty.first->second.internalRegInfo = internalRegInfo; @@ -204,6 +205,26 @@ namespace Dml } else { + bool requiredCpuInputsConstant = true; + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + { + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } + + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + + if (!graph.GetInitializedTensor(inputName, tensor)) + { + requiredCpuInputsConstant = false; + break; + } + + requiredInitializerMap.insert(inputName); + } + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; if (requiredCpuInputsConstant && TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && @@ -392,6 +413,7 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, const std::unordered_map& implicitInputs, bool allowDmlGraphDynamicShapes) @@ -457,6 +479,7 @@ namespace Dml &nodeNameToPartitionMap, graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, allowDmlGraphDynamicShapes, /*out*/ &isDmlGraphNode ); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 407fdd52e2cf4..3bddb5ae16086 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -50,6 +50,7 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, const std::unordered_map& implicitInputs, bool allowDmlGraphDynamicShapes); From 0252c1363f40b5bc1a88a1cf1e5247b8ab578284 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 6 Oct 2023 17:51:44 -0700 Subject: [PATCH 06/20] Disable metacommands --- .../dml/DmlExecutionProvider/src/ExecutionProvider.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 1a48e909a6434..7d171f2dddf19 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -153,7 +153,7 @@ namespace Dml m_areMetacommandsEnabled(enableMetacommands) { // TODO (pavignol): Remove me - // m_areMetacommandsEnabled = false; + m_areMetacommandsEnabled = false; D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { From 779197df9725398380d737d3148104c255f54b0d Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 6 Oct 2023 21:45:57 -0700 Subject: [PATCH 07/20] Add options to disable metacommands and enable dynamic graph fusions --- .../inc/DmlExecutionProvider.h | 6 +++- .../inc/IWinmlExecutionProvider.h | 2 ++ .../src/ExecutionProvider.cpp | 22 +++++++----- .../src/ExecutionProvider.h | 13 ++++--- .../src/IExecutionProvider.h | 7 ++++ .../providers/dml/dml_provider_factory.cc | 34 +++++++++++-------- .../dml/dml_provider_factory_creator.h | 4 +-- .../core/providers/dml/dml_provider_options.h | 11 ++++++ onnxruntime/core/session/inference_session.cc | 19 +++++++---- .../python/onnxruntime_pybind_schema.cc | 7 +++- .../python/onnxruntime_pybind_state.cc | 25 ++++++++++++-- .../python/onnxruntime_pybind_state_common.h | 2 +- onnxruntime/test/util/default_providers.cc | 5 ++- 13 files changed, 115 insertions(+), 42 deletions(-) create mode 100644 onnxruntime/core/providers/dml/dml_provider_options.h diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index 52018500b134c..cdb0338157561 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -3,6 +3,9 @@ #pragma once interface IMLOperatorRegistry; +interface IDMLDevice; +interface ID3D12CommandQueue; +interface ID3D12Resource; #include "core/common/status.h" #include "core/framework/data_transfer.h" @@ -28,7 +31,8 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); void FlushContext(onnxruntime::IExecutionProvider* provider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 51bce09fcbe9a..431113b3e1650 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -7,12 +7,14 @@ #include #include #include +#include #include "core/framework/op_kernel.h" #include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" struct AbstractOperatorDesc; interface IMLOperatorTensor; +interface IDMLOperator; struct DML_INPUT_GRAPH_EDGE_DESC; struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 7d171f2dddf19..277da1591b1e3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -67,7 +67,8 @@ namespace Dml ExecutionProvider::ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) : + bool enableMetacommands, + bool enableDynamicGraphFusion) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = commandQueue->GetDesc().Type; @@ -80,7 +81,7 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(commandQueue->GetDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands, enableDynamicGraphFusion); } std::vector> @@ -147,13 +148,12 @@ namespace Dml // Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap #define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000) - ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands) + ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands, bool enableDynamicGraphFusion) : m_d3d12Device(d3d12Device), m_dmlDevice(dmlDevice), - m_areMetacommandsEnabled(enableMetacommands) + m_areMetacommandsEnabled(enableMetacommands), + m_dynamicGraphFusionEnabled(enableDynamicGraphFusion) { - // TODO (pavignol): Remove me - m_areMetacommandsEnabled = false; D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { @@ -1094,6 +1094,11 @@ namespace Dml return m_areMetacommandsEnabled; } + bool ExecutionProviderImpl::DynamicGraphFusionEnabled() const noexcept + { + return m_dynamicGraphFusionEnabled; + } + std::shared_ptr ExecutionProviderImpl::GetInternalRegistrationInfoMap() const { @@ -1130,9 +1135,10 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) + bool enableMetacommands, + bool enableDynamicGraphFusion) { - return std::make_unique(dmlDevice, commandQueue, enableMetacommands); + return std::make_unique(dmlDevice, commandQueue, enableMetacommands, enableDynamicGraphFusion); } ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 31b893a2f25d7..3aaa11cdee479 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -5,6 +5,7 @@ #include "GraphTransformer.h" #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h" #include #include @@ -34,7 +35,8 @@ namespace Dml IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); void ReleaseCompletedReferences(); @@ -150,6 +152,7 @@ namespace Dml STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; + bool DynamicGraphFusionEnabled() const noexcept; std::shared_ptr GetGpuAllocator(); std::shared_ptr GetCpuInputAllocator(); @@ -184,6 +187,7 @@ namespace Dml ComPtr m_dmlDevice; bool m_isMcdmDevice = false; bool m_areMetacommandsEnabled = true; + bool m_dynamicGraphFusionEnabled = false; bool m_native16BitShaderOpsSupported = false; std::shared_ptr m_context; std::unique_ptr m_uploadHeap; @@ -236,7 +240,8 @@ namespace Dml explicit ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true + bool enableMetacommands, + bool enableDynamicGraphFusion ); std::unique_ptr GetDataTransfer() const final override @@ -299,9 +304,9 @@ namespace Dml return m_impl.Get(); } - void MetacommandsEnabled() + bool DynamicGraphFusionEnabled() const { - m_impl->MetacommandsEnabled(); + return m_impl->DynamicGraphFusionEnabled(); } virtual std::vector CreatePreferredAllocators() override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index d7a0a607cdec9..a8a6d6745e908 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -2,8 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include + #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +interface IDMLCompiledOperator; +struct DML_BUFFER_BINDING; +struct DML_BINDING_DESC; + namespace Dml { struct Binding diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index fde61e73c2124..486629afced8e 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -18,6 +18,7 @@ using Microsoft::WRL::ComPtr; #include "core/session/allocator_adapters.h" #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" +#include "core/providers/dml/dml_provider_options.h" #include "DmlExecutionProvider/src/ErrorHandling.h" #include "DmlExecutionProvider/src/GraphicsUnknownHelper.h" #include "DmlExecutionProvider/inc/DmlExecutionProvider.h" @@ -27,8 +28,11 @@ namespace onnxruntime { struct DMLProviderFactory : IExecutionProviderFactory { DMLProviderFactory(IDMLDevice* dml_device, - ID3D12CommandQueue* cmd_queue) : dml_device_(dml_device), - cmd_queue_(cmd_queue) {} + ID3D12CommandQueue* cmd_queue, + const OrtDmlProviderOptions& provider_options) : dml_device_(dml_device), + cmd_queue_(cmd_queue), + metacommands_enabled_(!provider_options.disable_metacommands), + dynamic_graph_fusion_enabled_(provider_options.enable_dynamic_graph_fusion) {} ~DMLProviderFactory() override {} std::unique_ptr CreateProvider() override; @@ -39,10 +43,11 @@ struct DMLProviderFactory : IExecutionProviderFactory { ComPtr dml_device_{}; ComPtr cmd_queue_{}; bool metacommands_enabled_ = true; + bool dynamic_graph_fusion_enabled_ = false; }; std::unique_ptr DMLProviderFactory::CreateProvider() { - auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_, dynamic_graph_fusion_enabled_); return provider; } @@ -51,7 +56,8 @@ void DMLProviderFactory::SetMetacommandsEnabled(bool metacommands_enabled) { } std::shared_ptr CreateExecutionProviderFactory_DML(IDMLDevice* dml_device, - ID3D12CommandQueue* cmd_queue) { + ID3D12CommandQueue* cmd_queue, + const OrtDmlProviderOptions& provider_options) { #ifndef _GAMING_XBOX // Validate that the D3D12 devices match between DML and the command queue. This specifically asks for IUnknown in // order to be able to compare the pointers for COM object identity. @@ -70,7 +76,7 @@ std::shared_ptr CreateExecutionProviderFactory_DML(ID const Env& env = Env::Default(); auto luid = d3d12_device->GetAdapterLuid(); env.GetTelemetryProvider().LogExecutionProviderEvent(&luid); - return std::make_shared(dml_device, cmd_queue); + return std::make_shared(dml_device, cmd_queue, provider_options); } void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* factory, bool metacommandsEnabled) { @@ -92,10 +98,6 @@ bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId); } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) { - return Create(device_id, /*skip_software_device_check*/ false); -} - Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device(int device_id, bool skip_software_device_check) { #ifdef _GAMING_XBOX @@ -153,8 +155,8 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID return dml_device; } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { - ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); +std::shared_ptr DMLProviderFactoryCreator::Create(const OrtDmlProviderOptions& provider_options) { + ComPtr d3d12_device = CreateD3D12Device(provider_options.device_id, provider_options.skip_software_device_check); D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; @@ -164,7 +166,7 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); auto dml_device = CreateDMLDevice(d3d12_device.Get()); - return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); + return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get(), provider_options); } } // namespace onnxruntime @@ -174,7 +176,9 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int // The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead. ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN - options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id)); + OrtDmlProviderOptions provider_options{}; + provider_options.device_id = device_id; + options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(provider_options)); API_IMPL_END return nullptr; } @@ -185,8 +189,10 @@ API_IMPL_END ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options, _In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue) { API_IMPL_BEGIN + OrtDmlProviderOptions provider_options{}; options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_DML(dml_device, - cmd_queue)); + cmd_queue, + provider_options)); API_IMPL_END return nullptr; } diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index 574f4410fe3e3..b0a9ffed2c410 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -9,12 +9,12 @@ #include #include "core/providers/providers.h" #include "core/providers/dml/dml_provider_factory.h" +#include "core/providers/dml/dml_provider_options.h" namespace onnxruntime { struct DMLProviderFactoryCreator { - static std::shared_ptr Create(int device_id); - static std::shared_ptr Create(int device_id, bool skip_software_device_check); + static std::shared_ptr Create(const OrtDmlProviderOptions& provider_options); static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); }; diff --git a/onnxruntime/core/providers/dml/dml_provider_options.h b/onnxruntime/core/providers/dml/dml_provider_options.h new file mode 100644 index 0000000000000..3f98e70310787 --- /dev/null +++ b/onnxruntime/core/providers/dml/dml_provider_options.h @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +struct OrtDmlProviderOptions { + int device_id = 0; + bool skip_software_device_check = false; + bool disable_metacommands = false; + bool enable_dynamic_graph_fusion = false; +}; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ada5cfb1c04e7..225c362676692 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -55,6 +55,7 @@ #include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h" #include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h" #include "core/providers/dml/dml_session_options_config_keys.h" +#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h" #endif #include "core/session/environment.h" #include "core/session/IOBinding.h" @@ -1536,7 +1537,9 @@ common::Status InferenceSession::Initialize() { record_runtime_optimization_produced_op_schema)); #ifdef USE_DML - if (execution_providers_.Get(kDmlExecutionProvider)) { + const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); + + if (dmlExecutionProvider) { // DML graph fusion is an important runtime optimization that cannot be done ahead of time; it must be disabled // when running in "offline mode" and saving an optimized model to disk. To support users that want to optimize // models offline, and then disable graph optimizations when running "online", this transformer ignores the ORT @@ -1546,18 +1549,20 @@ common::Status InferenceSession::Initialize() { if (dml_graph_fusion_enabled) { std::unique_ptr dmlGraphFusionTransformer = std::make_unique("DmlGraphFusionTransformer", - execution_providers_.Get(kDmlExecutionProvider)); + dmlExecutionProvider); if (dmlGraphFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); } ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); - std::unique_ptr dmlRuntimeGraphFusionTransformer = std::make_unique("DmlRuntimeGraphFusionTransformer", - execution_providers_.Get(kDmlExecutionProvider)); - if (dmlRuntimeGraphFusionTransformer == nullptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "DmlRuntimeGraphFusionTransformer is nullptr"); + if (static_cast(dmlExecutionProvider)->DynamicGraphFusionEnabled()) { + std::unique_ptr dmlRuntimeGraphFusionTransformer = std::make_unique("DmlRuntimeGraphFusionTransformer", + dmlExecutionProvider); + if (dmlRuntimeGraphFusionTransformer == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "DmlRuntimeGraphFusionTransformer is nullptr"); + } + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlRuntimeGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); } - ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlRuntimeGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); } // This transformer applies DML-specific fusions that go beyond what ORT offers by default diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index a8c217b0ff1f6..f41ad399009de 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -59,7 +59,12 @@ void addGlobalSchemaFunctions(pybind11::module& m) { onnxruntime::ArmNNProviderFactoryCreator::Create(0), #endif #ifdef USE_DML - onnxruntime::DMLProviderFactoryCreator::Create(0, /*skip_software_device_check*/ true), + []() { + OrtDmlProviderOptions provider_options{}; + provider_options.device_id = 0; + provider_options.skip_software_device_check = true; + return onnxruntime::DMLProviderFactoryCreator::Create(provider_options); + }(), #endif #ifdef USE_NNAPI onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional()), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 907ea0ec41e23..045fb427ab2e6 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -31,6 +31,7 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/provider_bridge_ort.h" +#include "core/providers/dml/dml_provider_options.h" #ifdef ENABLE_ATEN #include "contrib_ops/cpu/aten_ops/aten_op_executor.h" @@ -887,18 +888,36 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML - int device_id = 0; + + OrtDmlProviderOptions provider_options{}; + auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { for (auto option : it->second) { if (option.first == "device_id") { if (!option.second.empty()) { - device_id = std::stoi(option.second); + provider_options.device_id = std::stoi(option.second); + } + } else if (option.first == "disable_metacommands") { + if (option.second == "True" || option.second == "true") { + provider_options.disable_metacommands = true; + } else if (option.second == "False" || option.second == "false") { + provider_options.disable_metacommands = false; + } else { + ORT_THROW("[ERROR] [DirectML] The value for the key 'disable_metacommands' should be 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "enable_dynamic_graph_fusion") { + if (option.second == "True" || option.second == "true") { + provider_options.enable_dynamic_graph_fusion = true; + } else if (option.second == "False" || option.second == "false") { + provider_options.enable_dynamic_graph_fusion = false; + } else { + ORT_THROW("[ERROR] [DirectML] The value for the key 'enable_dynamic_graph_fusion' should be 'True' or 'False'. Default value is 'False'.\n"); } } } } - return onnxruntime::DMLProviderFactoryCreator::Create(device_id)->CreateProvider(); + return onnxruntime::DMLProviderFactoryCreator::Create(provider_options)->CreateProvider(); #endif } else if (type == kNnapiExecutionProvider) { #if defined(USE_NNAPI) diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 18a9079b5c4f2..e7292b757f786 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -449,7 +449,7 @@ std::shared_ptr CreateExecutionProviderFactory_VITISA const char* load_runtime_module); std::shared_ptr CreateExecutionProviderFactory_ACL(int use_arena); std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); -std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); +std::shared_ptr CreateExecutionProviderFactory_DML(const OrtDmlProviderOptions& params); std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 28af61e15b2b5..b79be1c58a2d7 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -268,7 +268,10 @@ std::unique_ptr DefaultCannExecutionProvider() { std::unique_ptr DefaultDmlExecutionProvider() { #ifdef USE_DML - if (auto factory = DMLProviderFactoryCreator::Create(0)) + OrtDmlProviderOptions provider_options{}; + provider_options.device_id = 0; + + if (auto factory = DMLProviderFactoryCreator::Create(provider_options)) return factory->CreateProvider(); #endif return nullptr; From b809a5b26e8a2b48c8cbcbbb18f92ca6ac08dd85 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 6 Oct 2023 22:25:39 -0700 Subject: [PATCH 08/20] Remove unused variables --- .../src/DmlRuntimeFusedGraphKernel.cpp | 27 +++---------------- .../src/DmlRuntimeFusedGraphKernel.h | 1 - .../src/DmlRuntimeGraphFusionHelper.cpp | 25 ----------------- 3 files changed, 3 insertions(+), 50 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 099777b85452d..2c9d27b0546f6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -20,7 +20,6 @@ namespace Dml const onnxruntime::OpKernelInfo& kernelInfo, std::shared_ptr indexedSubGraph, const onnxruntime::Path& modelPath, - std::shared_ptr>> inputDimParams, std::vector>&& subgraphNodes, std::vector&& subgraphInputs, std::vector&& subgraphOutputs, @@ -30,7 +29,6 @@ namespace Dml : OpKernel(kernelInfo), m_indexedSubGraph(std::move(indexedSubGraph)), m_modelPath(modelPath), - m_inputDimParams(std::move(inputDimParams)), m_subgraphNodes(std::move(subgraphNodes)), m_subgraphInputs(std::move(subgraphInputs)), m_subgraphOutputs(std::move(subgraphOutputs)), @@ -68,8 +66,6 @@ namespace Dml std::vector>& initializeResourceRefs, std::vector initInputBindings) const { - std::optional persistentResourceBinding; - // Allocate a persistent resource and initialize the operator UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) @@ -80,12 +76,12 @@ namespace Dml m_persistentResource.GetAddressOf(), m_persistentResourceAllocatorUnk.GetAddressOf())); - persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; } ORT_THROW_IF_FAILED(m_provider->InitializeOperator( m_compiledExecutionPlanOperator.Get(), - persistentResourceBinding ? &*persistentResourceBinding : nullptr, + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, gsl::make_span(initInputBindings))); // Queue references to objects which must be kept alive until resulting GPU work completes @@ -303,17 +299,10 @@ namespace Dml ComPtr m_winmlProvider; ComPtr m_provider; - // Re-usable command list, supporting descriptor heap, and DML binding table to update that heap. - ComPtr m_graphicsCommandList; - ComPtr m_commandAllocator; - ComPtr m_heap; - ComPtr m_bindingTable; - std::optional m_persistentResourceBinding; + mutable std::optional m_persistentResourceBinding; std::shared_ptr m_indexedSubGraph; const onnxruntime::Path& m_modelPath; - // TODO (pavignol): Remove m_inputDimParams if truly not needed - std::shared_ptr>> m_inputDimParams; std::vector> m_subgraphNodes; std::vector m_subgraphInputs; std::vector m_subgraphOutputs; @@ -326,26 +315,17 @@ namespace Dml // Bindings from previous executions of a re-used command list mutable std::vector> m_ownedCpuInputs; mutable ComPtr m_compiledExecutionPlanOperator; - mutable std::vector m_inputBindingAllocIds; - mutable std::vector m_outputBindingAllocIds; - mutable uint64_t m_tempBindingAllocId = 0; mutable std::vector m_inputsUsed; mutable ComPtr m_persistentResource; mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; mutable std::unordered_map m_inferredInputShapes; - - // Fence tracking the status of the command list's last execution, and whether its descriptor heap - // can safely be updated. - mutable ComPtr m_fence; - mutable uint64_t m_completionValue = 0; }; onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( const onnxruntime::OpKernelInfo& info, std::shared_ptr indexedSubGraph, const onnxruntime::Path& modelPath, - std::shared_ptr>> inputDimParams, std::vector>&& subgraphNodes, std::vector&& subgraphInputs, std::vector&& subgraphOutputs, @@ -357,7 +337,6 @@ namespace Dml info, std::move(indexedSubGraph), modelPath, - std::move(inputDimParams), std::move(subgraphNodes), std::move(subgraphInputs), std::move(subgraphOutputs), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h index d18a6d4671bc4..d679c5aa5667c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -11,7 +11,6 @@ namespace Dml const onnxruntime::OpKernelInfo& info, std::shared_ptr indexedSubGraph, const onnxruntime::Path& modelPath, - std::shared_ptr>> inputDimParams, std::vector>&& subgraphNodes, std::vector&& subgraphInputs, std::vector&& subgraphOutputs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp index bd787dbfb4382..71ef8e7962f6c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp @@ -369,9 +369,6 @@ namespace DmlRuntimeGraphFusionHelper subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); } - // We store the input dim params that haven't been overriden yet so that we can map their value at runtime once the real inputs are provided - auto inputDimParams = std::make_shared>>(); - // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph std::vector ownedInitializers; ownedInitializers.reserve(isInitializerTransferable.size()); @@ -393,7 +390,6 @@ namespace DmlRuntimeGraphFusionHelper // lamda captures for the kernel registration auto fused_kernel_func = [ - inputDimParams, indexedSubGraph, &modelPath, nodesInfo = std::move(nodesInfo), @@ -422,7 +418,6 @@ namespace DmlRuntimeGraphFusionHelper info, indexedSubGraph, modelPath, - std::move(inputDimParams), std::move(subgraphNodes), std::move(subgraphInputs), std::move(subgraphOutputs), @@ -453,26 +448,6 @@ namespace DmlRuntimeGraphFusionHelper auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); - inputDimParams->resize(fusedNode.InputDefs().size()); - - for (int inputIndex = 0; inputIndex < fusedNode.InputDefs().size(); ++inputIndex) - { - const onnxruntime::NodeArg* inputDef = fusedNode.InputDefs()[inputIndex]; - - ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->TypeAsProto()->has_tensor_type()); - const auto& tensorShape = inputDef->TypeAsProto()->tensor_type().shape(); - - (*inputDimParams)[inputIndex].resize(tensorShape.dim_size()); - - for (int i = 0; i < tensorShape.dim_size(); ++i) - { - if (tensorShape.dim(i).has_dim_param()) - { - (*inputDimParams)[inputIndex][i] = tensorShape.dim(i).dim_param(); - } - } - } - graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); } } From bc322c1cdaf9938e2cf3bd8725465f6725db19b8 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Fri, 6 Oct 2023 23:00:18 -0700 Subject: [PATCH 09/20] Add check in case CPU inputs changed --- .../src/DmlRuntimeFusedGraphKernel.cpp | 44 +++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 2c9d27b0546f6..fd26a90be87e7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -99,36 +99,62 @@ namespace Dml { ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount()); - bool recompiledNeeded = m_compiledExecutionPlanOperator == nullptr; + bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex) { const auto& input = kernelContext->RequiredInput(inputIndex); const std::string& inputName = m_subgraphInputs[inputIndex]->Name(); - auto iter = m_inferredInputShapes.find(inputName); + auto shapeIter = m_inferredInputShapes.find(inputName); - if (iter == m_inferredInputShapes.end()) + if (shapeIter == m_inferredInputShapes.end()) { m_inferredInputShapes[inputName] = input.Shape(); - recompiledNeeded = true; + recompileNeeded = true; } - else if (iter->second != input.Shape()) + else if (shapeIter->second != input.Shape()) { - iter->second = input.Shape(); - recompiledNeeded = true; + shapeIter->second = input.Shape(); + recompileNeeded = true; } // If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list if (input.Location().device.Type() == OrtDevice::CPU) { - // TODO (pavignol): Force recompile if CPU data changed auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName); + + // We can only avoid recompiling the graph when all CPU inputs are identical + auto initializerIter = m_isInitializerTransferable.find(inputName); + + if (initializerIter != m_isInitializerTransferable.end()) + { + if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length()) + { + for (int i = 0; i < inputProto.raw_data().length(); ++i) + { + if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i]) + { + recompileNeeded = true; + break; + } + } + } + else + { + recompileNeeded = true; + } + } + else + { + recompileNeeded = true; + } + m_ownedCpuInputs.push_back(std::make_unique(std::move(inputProto))); m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false); } } - if (recompiledNeeded) + if (recompileNeeded) { // Go through all the node args and replace their shapes with the real ones for (auto& nodeArg : m_intermediateNodeArgs) From f850a2d340f2b25a000bf235420e5b94ea8ccabb Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sat, 7 Oct 2023 02:34:01 -0700 Subject: [PATCH 10/20] Refactor --- .../src/DmlGraphFusionHelper.cpp | 168 ++++++- .../src/DmlGraphFusionHelper.h | 9 + .../src/DmlRuntimeFusedGraphKernel.cpp | 8 +- .../src/DmlRuntimeGraphFusionHelper.cpp | 454 ------------------ .../src/DmlRuntimeGraphFusionHelper.h | 83 ---- .../src/DmlRuntimeGraphFusionTransformer.cpp | 9 +- 6 files changed, 184 insertions(+), 547 deletions(-) delete mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp delete mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 51b93efb3a646..cd74e7fa92940 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -1,7 +1,7 @@ #pragma once #include "DmlGraphFusionHelper.h" - +#include "DmlRuntimeFusedGraphKernel.h" namespace Dml { @@ -501,5 +501,171 @@ namespace DmlGraphFusionHelper graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode); } + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable) + { + struct NodeInfo + { + std::string name; + std::string opType; + std::string description; + std::string domain; + onnxruntime::NodeAttributes attributes; + std::vector inputDefPointers; + std::vector outputDefPointers; + }; + + auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap( + graph, + *indexedSubGraph, + std::move(graphNodePropertyMap)); + + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs; + + std::vector nodesInfo; + nodesInfo.reserve(indexedSubGraph->nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + std::vector nodeAttributes; + nodeAttributes.reserve(indexedSubGraph->nodes.size()); + + std::vector> intermediateNodeArgs; + + for (size_t sortedNodeIndex : indexedSubGraph->nodes) + { + auto node = graph.GetNode(sortedNodeIndex); + + nodeAttributes.push_back(node->GetAttributes()); + + NodeInfo nodeInfo{}; + nodeInfo.name = node->Name(); + nodeInfo.opType = node->OpType(); + nodeInfo.description = node->Description(); + nodeInfo.domain = node->Domain(); + nodeInfo.attributes = node->GetAttributes(); + nodeInfo.inputDefPointers.reserve(node->InputDefs().size()); + nodeInfo.outputDefPointers.reserve(node->OutputDefs().size()); + + for (const onnxruntime::NodeArg* inputDef : node->InputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(inputDef->Name(), inputDef->TypeAsProto())); + nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + for (const onnxruntime::NodeArg* outputDef : node->OutputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(outputDef->Name(), outputDef->TypeAsProto())); + nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + nodesInfo.push_back(std::move(nodeInfo)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + + // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph + std::vector ownedInitializers; + ownedInitializers.reserve(isInitializerTransferable.size()); + + for (auto& kvp : isInitializerTransferable) + { + ONNX_NAMESPACE::TensorProto tensorProto; + tensorProto.set_data_type(kvp.second.first->data_type()); + tensorProto.set_raw_data(kvp.second.first->raw_data()); + tensorProto.set_name(kvp.second.first->name()); + + for (int i = 0; i < kvp.second.first->dims_size(); ++i) + { + tensorProto.add_dims(kvp.second.first->dims(i)); + } + ownedInitializers.push_back(std::move(tensorProto)); + kvp.second.first = &ownedInitializers.back(); + } + + // lamda captures for the kernel registration + auto fused_kernel_func = [ + indexedSubGraph, + &modelPath, + nodesInfo = std::move(nodesInfo), + intermediateNodeArgs = std::move(intermediateNodeArgs), + subgraphInputs = std::move(subgraphInputs), + subgraphOutputs = std::move(subgraphOutputs), + partitionNodePropsMap = std::move(partitionNodePropsMap), + ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status + { + std::vector> subgraphNodes; + subgraphNodes.reserve(nodesInfo.size()); + + for (const NodeInfo& nodeInfo : nodesInfo) + { + subgraphNodes.emplace_back(std::make_shared( + nodeInfo.name, + nodeInfo.opType, + nodeInfo.description, + nodeInfo.inputDefPointers, + nodeInfo.outputDefPointers, + &nodeInfo.attributes, + nodeInfo.domain)); + } + + out.reset(CreateRuntimeFusedGraphKernel( + info, + indexedSubGraph, + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers))); + return Status::OK(); + }; + + // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. + onnxruntime::KernelDefBuilder builder; + builder.SetName(indexedSubGraph->GetMetaDef()->name) + .SetDomain(indexedSubGraph->GetMetaDef()->domain) + .SinceVersion(indexedSubGraph->GetMetaDef()->since_version) + .Provider(onnxruntime::kDmlExecutionProvider); + + // Force the CPU inputs to be allocated on the CPU + for (int i = 0; i < subGraphInputArgNames.size(); ++i) + { + if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end()) + { + builder.InputMemoryType(OrtMemTypeCPUInput, i); + } + } + + ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); + + auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); + fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + + graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); + } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index 030cffc2a8794..f8f6162aaa1e0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index fd26a90be87e7..649b6d29659c1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -3,9 +3,9 @@ #include "precomp.h" -#include "MLOperatorAuthorImpl.h" -#include "DmlRuntimeFusedGraphKernel.h" -#include "DmlRuntimeGraphFusionHelper.h" +#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h" using namespace Windows::AI::MachineLearning::Adapter; @@ -206,7 +206,7 @@ namespace Dml } // Compile the operator - m_compiledExecutionPlanOperator = DmlRuntimeGraphFusionHelper::TryCreateCompiledOperator( + m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( graphDesc, *m_indexedSubGraph, providerImpl); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp deleted file mode 100644 index 71ef8e7962f6c..0000000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.cpp +++ /dev/null @@ -1,454 +0,0 @@ -#pragma once - -#include "DmlRuntimeGraphFusionHelper.h" - - -namespace Dml -{ -namespace DmlRuntimeGraphFusionHelper -{ - Microsoft::WRL::ComPtr - CreateResource( - const ExecutionProviderImpl* provider, - const std::byte* tensorPtr, - size_t tensorByteSize) - { - Microsoft::WRL::ComPtr buffer; - - D3D12_HEAP_PROPERTIES heapProperties = { - D3D12_HEAP_TYPE_DEFAULT, D3D12_CPU_PAGE_PROPERTY_UNKNOWN, D3D12_MEMORY_POOL_UNKNOWN, 0, 0}; - - D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER, - 0, - static_cast((tensorByteSize + 3) & ~3), - 1, - 1, - 1, - DXGI_FORMAT_UNKNOWN, - {1, 0}, - D3D12_TEXTURE_LAYOUT_ROW_MAJOR, - D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; - - Microsoft::WRL::ComPtr d3dDevice; - ORT_THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf())); - - ORT_THROW_IF_FAILED(d3dDevice->CreateCommittedResource( - &heapProperties, - D3D12_HEAP_FLAG_NONE, - &resourceDesc, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - nullptr, - IID_GRAPHICS_PPV_ARGS(buffer.GetAddressOf()))); - - ORT_THROW_IF_FAILED(provider->UploadToResource(buffer.Get(), tensorPtr, tensorByteSize)); - - return buffer; - } - - Microsoft::WRL::ComPtr - CreateCpuResource( - const ExecutionProviderImpl* provider, - const std::byte* tensorPtr, - size_t tensorByteSize) - { - Microsoft::WRL::ComPtr buffer; - - D3D12_HEAP_PROPERTIES heapProperties = { - D3D12_HEAP_TYPE_CUSTOM, D3D12_CPU_PAGE_PROPERTY_WRITE_COMBINE, D3D12_MEMORY_POOL_L0, 0, 0}; - - D3D12_RESOURCE_DESC resourceDesc = {D3D12_RESOURCE_DIMENSION_BUFFER, - 0, - static_cast((tensorByteSize + 3) & ~3), - 1, - 1, - 1, - DXGI_FORMAT_UNKNOWN, - {1, 0}, - D3D12_TEXTURE_LAYOUT_ROW_MAJOR, - D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; - - Microsoft::WRL::ComPtr d3dDevice; - ORT_THROW_IF_FAILED(provider->GetD3DDevice(d3dDevice.GetAddressOf())); - - ORT_THROW_IF_FAILED(d3dDevice->CreateCommittedResource( - &heapProperties, - D3D12_HEAP_FLAG_NONE, - &resourceDesc, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - nullptr, - IID_GRAPHICS_PPV_ARGS(buffer.GetAddressOf()))); - - // Map the buffer and copy the data - void* bufferData = nullptr; - D3D12_RANGE range = {0, tensorByteSize}; - ORT_THROW_IF_FAILED(buffer->Map(0, &range, &bufferData)); - memcpy(bufferData, tensorPtr, tensorByteSize); - buffer->Unmap(0, &range); - - return buffer; - } - - void UnwrapTensor( - Windows::AI::MachineLearning::Adapter::IWinmlExecutionProvider* winmlProvider, - const onnxruntime::Tensor* tensor, - ID3D12Resource** resource, - uint64_t* allocId) - { - IUnknown* allocationUnk = static_cast(const_cast(tensor->DataRaw())); - Microsoft::WRL::ComPtr resourceUnk; - winmlProvider->GetABIDataInterface(false, allocationUnk, &resourceUnk); - - *allocId = winmlProvider->TryGetPooledAllocationId(allocationUnk, 0); - - ORT_THROW_IF_FAILED(resourceUnk->QueryInterface(resource)); - } - - std::unordered_map> - GetInitializerToPartitionMap( - const onnxruntime::GraphViewer& graph, - gsl::span> partitions - ) - { - std::unordered_map> initializerPartitionMap; - for (uint32_t partitionIndex = 0; partitionIndex < gsl::narrow_cast(partitions.size()); ++partitionIndex) - { - auto& partition = partitions[partitionIndex]; - - // Skip partitions which have been merged into other partitions - if (partition->GetRootMergedPartition() != partition.get()) - { - continue; - } - - for (const std::string& input : partition->GetInputs()) - { - const onnx::TensorProto* tensor = nullptr; - if (graph.GetInitializedTensor(input, tensor)) - { - initializerPartitionMap[tensor].push_back(partitionIndex); - } - } - } - - return initializerPartitionMap; - } - - void ConvertGraphDesc( - const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, - const uint32_t inputCount, - const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, - _Inout_ std::vector& dmlGraphNodes, - _Inout_ std::vector& dmlInputEdges, - _Inout_ std::vector& dmlOutputEdges, - _Inout_ std::vector& dmlIntermediateEdges) - { - for (size_t i = 0; i < graphDesc.nodes.size(); ++i) - { - auto& nodeInfo = graphDesc.nodes[i]; - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()}; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; - } - - for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) - { - dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]}; - } - - for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i) - { - dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]}; - } - - for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i) - { - dmlIntermediateEdges[i] = - DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]}; - } - - dmlGraphDesc.InputCount = inputCount; - dmlGraphDesc.OutputCount = outputCount; - dmlGraphDesc.NodeCount = gsl::narrow_cast(dmlGraphNodes.size()); - dmlGraphDesc.Nodes = dmlGraphNodes.data(); - dmlGraphDesc.InputEdgeCount = gsl::narrow_cast(dmlInputEdges.size()); - dmlGraphDesc.InputEdges = dmlInputEdges.data(); - dmlGraphDesc.OutputEdgeCount = gsl::narrow_cast(dmlOutputEdges.size()); - dmlGraphDesc.OutputEdges = dmlOutputEdges.data(); - dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast(dmlIntermediateEdges.size()); - dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data(); - } - - onnxruntime::IndexedSubGraph CreateIndexedSubGraph( - GraphPartition* partition, - uint32_t partitionIndex, - const std::string& partitionKernelPrefix) - { - assert(partition->IsDmlGraphPartition()); - - onnxruntime::IndexedSubGraph indexedSubGraph; - // Create a definition for the node. The name must be unique. - auto def = std::make_unique(); - def->name = DmlRuntimeGraphFusionTransformer::DML_GRAPH_FUSION_NODE_NAME_PREFIX + partitionKernelPrefix + std::to_string(partitionIndex); - def->domain = DmlRuntimeGraphFusionTransformer::DML_GRAPH_FUSION_NODE_DOMAIN; - def->since_version = 1; - def->inputs.insert(def->inputs.begin(), partition->GetInputs().begin(), partition->GetInputs().end()); - def->outputs.insert(def->outputs.begin(), partition->GetOutputs().begin(), partition->GetOutputs().end()); - - indexedSubGraph.SetMetaDef(std::move(def)); - indexedSubGraph.nodes = std::move(partition->GetNodeIndices()); - - return indexedSubGraph; - } - - std::unordered_map CreatePartitionNodePropsMap( - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - std::unordered_map&& graphNodePropertyMap) - { - // Populate properties which will be passed to OpKernel for this graph via the function below - std::unordered_map partitionNodePropsMap; - for (auto nodeIndex : indexedSubGraph.nodes) - { - const onnxruntime::Node* node = graph.GetNode(nodeIndex); - -#ifdef PRINT_PARTITON_INFO - printf("Partition %u\t%s\n", partitionIndex, GraphDescBuilder::GetUniqueNodeName(*node).c_str()); -#endif - partitionNodePropsMap.insert(std::make_pair( - GraphDescBuilder::GetUniqueNodeName(*node), std::move(graphNodePropertyMap[node]))); - } - -#ifdef PRINT_PARTITON_INFO - printf("\n"); -#endif - - return partitionNodePropsMap; - } - - Microsoft::WRL::ComPtr TryCreateCompiledOperator( - const GraphDescBuilder::GraphDesc& graphDesc, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl) - { - const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); - const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); - - // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator - DML_GRAPH_DESC dmlGraphDesc = {}; - std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); - std::vector dmlGraphNodes(graphDesc.nodes.size()); - std::vector dmlInputEdges(graphDesc.inputEdges.size()); - std::vector dmlOutputEdges(graphDesc.outputEdges.size()); - std::vector dmlIntermediateEdges(graphDesc.intermediateEdges.size()); - ConvertGraphDesc( - graphDesc, - dmlGraphDesc, - fusedNodeInputCount, - fusedNodeOutputCount, - dmlOperatorGraphNodes, - dmlGraphNodes, - dmlInputEdges, - dmlOutputEdges, - dmlIntermediateEdges); - - DML_EXECUTION_FLAGS executionFlags = DML_EXECUTION_FLAG_NONE; - if (graphDesc.reuseCommandList) - { - executionFlags |= DML_EXECUTION_FLAG_DESCRIPTORS_VOLATILE; - } - - // Query DML execution provider to see if metacommands is enabled - if (!providerImpl->MetacommandsEnabled()) - { - executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS; - } - - ComPtr device; - ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); - - ComPtr device1; - ORT_THROW_IF_FAILED(device.As(&device1)); - - ComPtr compiledExecutionPlanOperator; - ORT_THROW_IF_FAILED(device1->CompileGraph( - &dmlGraphDesc, - executionFlags, - IID_PPV_ARGS(&compiledExecutionPlanOperator))); - - // UINT32_MAX is currently the maximum number of bytes allowed by D3D12 for the offset of a view over a resource - if (compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize > UINT32_MAX) - { - return nullptr; - } - - return compiledExecutionPlanOperator; - } - - struct NodeInfo - { - std::string name; - std::string opType; - std::string description; - std::string domain; - onnxruntime::NodeAttributes attributes; - std::vector inputDefPointers; - std::vector outputDefPointers; - }; - - void RegisterKernel( - onnxruntime::Graph& graph, - onnxruntime::KernelRegistry* registryForPartitionKernels, - const ExecutionProviderImpl* providerImpl, - std::unordered_map graphNodePropertyMap, - const std::unordered_set& dynamicCpuInputMap, - std::shared_ptr indexedSubGraph, - std::unordered_map>&& isInitializerTransferable) - { - auto partitionNodePropsMap = DmlRuntimeGraphFusionHelper::CreatePartitionNodePropsMap( - graph, - *indexedSubGraph, - std::move(graphNodePropertyMap)); - - auto modelPath = graph.ModelPath(); - - const gsl::span subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs; - const gsl::span subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs; - - std::vector nodesInfo; - nodesInfo.reserve(indexedSubGraph->nodes.size()); - - std::vector subgraphInputs; - subgraphInputs.reserve(subGraphInputArgNames.size()); - - std::vector subgraphOutputs; - subgraphOutputs.reserve(subGraphOutputArgNames.size()); - - std::vector nodeAttributes; - nodeAttributes.reserve(indexedSubGraph->nodes.size()); - - std::vector> intermediateNodeArgs; - - for (size_t sortedNodeIndex : indexedSubGraph->nodes) - { - auto node = graph.GetNode(sortedNodeIndex); - - nodeAttributes.push_back(node->GetAttributes()); - - NodeInfo nodeInfo{}; - nodeInfo.name = node->Name(); - nodeInfo.opType = node->OpType(); - nodeInfo.description = node->Description(); - nodeInfo.domain = node->Domain(); - nodeInfo.attributes = node->GetAttributes(); - nodeInfo.inputDefPointers.reserve(node->InputDefs().size()); - nodeInfo.outputDefPointers.reserve(node->OutputDefs().size()); - - for (const onnxruntime::NodeArg* inputDef : node->InputDefs()) - { - intermediateNodeArgs.emplace_back(std::make_shared(inputDef->Name(), inputDef->TypeAsProto())); - nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get()); - } - - for (const onnxruntime::NodeArg* outputDef : node->OutputDefs()) - { - intermediateNodeArgs.emplace_back(std::make_shared(outputDef->Name(), outputDef->TypeAsProto())); - nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get()); - } - - nodesInfo.push_back(std::move(nodeInfo)); - } - - for (const std::string& graphInputName : subGraphInputArgNames) - { - subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); - } - - for (const std::string& graphOutputName : subGraphOutputArgNames) - { - subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); - } - - // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph - std::vector ownedInitializers; - ownedInitializers.reserve(isInitializerTransferable.size()); - - for (auto& kvp : isInitializerTransferable) - { - ONNX_NAMESPACE::TensorProto tensorProto; - tensorProto.set_data_type(kvp.second.first->data_type()); - tensorProto.set_raw_data(kvp.second.first->raw_data()); - tensorProto.set_name(kvp.second.first->name()); - - for (int i = 0; i < kvp.second.first->dims_size(); ++i) - { - tensorProto.add_dims(kvp.second.first->dims(i)); - } - ownedInitializers.push_back(std::move(tensorProto)); - kvp.second.first = &ownedInitializers.back(); - } - - // lamda captures for the kernel registration - auto fused_kernel_func = [ - indexedSubGraph, - &modelPath, - nodesInfo = std::move(nodesInfo), - intermediateNodeArgs = std::move(intermediateNodeArgs), - subgraphInputs = std::move(subgraphInputs), - subgraphOutputs = std::move(subgraphOutputs), - partitionNodePropsMap = std::move(partitionNodePropsMap), - ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status - { - std::vector> subgraphNodes; - subgraphNodes.reserve(nodesInfo.size()); - - for (const NodeInfo& nodeInfo : nodesInfo) - { - subgraphNodes.emplace_back(std::make_shared( - nodeInfo.name, - nodeInfo.opType, - nodeInfo.description, - nodeInfo.inputDefPointers, - nodeInfo.outputDefPointers, - &nodeInfo.attributes, - nodeInfo.domain)); - } - - out.reset(CreateRuntimeFusedGraphKernel( - info, - indexedSubGraph, - modelPath, - std::move(subgraphNodes), - std::move(subgraphInputs), - std::move(subgraphOutputs), - std::move(intermediateNodeArgs), - std::move(partitionNodePropsMap), - std::move(ownedInitializers))); - return Status::OK(); - }; - - // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. - onnxruntime::KernelDefBuilder builder; - builder.SetName(indexedSubGraph->GetMetaDef()->name) - .SetDomain(indexedSubGraph->GetMetaDef()->domain) - .SinceVersion(indexedSubGraph->GetMetaDef()->since_version) - .Provider(onnxruntime::kDmlExecutionProvider); - - // Force the CPU inputs to be allocated on the CPU - for (int i = 0; i < subGraphInputArgNames.size(); ++i) - { - if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end()) - { - builder.InputMemoryType(OrtMemTypeCPUInput, i); - } - } - - ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); - - auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); - fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); - - graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); - } -} -} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h deleted file mode 100644 index 3da482bd72aaa..0000000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionHelper.h +++ /dev/null @@ -1,83 +0,0 @@ -#pragma once - -#include "precomp.h" -#include "GraphDescBuilder.h" -#include "ExecutionProvider.h" -#include "GraphPartitioner.h" -#include "DmlRuntimeFusedGraphKernel.h" -#include "MLOperatorAuthorImpl.h" - - -namespace Dml -{ -namespace DmlRuntimeGraphFusionHelper -{ - template - static T AlignToPow2(T offset, T alignment) - { - static_assert(std::is_unsigned_v); - assert(alignment != 0); - assert((alignment & (alignment - 1)) == 0); - return (offset + alignment - 1) & ~(alignment - 1); - } - - Microsoft::WRL::ComPtr - CreateResource( - const ExecutionProviderImpl* provider, - const std::byte* tensorPtr, - size_t tensorByteSize); - - Microsoft::WRL::ComPtr - CreateCpuResource( - const ExecutionProviderImpl* provider, - const std::byte* tensorPtr, - size_t tensorByteSize); - - void UnwrapTensor( - Windows::AI::MachineLearning::Adapter::IWinmlExecutionProvider* winmlProvider, - const onnxruntime::Tensor* tensor, - ID3D12Resource** resource, - uint64_t* allocId); - - std::unordered_map> - GetInitializerToPartitionMap( - const onnxruntime::GraphViewer& graph, - gsl::span> partitions - ); - - void ConvertGraphDesc( - const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, - const uint32_t inputCount, - const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, - _Inout_ std::vector& dmlGraphNodes, - _Inout_ std::vector& dmlInputEdges, - _Inout_ std::vector& dmlOutputEdges, - _Inout_ std::vector& dmlIntermediateEdges); - - onnxruntime::IndexedSubGraph CreateIndexedSubGraph( - GraphPartition* partition, - uint32_t partitionIndex, - const std::string& partitionKernelPrefix); - - std::unordered_map CreatePartitionNodePropsMap( - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - std::unordered_map&& graphNodePropertyMap); - - Microsoft::WRL::ComPtr TryCreateCompiledOperator( - const GraphDescBuilder::GraphDesc& graphDesc, - const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl); - - void RegisterKernel( - onnxruntime::Graph& graph, - onnxruntime::KernelRegistry* registryForPartitionKernels, - const ExecutionProviderImpl* providerImpl, - std::unordered_map graphNodePropertyMap, - const std::unordered_set& dynamicCpuInputMap, - std::shared_ptr indexedSubGraph, - std::unordered_map>&& isInitializerTransferable); -} -} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp index f1c7e723d4b18..38d8b960f3b53 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -10,8 +10,7 @@ #include "core/optimizer/constant_sharing.h" #include "DmlRuntimeFusedGraphKernel.h" #include "MLOperatorAuthorImpl.h" -#include "DmlRuntimeGraphFusionHelper.h" - +#include "DmlGraphFusionHelper.h" namespace Dml { @@ -105,7 +104,7 @@ namespace Dml std::vector> compiledPartitionInfos(partitions.size()); // Create a map between each initialized tensor and the partition(s) it is part of. - auto initializerPartitionMap = DmlRuntimeGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); + auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) { @@ -136,7 +135,7 @@ namespace Dml compiledPartitionInfos[partitionIndex] = std::make_shared(); compiledPartitionInfos[partitionIndex]->indexedSubGraph = std::make_shared( - DmlRuntimeGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix)); + DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix)); compiledPartitionInfos[partitionIndex]->isInitializerTransferable = std::move(isInitializerTransferable); } } @@ -146,7 +145,7 @@ namespace Dml // Null compiled operators were not DML partitions if (compiledPartitionInfo) { - DmlRuntimeGraphFusionHelper::RegisterKernel( + DmlGraphFusionHelper::RegisterDynamicKernel( graph, m_providerImpl->GetKernelRegistry().get(), m_providerImpl, From 339abded335887a74d6358e1c89b559d7199bf23 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sat, 7 Oct 2023 03:03:08 -0700 Subject: [PATCH 11/20] More refactoring --- .../inc/IWinmlExecutionProvider.h | 3 +-- .../DmlExecutionProvider/src/AbiCustomRegistry.cpp | 8 +++----- .../dml/DmlExecutionProvider/src/GraphDescBuilder.cpp | 11 ++++++----- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 431113b3e1650..074f13b309181 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -89,8 +89,6 @@ namespace Windows::AI::MachineLearning::Adapter std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; - EdgeShapes outputShapes; - const std::unordered_map>* inferredOutputShapes; }; using GraphNodeFactory = std::function; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index e86737fa9665a..eb068087de4ad 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -492,6 +492,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo ) { @@ -499,10 +500,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( onnxruntime::OpNodeProtoHelper protoHelper(&nodeContext); // Use the same list of required constant inputs for the shape inferrer and the kernel. - EdgeShapes outputShapes; - InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, outputShapes); - - graphNodeCreateInfo->outputShapes = outputShapes; + InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes); // Create the kernel while allowing input shape and output shape queries according to options ComPtr kernelInfoWrapper = wil::MakeOrThrow( @@ -510,7 +508,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( executionHandle, true, inputShapesOverrides, - &outputShapes, + outputShapes, &defaultAttributesCapture, graphNodeCreateInfo, constantCpuInputCapture, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 546fb56bc780e..c620859495b15 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -246,7 +246,6 @@ namespace Dml::GraphDescBuilder return tensor; }; - DmlGraphNodeCreateInfo graphNodeCreateInfo; EdgeShapes inputShapesOverrides(node.InputDefs().size()); // Override the input shapes with shapes that were previously inferred @@ -269,19 +268,21 @@ namespace Dml::GraphDescBuilder } } - graphNodeCreateInfo.inferredOutputShapes = &inferredOutputShapes; + EdgeShapes outputShapes; + DmlGraphNodeCreateInfo graphNodeCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, &inputShapesOverrides, + /*out*/ &outputShapes, /*out*/ &graphNodeCreateInfo ); - ORT_THROW_HR_IF(E_UNEXPECTED, graphNodeCreateInfo.outputShapes.EdgeCount() != node.OutputDefs().size()); + ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); for (int i = 0; i < node.OutputDefs().size(); ++i) { - inferredOutputShapes[node.OutputDefs()[i]->Name()] = graphNodeCreateInfo.outputShapes.GetShape(i); + inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); } // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. @@ -380,7 +381,7 @@ namespace Dml::GraphDescBuilder operatorGraphOutputEdge.FromNodeOutputIndex }; - nodeOutputShapes[arg->Name()] = graphNodeCreateInfo.outputShapes; + nodeOutputShapes[arg->Name()] = outputShapes; } } From ebfe95ff6df9d34ffdc6ab7220bf695bc4ca8e7a Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 10 Oct 2023 14:00:41 -0700 Subject: [PATCH 12/20] Fix crash when empty shapes --- .../providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index c620859495b15..3fc8f415e5a58 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -258,7 +258,7 @@ namespace Dml::GraphDescBuilder { inputShapesOverrides.GetMutableShape(inputIndex) = outputShapesIter->second; } - else + else if (inputDef->HasTensorOrScalarShape()) { for (int i = 0; i < inputDef->Shape()->dim_size(); ++i) { From e5fa87eac980960fb0ac6a7a5a5853e4418f7105 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 10 Oct 2023 16:00:53 -0700 Subject: [PATCH 13/20] Uncomment assert --- onnxruntime/core/framework/allocation_planner.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index b560b38752021..0bf27fdf5e5dc 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -234,7 +234,7 @@ class PlannerImpl { int DecrementUseCount(OrtValueIndex n) { int& use_count = --UseCount(n); - // assert(use_count >= 0); + assert(use_count >= 0); return use_count; } From 5dc0aaf7033002e3346be8d773410257d8642448 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 16 Oct 2023 10:47:56 -0700 Subject: [PATCH 14/20] Address PR comments --- .../src/DmlRuntimeFusedGraphKernel.cpp | 2 +- .../src/DmlRuntimeGraphFusionTransformer.cpp | 27 +++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 649b6d29659c1..1d9b21941c972 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -319,7 +319,7 @@ namespace Dml persistentResourceBinding, inputBindings, outputBindings)); - } + } private: ComPtr m_winmlProvider; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp index 38d8b960f3b53..ee59171d1844c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -35,32 +35,32 @@ namespace Dml onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, - int graph_level, + int graphLevel, const onnxruntime::logging::Logger& logger) const { - return ApplyImplHelper(graph, modified, graph_level, logger, {}); + return ApplyImplHelper(graph, modified, graphLevel, logger, {}); } onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImplHelper( onnxruntime::Graph& graph, bool& modified, - int graph_level, + int graphLevel, const onnxruntime::logging::Logger& logger, const std::unordered_map& implicitInputDefs) const { - onnxruntime::ProviderType provider_type = onnxruntime::kDmlExecutionProvider; + onnxruntime::ProviderType providerType = onnxruntime::kDmlExecutionProvider; const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); - const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; - const auto kernel_lookup = onnxruntime::KernelLookup{provider_type, + const auto kernelTypeStrResolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; + const auto kernelLookup = onnxruntime::KernelLookup{providerType, gsl::make_span(®istry, 1), - kernel_type_str_resolver}; + kernelTypeStrResolver}; - onnxruntime::GraphViewer graph_viewer(graph); - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + onnxruntime::GraphViewer graphViewer(graph); + const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); - for (auto node_index : node_topology_list) + for (auto nodeIndex : nodeTopologyList) { - auto* node = graph.GetNode(node_index); + auto* node = graph.GetNode(nodeIndex); if (!node) { continue; // node was removed @@ -75,7 +75,7 @@ namespace Dml for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) { auto& subgraph = *entry.second; - ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graph_level + 1, logger, subgraphImplicitInputDefs)); + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graphLevel + 1, logger, subgraphImplicitInputDefs)); } } @@ -84,11 +84,10 @@ namespace Dml std::unordered_map graphNodePropertyMap; std::unordered_set requiredInitializerMap; std::unordered_set dynamicCpuInputMap; - onnxruntime::GraphViewer graphViewer(graph); std::vector> partitions = BuildPartitions( graphViewer, *m_providerImpl->GetInternalRegistrationInfoMap(), - kernel_lookup, + kernelLookup, m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, From 2ac62ab1c6b1eb0392f8ed6bd84a5b97352606a9 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Wed, 18 Oct 2023 21:06:32 -0700 Subject: [PATCH 15/20] Address PR comments --- .../src/DmlRuntimeGraphFusionTransformer.cpp | 7 ++++--- .../src/DmlRuntimeGraphFusionTransformer.h | 6 +++--- onnxruntime/core/providers/dml/dml_provider_factory.cc | 6 +++--- .../core/providers/dml/dml_provider_factory_creator.h | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp index ee59171d1844c..6318b0d5e2865 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -51,9 +51,10 @@ namespace Dml onnxruntime::ProviderType providerType = onnxruntime::kDmlExecutionProvider; const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); const auto kernelTypeStrResolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; - const auto kernelLookup = onnxruntime::KernelLookup{providerType, - gsl::make_span(®istry, 1), - kernelTypeStrResolver}; + const auto kernelLookup = onnxruntime::KernelLookup( + providerType, + gsl::make_span(®istry, 1), + kernelTypeStrResolver); onnxruntime::GraphViewer graphViewer(graph); const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h index 602a6373d483c..bddb2feceb523 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h @@ -25,9 +25,9 @@ class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer private: onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, - bool& modified, - int graph_level, - const onnxruntime::logging::Logger& logger) const final; + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger) const final; onnxruntime::common::Status ApplyImplHelper( onnxruntime::Graph& graph, diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index ef8b8e82f31e1..d587424fe01f8 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -33,9 +33,9 @@ struct DMLProviderFactory : IExecutionProviderFactory { ID3D12CommandQueue* cmd_queue, bool disable_metacommands, bool enable_dynamic_graph_fusion) : dml_device_(dml_device), - cmd_queue_(cmd_queue), - metacommands_enabled_(!disable_metacommands), - dynamic_graph_fusion_enabled_(enable_dynamic_graph_fusion) {} + cmd_queue_(cmd_queue), + metacommands_enabled_(!disable_metacommands), + dynamic_graph_fusion_enabled_(enable_dynamic_graph_fusion) {} ~DMLProviderFactory() override {} std::unique_ptr CreateProvider() override; diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index ca2cdee24f3a2..0fab9fe902526 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -32,7 +32,7 @@ struct DMLProviderFactoryCreator { bool enable_dynamic_graph_fusion); static std::shared_ptr CreateFromAdapterList( - std::vector>&& dxcore_devices, + std::vector>&& dxcore_devices, bool disable_metacommands, bool enable_dynamic_graph_fusion); From ba27e0dff0f72757d7a5df54fbd06187dce2e707 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Wed, 18 Oct 2023 21:36:54 -0700 Subject: [PATCH 16/20] Revert unneeded change --- onnxruntime/python/onnxruntime_pybind_schema.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index bcb8c650d0966..3a977772873f3 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -59,9 +59,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) { onnxruntime::ArmNNProviderFactoryCreator::Create(0), #endif #ifdef USE_DML - []() { - return onnxruntime::DMLProviderFactoryCreator::Create(0, false, false, false); - }(), + onnxruntime::DMLProviderFactoryCreator::Create(0, false, false, false), #endif #ifdef USE_NNAPI onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional()), From 95b4779e5e4deba23db047db1d43943e916503d4 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Thu, 19 Oct 2023 15:48:01 -0700 Subject: [PATCH 17/20] Small fix --- .../DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 1d9b21941c972..34b826cc462e3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -84,10 +84,6 @@ namespace Dml m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, gsl::make_span(initInputBindings))); - // Queue references to objects which must be kept alive until resulting GPU work completes - m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); - m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); - std::for_each( initializeResourceRefs.begin(), initializeResourceRefs.end(), From 3ccfd0644b008433b2e9dd10b10ebf743c2772d7 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Thu, 19 Oct 2023 16:41:32 -0700 Subject: [PATCH 18/20] Address PR comments --- .../src/DmlRuntimeFusedGraphKernel.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 34b826cc462e3..59d8ac7e6357b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -73,8 +73,8 @@ namespace Dml ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource( static_cast(persistentResourceSize), AllocatorRoundingMode::Disabled, - m_persistentResource.GetAddressOf(), - m_persistentResourceAllocatorUnk.GetAddressOf())); + m_persistentResource.ReleaseAndGetAddressOf(), + m_persistentResourceAllocatorUnk.ReleaseAndGetAddressOf())); m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; } @@ -207,6 +207,9 @@ namespace Dml *m_indexedSubGraph, providerImpl); + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); + TranslateAndCompileGraph( Info(), initializeResourceRefs, @@ -247,10 +250,6 @@ namespace Dml ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); - // Queue references to objects which must be kept alive until resulting GPU work completes - m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); - m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); - return onnxruntime::Status::OK(); } From 1f8d9eaf493ced1b03331594e1385ea838ed3362 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sun, 22 Oct 2023 23:18:01 -0700 Subject: [PATCH 19/20] Address PR comments --- .../src/DmlRuntimeGraphFusionTransformer.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h index bddb2feceb523..cfa743e1f2b85 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h @@ -26,13 +26,13 @@ class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer private: onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, - int graph_level, + int graphLevel, const onnxruntime::logging::Logger& logger) const final; onnxruntime::common::Status ApplyImplHelper( onnxruntime::Graph& graph, bool& modified, - int graph_level, + int graphLevel, const onnxruntime::logging::Logger& logger, const std::unordered_map& implicitInputDefs) const; From c5c844c06bbec46d68ac012dee7493ad385d9d7f Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sun, 22 Oct 2023 23:19:31 -0700 Subject: [PATCH 20/20] Address PR comments --- .../DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 59d8ac7e6357b..1db22ac92e527 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -241,12 +241,12 @@ namespace Dml inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); } - auto aux = contextWrapper.GetOutputTensors(m_outputShapes); + auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes); ExecuteOperator( m_compiledExecutionPlanOperator.Get(), m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, inputPtrs, - aux); + outputTensors); ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier());