Skip to content

Commit

Permalink
Added DML and CUDA provider support in onnxruntime-node (microsoft#16050
Browse files Browse the repository at this point in the history
)

### Description
I've added changes to support CUDA and DML (only on Windows, on other
platforms it will throw an error)



### Motivation and Context
It fixes this feature request
microsoft#14127 which is tracked
here microsoft#14529

I was working on StableDiffusion implementation for node.js and it is
very slow on CPU, so GPU support is essential.

Here is a working demo with a patched and precompiled version
https://github.com/dakenf/stable-diffusion-nodejs

---------
  • Loading branch information
dakenf authored Aug 25, 2023
1 parent a9e75d4 commit 5ab0896
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 14 deletions.
16 changes: 15 additions & 1 deletion common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,14 @@ export declare namespace InferenceSession {
// Backend React Native: supports 'cpu', 'xnnpack', 'coreml' (iOS), 'nnapi' (Android).
interface ExecutionProviderOptionMap {
cpu: CpuExecutionProviderOption;
coreml: CoreMlExecutionProviderOption;
cuda: CudaExecutionProviderOption;
dml: DmlExecutionProviderOption;
tensorrt: TensorRtExecutionProviderOption;
wasm: WebAssemblyExecutionProviderOption;
webgl: WebGLExecutionProviderOption;
xnnpack: XnnpackExecutionProviderOption;
webnn: WebNNExecutionProviderOption;
coreml: CoreMLExecutionProviderOption;
nnapi: NnapiExecutionProviderOption;
}

Expand All @@ -194,6 +196,18 @@ export declare namespace InferenceSession {
readonly name: 'cuda';
deviceId?: number;
}
export interface CoreMlExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'coreml';
coreMlFlags?: number;
}
export interface DmlExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'dml';
deviceId?: number;
}
export interface TensorRtExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'tensorrt';
deviceId?: number;
}
export interface WebAssemblyExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'wasm';
}
Expand Down
29 changes: 29 additions & 0 deletions node/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,29 @@ endif()
# include dirs
include_directories(${CMAKE_JS_INC})
include_directories(${CMAKE_SOURCE_DIR}/../../include/onnxruntime/core/session)
include_directories(${CMAKE_SOURCE_DIR}/../../include/onnxruntime)
include_directories(${CMAKE_SOURCE_DIR}/../../onnxruntime)
include_directories(${CMAKE_SOURCE_DIR}/node_modules/node-addon-api)

# optional providers
option(USE_DML "Build with DirectML support" OFF)
option(USE_CUDA "Build with CUDA support" OFF)
option(USE_TENSORRT "Build with TensorRT support" OFF)
option(USE_COREML "Build with CoreML support" OFF)

if(USE_DML)
add_compile_definitions(USE_DML=1)
endif()
if(USE_CUDA)
add_compile_definitions(USE_CUDA=1)
endif()
if(USE_TENSORRT)
add_compile_definitions(USE_TENSORRT=1)
endif()
if(USE_COREML)
add_compile_definitions(USE_COREML=1)
endif()

