diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc index 5fe37a6c30e33..90e488a1eda18 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc @@ -7,18 +7,7 @@ #include "core/framework/provider_options.h" #include "nv_execution_provider_custom_ops.h" #include "nv_execution_provider.h" - -// The filename extension for a shared library is different per platform -#ifdef _WIN32 -#define LIBRARY_PREFIX -#define LIBRARY_EXTENSION ORT_TSTR(".dll") -#elif defined(__APPLE__) -#define LIBRARY_PREFIX "lib" -#define LIBRARY_EXTENSION ".dylib" -#else -#define LIBRARY_PREFIX "lib" -#define LIBRARY_EXTENSION ".so" -#endif +#include "nv_platform_utils.h" namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose); @@ -76,14 +65,14 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // This library contains GroupQueryAttention and RotaryEmbedding plugins for transformer models try { const auto& env = onnxruntime::GetDefaultEnv(); - auto external_plugin_path = env.GetRuntimePath() + + auto external_plugin_path = GetEPLibraryDirectory() + PathString(LIBRARY_PREFIX ORT_TSTR("tensorrt_plugins") LIBRARY_EXTENSION); void* external_plugin_handle = nullptr; auto status = env.LoadDynamicLibrary(external_plugin_path, false, &external_plugin_handle); if (status.IsOK()) { LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] External plugins loaded: tensorrt_plugins (GQA + RotaryEmbedding)"; } else { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] tensorrt_plugins library not found in runtime path (optional)"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] tensorrt_plugins library not found in EP library path (optional)"; } } catch (const std::exception& e) { LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] tensorrt_plugins library not available: " << e.what(); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_platform_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_platform_utils.h new file mode 100644 index 0000000000000..f3298a8449157 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_platform_utils.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include "core/common/path_string.h" + +#ifdef _WIN32 +#include +#else +#include +#endif + +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + +namespace onnxruntime { +inline PathString GetEPLibraryDirectory() { +#ifdef _WIN32 + HMODULE hModule = NULL; + // Get handle to the DLL executing this code + if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast(GetEPLibraryDirectory), + &hModule)) { + return PathString(); + } + + wchar_t buffer[MAX_PATH]; + DWORD len = GetModuleFileNameW(hModule, buffer, MAX_PATH); + if (len == 0 || len >= MAX_PATH) { + return PathString(); + } + + std::wstring path(buffer); + size_t lastSlash = path.find_last_of(L"\\/"); + if (lastSlash != std::wstring::npos) { + return PathString(path.substr(0, lastSlash + 1)); + } + return PathString(); +#else + // Linux and other Unix-like platforms + Dl_info dl_info; + + if (dladdr((void*)&GetEPLibraryDirectory, &dl_info) == 0 || dl_info.dli_fname == nullptr) { + return PathString(); + } + + std::string so_path(dl_info.dli_fname); + size_t last_slash = so_path.find_last_of('/'); + if (last_slash != std::string::npos) { + return PathString(so_path.substr(0, last_slash + 1)); + } + return PathString(); +#endif +} +} // namespace onnxruntime