diff --git a/CMakeLists.txt b/CMakeLists.txt index 661647020..d1022d2a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) +option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF) option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF) option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF) option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF) @@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.") endif() endif() +if(SHERPA_ONNX_ENABLE_DIRECTML) + message(WARNING "\ +Compiling with DirectML enabled. Please make sure Windows 10 SDK +is installed on your system. Otherwise, you will get errors at runtime. +Please refer to + https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#requirements +to install Windows 10 SDK if you have not installed it.") + if(NOT BUILD_SHARED_LIBS) + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_DIRECTML is ON") + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) + endif() +endif() + # see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html # https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake if(MSVC) @@ -160,6 +174,14 @@ else() add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0) endif() +if(SHERPA_ONNX_ENABLE_DIRECTML) + message(STATUS "DirectML is enabled") + add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1) +else() + message(WARNING "DirectML is disabled") + add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0) +endif() + if(SHERPA_ONNX_ENABLE_WASM_TTS) if(NOT SHERPA_ONNX_ENABLE_TTS) message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS") diff --git a/cmake/onnxruntime-win-x64-directml.cmake b/cmake/onnxruntime-win-x64-directml.cmake new file mode 100644 index 000000000..9648ffecc --- /dev/null +++ b/cmake/onnxruntime-win-x64-directml.cmake @@ -0,0 +1,161 @@ +# Copyright (c) 2022-2023 Xiaomi Corporation +message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}") + +if(NOT CMAKE_SYSTEM_NAME STREQUAL Windows) + message(FATAL_ERROR "This file is for Windows only. Given: ${CMAKE_SYSTEM_NAME}") +endif() + +if(NOT (CMAKE_VS_PLATFORM_NAME STREQUAL X64 OR CMAKE_VS_PLATFORM_NAME STREQUAL x64)) + message(FATAL_ERROR "This file is for Windows x64 only. Given: ${CMAKE_VS_PLATFORM_NAME}") +endif() + +if(NOT BUILD_SHARED_LIBS) + message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}") +endif() + +if(NOT SHERPA_ONNX_ENABLE_DIRECTML) + message(FATAL_ERROR "This file is for DirectML. Given SHERPA_ONNX_ENABLE_DIRECTML: ${SHERPA_ONNX_ENABLE_DIRECTML}") +endif() + +set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg") +set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg") +set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a") + +# If you don't have access to the Internet, +# please download onnxruntime to one of the following locations. +# You can add more if you want. +set(possible_file_locations + $ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + ${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + ${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg + /tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg +) + +foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(onnxruntime_URL "${f}") + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL) + message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}") + set(onnxruntime_URL2) + break() + endif() +endforeach() + +FetchContent_Declare(onnxruntime + URL + ${onnxruntime_URL} + ${onnxruntime_URL2} + URL_HASH ${onnxruntime_HASH} +) + +FetchContent_GetProperties(onnxruntime) +if(NOT onnxruntime_POPULATED) + message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}") + FetchContent_Populate(onnxruntime) +endif() +message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") + +find_library(location_onnxruntime onnxruntime + PATHS + "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native" + NO_CMAKE_SYSTEM_PATH +) + +message(STATUS "location_onnxruntime: ${location_onnxruntime}") + +add_library(onnxruntime SHARED IMPORTED) + +set_target_properties(onnxruntime PROPERTIES + IMPORTED_LOCATION ${location_onnxruntime} + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/build/native/include" +) + +set_property(TARGET onnxruntime + PROPERTY + IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.lib" +) + +file(COPY ${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.dll + DESTINATION + ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE} +) + +file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.*") + +message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}") + +if(SHERPA_ONNX_ENABLE_PYTHON) + install(FILES ${onnxruntime_lib_files} DESTINATION ..) +else() + install(FILES ${onnxruntime_lib_files} DESTINATION lib) +endif() + +install(FILES ${onnxruntime_lib_files} DESTINATION bin) + +# Setup DirectML + +set(directml_URL "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.0") +set(directml_HASH "SHA256=10d175f8e97447712b3680e3ac020bbb8eafdf651332b48f09ffee2eec801c23") + +set(possible_directml_file_locations + $ENV{HOME}/Downloads/Microsoft.AI.DirectML.1.15.0.nupkg + ${PROJECT_SOURCE_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg + ${PROJECT_BINARY_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg + /tmp/Microsoft.AI.DirectML.1.15.0.nupkg +) + +foreach(f IN LISTS possible_directml_file_locations) + if(EXISTS ${f}) + set(directml_URL "${f}") + file(TO_CMAKE_PATH "${directml_URL}" directml_URL) + message(STATUS "Found local downloaded DirectML: ${directml_URL}") + break() + endif() +endforeach() + +FetchContent_Declare(directml + URL + ${directml_URL} + URL_HASH ${directml_HASH} +) + +FetchContent_GetProperties(directml) +if(NOT directml_POPULATED) + message(STATUS "Downloading DirectML from ${directml_URL}") + FetchContent_Populate(directml) +endif() +message(STATUS "DirectML is downloaded to ${directml_SOURCE_DIR}") + +find_library(location_directml DirectML + PATHS + "${directml_SOURCE_DIR}/bin/x64-win" + NO_CMAKE_SYSTEM_PATH +) + +message(STATUS "location_directml: ${location_directml}") + +add_library(directml SHARED IMPORTED) + +set_target_properties(directml PROPERTIES + IMPORTED_LOCATION ${location_directml} + INTERFACE_INCLUDE_DIRECTORIES "${directml_SOURCE_DIR}/bin/x64-win" +) + +set_property(TARGET directml + PROPERTY + IMPORTED_IMPLIB "${directml_SOURCE_DIR}/bin/x64-win/DirectML.lib" +) + +file(COPY ${directml_SOURCE_DIR}/bin/x64-win/DirectML.dll + DESTINATION + ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE} +) + +file(GLOB directml_lib_files "${directml_SOURCE_DIR}/bin/x64-win/DirectML.*") + +message(STATUS "DirectML lib files: ${directml_lib_files}") + +install(FILES ${directml_lib_files} DESTINATION lib) +install(FILES ${directml_lib_files} DESTINATION bin) \ No newline at end of file diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index d1c4dc851..6655b45cd 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -95,7 +95,10 @@ function(download_onnxruntime) include(onnxruntime-win-arm64) else() # for 64-bit windows (x64) - if(BUILD_SHARED_LIBS) + if(SHERPA_ONNX_ENABLE_DIRECTML) + message(STATUS "Use DirectML") + include(onnxruntime-win-x64-directml) + elseif(BUILD_SHARED_LIBS) message(STATUS "Use dynamic onnxruntime libraries") if(SHERPA_ONNX_ENABLE_GPU) include(onnxruntime-win-x64-gpu) diff --git a/sherpa-onnx/csrc/provider.cc b/sherpa-onnx/csrc/provider.cc index 19d585976..3baed32c1 100644 --- a/sherpa-onnx/csrc/provider.cc +++ b/sherpa-onnx/csrc/provider.cc @@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) { return Provider::kNNAPI; } else if (s == "trt") { return Provider::kTRT; + } else if (s == "directml") { + return Provider::kDirectML; } else { SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); return Provider::kCPU; diff --git a/sherpa-onnx/csrc/provider.h b/sherpa-onnx/csrc/provider.h index 712006f2b..2b85b8a2e 100644 --- a/sherpa-onnx/csrc/provider.h +++ b/sherpa-onnx/csrc/provider.h @@ -14,12 +14,13 @@ namespace sherpa_onnx { // https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java // for a list of available providers enum class Provider { - kCPU = 0, // CPUExecutionProvider - kCUDA = 1, // CUDAExecutionProvider - kCoreML = 2, // CoreMLExecutionProvider - kXnnpack = 3, // XnnpackExecutionProvider - kNNAPI = 4, // NnapiExecutionProvider - kTRT = 5, // TensorRTExecutionProvider + kCPU = 0, // CPUExecutionProvider + kCUDA = 1, // CUDAExecutionProvider + kCoreML = 2, // CoreMLExecutionProvider + kXnnpack = 3, // XnnpackExecutionProvider + kNNAPI = 4, // NnapiExecutionProvider + kTRT = 5, // TensorRTExecutionProvider + kDirectML = 6, // DmlExecutionProvider }; /** diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index fb5932c47..50d4abfe6 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -19,6 +19,10 @@ #include "nnapi_provider_factory.h" // NOLINT #endif +#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1 +#include "dml_provider_factory.h" // NOLINT +#endif + namespace sherpa_onnx { static void OrtStatusFailure(OrtStatus *status, const char *s) { @@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl( } break; } + case Provider::kDirectML: { +#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1 + sess_opts.DisableMemPattern(); + sess_opts.SetExecutionMode(ORT_SEQUENTIAL); + int32_t device_id = 0; + OrtStatus *status = + OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id); + if (status) { + const auto &api = Ort::GetApi(); + const char *msg = api.GetErrorMessage(status); + SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg); + api.ReleaseStatus(status); + } +#else + SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!"); +#endif + break; + } case Provider::kCoreML: { #if defined(__APPLE__) uint32_t coreml_flags = 0;