-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added DML and CUDA provider support in onnxruntime-node (#16050)
### Description I've added changes to support CUDA and DML (only on Windows, on other platforms it will throw an error) ### Motivation and Context It fixes this feature request #14127 which is tracked here #14529 I was working on StableDiffusion implementation for node.js and it is very slow on CPU, so GPU support is essential. Here is a working demo with a patched and precompiled version https://github.com/dakenf/stable-diffusion-nodejs ---------
- Loading branch information
Showing
19 changed files
with
292 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#ifdef _WIN32 | ||
#include "common.h" | ||
#include "windows.h" | ||
|
||
void LoadDirectMLDll(Napi::Env env) { | ||
DWORD pathLen = MAX_PATH; | ||
std::wstring path(pathLen, L'\0'); | ||
HMODULE moduleHandle = nullptr; | ||
|
||
GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, | ||
reinterpret_cast<LPCSTR>(&LoadDirectMLDll), &moduleHandle); | ||
|
||
DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t *>(path.c_str()), pathLen); | ||
while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) { | ||
int ret = GetLastError(); | ||
if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) { | ||
pathLen *= 2; | ||
path.resize(pathLen); | ||
getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t *>(path.c_str()), pathLen); | ||
} else { | ||
ORT_NAPI_THROW_ERROR(env, "Failed getting path to load DirectML.dll, error code: ", ret); | ||
} | ||
} | ||
|
||
path.resize(path.rfind(L'\\') + 1); | ||
path.append(L"DirectML.dll"); | ||
HMODULE libraryLoadResult = LoadLibraryW(path.c_str()); | ||
|
||
if (!libraryLoadResult) { | ||
int ret = GetLastError(); | ||
ORT_NAPI_THROW_ERROR(env, "Failed loading bundled DirectML.dll, error code: ", ret); | ||
} | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#if defined(USE_DML) && defined(_WIN32) | ||
void LoadDirectMLDll(Napi::Env env); | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.