# source files
file(GLOB ORT_NODEJS_BINDING_SOURCE_FILES ${CMAKE_SOURCE_DIR}/src/*.cc)

Expand Down Expand Up @@ -77,6 +98,14 @@ if (WIN32)
${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}/onnxruntime.dll
${dist_folder}
)
if (USE_DML)
add_custom_command(
TARGET onnxruntime_binding POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}/DirectML.dll
${dist_folder}
)
endif ()
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
add_custom_command(
TARGET onnxruntime_binding POST_BUILD
Expand Down
4 changes: 4 additions & 0 deletions node/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ Following platforms are supported with pre-built binaries:

To use on platforms without pre-built binaries, you can build Node.js binding from source and consume it by `npm install <onnxruntime_repo_root>/js/node/`. See also [instructions](https://www.onnxruntime.ai/docs/how-to/build.html#apis-and-language-bindings) for building ONNX Runtime Node.js binding locally.

# GPU Support

Right now, the Windows version supports only the DML provider. Linux x64 can use CUDA and TensorRT.

## License

License information can be found [here](https://github.com/microsoft/onnxruntime/blob/main/README.md#license).
1 change: 1 addition & 0 deletions node/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ class OnnxruntimeBackend implements Backend {
}

export const onnxruntimeBackend = new OnnxruntimeBackend();
export const listSupportedBackends = binding.listSupportedBackends;
13 changes: 10 additions & 3 deletions node/lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@ export declare namespace Binding {
export interface InferenceSessionConstructor {
new(): InferenceSession;
}

export interface SupportedBackend {
name: string;
bundled: boolean;
}
}

// export native binding
export const binding =
// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
require(`../bin/napi-v3/${process.platform}/${process.arch}/onnxruntime_binding.node`) as
// eslint-disable-next-line @typescript-eslint/naming-convention
{InferenceSession: Binding.InferenceSessionConstructor};
require(`../bin/napi-v3/${process.platform}/${process.arch}/onnxruntime_binding.node`) as {
// eslint-disable-next-line @typescript-eslint/naming-convention
InferenceSession: Binding.InferenceSessionConstructor;
listSupportedBackends: () => Binding.SupportedBackend[];
};
8 changes: 6 additions & 2 deletions node/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
// Licensed under the MIT License.

export * from 'onnxruntime-common';
export {listSupportedBackends} from './backend';
import {registerBackend, env} from 'onnxruntime-common';
import {onnxruntimeBackend} from './backend';
import {version} from './version';
import {onnxruntimeBackend, listSupportedBackends} from './backend';

registerBackend('cpu', onnxruntimeBackend, 100);
const backends = listSupportedBackends();
for (const backend of backends) {
registerBackend(backend.name, onnxruntimeBackend, 100);
}

env.versions.node = version;
20 changes: 20 additions & 0 deletions node/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ if (ARCH !== 'x64' && ARCH !== 'ia32' && ARCH !== 'arm64' && ARCH !== 'arm') {
const ONNXRUNTIME_BUILD_DIR = buildArgs['onnxruntime-build-dir'];
// --rebuild
const REBUILD = !!buildArgs.rebuild;
// --use_dml
const USE_DML = !!buildArgs.use_dml;
// --use_cuda
const USE_CUDA = !!buildArgs.use_cuda;
// --use_tensorrt
const USE_TENSORRT = !!buildArgs.use_tensorrt;
// --use_coreml
const USE_COREML = !!buildArgs.use_coreml;

// build path
const ROOT_FOLDER = path.join(__dirname, '..');
Expand All @@ -47,6 +55,18 @@ const args = [
if (ONNXRUNTIME_BUILD_DIR && typeof ONNXRUNTIME_BUILD_DIR === 'string') {
args.push(`--CDONNXRUNTIME_BUILD_DIR=${ONNXRUNTIME_BUILD_DIR}`);
}
if (USE_DML) {
args.push('--CDUSE_DML=ON');
}
if (USE_CUDA) {
args.push('--CDUSE_CUDA=ON');
}
if (USE_TENSORRT) {
args.push('--CDUSE_TENSORRT=ON');
}
if (USE_COREML) {
args.push('--CDUSE_COREML=ON');
}

// set CMAKE_OSX_ARCHITECTURES for macOS build
if (os.platform() === 'darwin') {
Expand Down
37 changes: 37 additions & 0 deletions node/src/directml_load_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef _WIN32
#include "common.h"
#include "windows.h"

void LoadDirectMLDll(Napi::Env env) {
DWORD pathLen = MAX_PATH;
std::wstring path(pathLen, L'\0');
HMODULE moduleHandle = nullptr;

GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
reinterpret_cast<LPCSTR>(&LoadDirectMLDll), &moduleHandle);

DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t *>(path.c_str()), pathLen);
while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) {
int ret = GetLastError();
if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) {
pathLen *= 2;
path.resize(pathLen);
getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t *>(path.c_str()), pathLen);
} else {
ORT_NAPI_THROW_ERROR(env, "Failed getting path to load DirectML.dll, error code: ", ret);
}
}

path.resize(path.rfind(L'\\') + 1);
path.append(L"DirectML.dll");
HMODULE libraryLoadResult = LoadLibraryW(path.c_str());

if (!libraryLoadResult) {
int ret = GetLastError();
ORT_NAPI_THROW_ERROR(env, "Failed loading bundled DirectML.dll, error code: ", ret);
}
}
#endif
6 changes: 6 additions & 0 deletions node/src/directml_load_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if defined(USE_DML) && defined(_WIN32)
void LoadDirectMLDll(Napi::Env env);
#endif
46 changes: 43 additions & 3 deletions node/src/inference_session_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
#include "onnxruntime_cxx_api.h"

#include "common.h"
#include "directml_load_helper.h"
#include "inference_session_wrap.h"
#include "run_options_helper.h"
#include "session_options_helper.h"
#include "tensor_helper.h"
#include <string>

Napi::FunctionReference InferenceSessionWrap::constructor;

Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
#if defined(USE_DML) && defined(_WIN32)
LoadDirectMLDll(env);
#endif
// create ONNX runtime env
Ort::InitApi();
ORT_NAPI_THROW_ERROR_IF(
Expand All @@ -32,6 +37,10 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
constructor = Napi::Persistent(func);
constructor.SuppressDestruct();
exports.Set("InferenceSession", func);

Napi::Function listSupportedBackends = Napi::Function::New(env, InferenceSessionWrap::ListSupportedBackends);
exports.Set("listSupportedBackends", listSupportedBackends);

return exports;
}

Expand Down Expand Up @@ -70,7 +79,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) {
int64_t bytesOffset = info[1].As<Napi::Number>().Int64Value();
int64_t bytesLength = info[2].As<Napi::Number>().Int64Value();

ParseSessionOptions(info[1].As<Napi::Object>(), sessionOptions);
ParseSessionOptions(info[3].As<Napi::Object>(), sessionOptions);
this->session_.reset(new Ort::Session(*env.GetInstanceData<Ort::Env>(),
reinterpret_cast<char *>(buffer) + bytesOffset, bytesLength,
sessionOptions));
Expand Down Expand Up @@ -154,14 +163,15 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) {
std::vector<bool> reuseOutput;
size_t inputIndex = 0;
size_t outputIndex = 0;
OrtMemoryInfo *memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release();

try {
for (auto &name : inputNames_) {
if (feed.Has(name)) {
inputIndex++;
inputNames_cstr.push_back(name.c_str());
auto value = feed.Get(name);
inputValues.push_back(NapiValueToOrtValue(env, value));
inputValues.push_back(NapiValueToOrtValue(env, value, memory_info));
}
}
for (auto &name : outputNames_) {
Expand All @@ -170,7 +180,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) {
outputNames_cstr.push_back(name.c_str());
auto value = fetch.Get(name);
reuseOutput.push_back(!value.IsNull());
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value));
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, memory_info));
}
}

Expand Down Expand Up @@ -198,3 +208,33 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) {
ORT_NAPI_THROW_ERROR(env, e.what());
}
}

Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::EscapableHandleScope scope(env);
Napi::Array result = Napi::Array::New(env);

auto createObject = [&env](const std::string &name, const bool bundled) -> Napi::Object {
Napi::Object result = Napi::Object::New(env);
result.Set("name", name);
result.Set("bundled", bundled);
return result;
};

result.Set(uint32_t(0), createObject("cpu", true));

#ifdef USE_DML
result.Set(result.Length(), createObject("dml", true));
#endif
#ifdef USE_CUDA
result.Set(result.Length(), createObject("cuda", false));
#endif
#ifdef USE_TENSORRT
result.Set(result.Length(), createObject("tensorrt", false));
#endif
#ifdef USE_COREML
result.Set(result.Length(), createObject("coreml", true));
#endif

return scope.Escape(result);
}
6 changes: 6 additions & 0 deletions node/src/inference_session_wrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
InferenceSessionWrap(const Napi::CallbackInfo &info);

private:
/**
* [sync] list supported backend list
* @returns array with objects { "name": "cpu", requirementsInstalled: true }
*/
static Napi::Value ListSupportedBackends(const Napi::CallbackInfo &info);

/**
* [sync] create the session.
* @param arg0 either a string (file path) or a Uint8Array
Expand Down
Loading

0 comments on commit 5ab0896

Please sign in to comment.