Skip to content

Commit

Permalink
chore: 修复编译
Browse files Browse the repository at this point in the history
  • Loading branch information
Blinue committed May 13, 2024
1 parent b568096 commit 1f2693f
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 38 deletions.
13 changes: 5 additions & 8 deletions src/Magpie.Core/CudaInferenceBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "Logger.h"
#include "DirectXHelper.h"
#include "Utils.h"
#include "OnnxHelper.h"

#pragma comment(lib, "cudart.lib")

Expand Down Expand Up @@ -336,21 +337,17 @@ bool CudaInferenceBackend::_CreateSession(
) {
const OrtApi& ortApi = Ort::GetApi();

OrtCUDAProviderOptionsV2* cudaOptions;
Ort::ThrowOnError(ortApi.CreateCUDAProviderOptions(&cudaOptions));

Utils::ScopeExit se([cudaOptions]() {
Ort::GetApi().ReleaseCUDAProviderOptions(cudaOptions);
});
OnnxHelper::unique_cuda_provider_options cudaOptions;
Ort::ThrowOnError(ortApi.CreateCUDAProviderOptions(cudaOptions.put()));

{
const char* keys[]{ "device_id", "has_user_compute_stream" };
std::string deviceIdStr = std::to_string(deviceId);
const char* values[]{ deviceIdStr.c_str(), "1" };
Ort::ThrowOnError(ortApi.UpdateCUDAProviderOptions(cudaOptions, keys, values, std::size(keys)));
Ort::ThrowOnError(ortApi.UpdateCUDAProviderOptions(cudaOptions.get(), keys, values, std::size(keys)));
}

sessionOptions.AppendExecutionProvider_CUDA_V2(*cudaOptions);
sessionOptions.AppendExecutionProvider_CUDA_V2(*cudaOptions.get());

_session = Ort::Session(_env, modelPath, sessionOptions);
return true;
Expand Down
16 changes: 6 additions & 10 deletions src/Magpie.Core/DirectMLInferenceBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,14 @@ static winrt::com_ptr<ID3D12Resource> ShareTextureWithD3D12(ID3D11Texture2D* tex
return result;
}

HANDLE sharedHandle;
hr = dxgiResource->CreateSharedHandle(nullptr, access, nullptr, &sharedHandle);
wil::unique_handle sharedHandle;
hr = dxgiResource->CreateSharedHandle(nullptr, access, nullptr, sharedHandle.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateSharedHandle 失败", hr);
return result;
}

Win32Utils::ScopedHandle scopedSharedHandle(sharedHandle);

hr = d3d12Device->OpenSharedHandle(sharedHandle, IID_PPV_ARGS(&result));
hr = d3d12Device->OpenSharedHandle(sharedHandle.get(), IID_PPV_ARGS(&result));
if (FAILED(hr)) {
Logger::Get().ComError("OpenSharedHandle 失败", hr);
return result;
Expand Down Expand Up @@ -349,16 +347,14 @@ bool DirectMLInferenceBackend::_CreateFence(ID3D11Device5* d3d11Device, ID3D12De
return false;
}

HANDLE sharedHandle;
hr = _d3d11Fence->CreateSharedHandle(nullptr, GENERIC_ALL, nullptr, &sharedHandle);
wil::unique_handle sharedHandle;
hr = _d3d11Fence->CreateSharedHandle(nullptr, GENERIC_ALL, nullptr, sharedHandle.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateSharedHandle 失败", hr);
return false;
}

Win32Utils::ScopedHandle scopedSharedHandle(sharedHandle);

hr = d3d12Device->OpenSharedHandle(sharedHandle, IID_PPV_ARGS(&_d3d12Fence));
hr = d3d12Device->OpenSharedHandle(sharedHandle.get(), IID_PPV_ARGS(&_d3d12Fence));
if (FAILED(hr)) {
Logger::Get().ComError("OpenSharedHandle 失败", hr);
return false;
Expand Down
3 changes: 2 additions & 1 deletion src/Magpie.Core/Magpie.Core.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
<ClInclude Include="include\Magpie.Core.h" />
<ClInclude Include="InferenceBackendBase.h" />
<ClInclude Include="OnnxEffectDrawer.h" />
<ClInclude Include="OnnxHelper.h" />
<ClInclude Include="OverlayDrawer.h" />
<ClInclude Include="Renderer.h" />
<ClInclude Include="ScalingOptions.h" />
Expand Down Expand Up @@ -171,4 +172,4 @@
<Error Condition="!Exists('..\..\packages\Microsoft.AI.DirectML.1.13.1\build\Microsoft.AI.DirectML.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.AI.DirectML.1.13.1\build\Microsoft.AI.DirectML.props'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.AI.DirectML.1.13.1\build\Microsoft.AI.DirectML.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.AI.DirectML.1.13.1\build\Microsoft.AI.DirectML.targets'))" />
</Target>
</Project>
</Project>
5 changes: 4 additions & 1 deletion src/Magpie.Core/Magpie.Core.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@
<Filter>ONNX</Filter>
</ClInclude>
<ClInclude Include="ExclModeHelper.h" />
<ClInclude Include="OnnxHelper.h">
<Filter>Helpers</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="ScalingRuntime.cpp" />
Expand Down Expand Up @@ -220,4 +223,4 @@
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
</Project>
</Project>
25 changes: 25 additions & 0 deletions src/Magpie.Core/OnnxHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once
#include "pch.h"
#include <onnxruntime_cxx_api.h>

namespace Magpie::Core {

struct OnnxHelper {
private:
static void _CloseCUDAProviderOptions(OrtCUDAProviderOptionsV2* options) {
Ort::GetApi().ReleaseCUDAProviderOptions(options);
}

static void _CloseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2* options) {
Ort::GetApi().ReleaseTensorRTProviderOptions(options);
}

public:
using unique_cuda_provider_options = wil::unique_any<OrtCUDAProviderOptionsV2*,
decltype(_CloseCUDAProviderOptions), _CloseCUDAProviderOptions>;

using unique_tensorrt_provider_options = wil::unique_any<OrtTensorRTProviderOptionsV2*,
decltype(_CloseTensorRTProviderOptions), _CloseTensorRTProviderOptions>;
};

}
30 changes: 12 additions & 18 deletions src/Magpie.Core/TensorRTInferenceBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "HashHelper.h"
#include "CommonSharedConstants.h"
#include "Utils.h"
#include "OnnxHelper.h"

#pragma warning(push)
// C4100: “pluginFactory”: 未引用的形参
Expand Down Expand Up @@ -102,21 +103,18 @@ bool TensorRTInferenceBackend::_CreateSession(
optimizationLevel,
enableFP16
);
if (!Win32Utils::CreateDir(cacheDir, true)) {
Logger::Get().Error("CreateDir 失败");
HRESULT hr = wil::CreateDirectoryDeepNoThrow(cacheDir.c_str());
if (FAILED(hr)) {
Logger::Get().ComError("CreateDirectoryDeepNoThrow 失败", hr);
return false;
}

const std::wstring cacheCtxPath = cacheDir + L"\\ctx.onnx";

const OrtApi& ortApi = Ort::GetApi();

OrtTensorRTProviderOptionsV2* trtOptions;
Ort::ThrowOnError(ortApi.CreateTensorRTProviderOptions(&trtOptions));

Utils::ScopeExit se1([trtOptions]() {
Ort::GetApi().ReleaseTensorRTProviderOptions(trtOptions);
});
OnnxHelper::unique_tensorrt_provider_options trtOptions;
Ort::ThrowOnError(ortApi.CreateTensorRTProviderOptions(trtOptions.put()));

const std::string deviceIdStr = std::to_string(deviceId);
{
Expand Down Expand Up @@ -154,24 +152,20 @@ bool TensorRTInferenceBackend::_CreateSession(
"1",
cacheCtxPathANSI.c_str()
};
Ort::ThrowOnError(ortApi.UpdateTensorRTProviderOptions(trtOptions, keys, values, std::size(keys)));
Ort::ThrowOnError(ortApi.UpdateTensorRTProviderOptions(trtOptions.get(), keys, values, std::size(keys)));
}

OrtCUDAProviderOptionsV2* cudaOptions;
Ort::ThrowOnError(ortApi.CreateCUDAProviderOptions(&cudaOptions));

Utils::ScopeExit se2([cudaOptions]() {
Ort::GetApi().ReleaseCUDAProviderOptions(cudaOptions);
});
OnnxHelper::unique_cuda_provider_options cudaOptions;
Ort::ThrowOnError(ortApi.CreateCUDAProviderOptions(cudaOptions.put()));

{
const char* keys[]{ "device_id", "has_user_compute_stream" };
const char* values[]{ deviceIdStr.c_str(), "1" };
Ort::ThrowOnError(ortApi.UpdateCUDAProviderOptions(cudaOptions, keys, values, std::size(keys)));
Ort::ThrowOnError(ortApi.UpdateCUDAProviderOptions(cudaOptions.get(), keys, values, std::size(keys)));
}

sessionOptions.AppendExecutionProvider_TensorRT_V2(*trtOptions);
sessionOptions.AppendExecutionProvider_CUDA_V2(*cudaOptions);
sessionOptions.AppendExecutionProvider_TensorRT_V2(*trtOptions.get());
sessionOptions.AppendExecutionProvider_CUDA_V2(*cudaOptions.get());

if (Win32Utils::FileExists(cacheCtxPath.c_str())) {
Logger::Get().Info("读取缓存 " + StrUtils::UTF16ToUTF8(cacheCtxPath));
Expand Down
3 changes: 3 additions & 0 deletions src/Magpie/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "Win32Utils.h"
#include "TouchHelper.h"
#include "CommonSharedConstants.h"
#include "StrUtils.h"

// 将当前目录设为程序所在目录
static std::wstring SetWorkingDir() noexcept {
Expand All @@ -30,6 +31,8 @@ static std::wstring SetWorkingDir() noexcept {
));

FAIL_FAST_IF_WIN32_BOOL_FALSE(SetCurrentDirectory(path.c_str()));

path.resize(StrUtils::StrLen(path.c_str()));
return path;
}

Expand Down

0 comments on commit 1f2693f

Please sign in to comment.