Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.20.0.tar.
# this entry will be updated to use refs/tags/<version> instead of the raw commit hash.
kleidiai-qmx;https://github.com/qualcomm/kleidiai/archive/2f10c9a8d32f81ffeeb6d4885a29cc35d2b0da87.zip;5e855730a2d69057a569f43dd7532db3b2d2a05c
duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794
vulkan_headers;https://codeload.github.com/KhronosGroup/Vulkan-Headers/tar.gz/refs/tags/v1.4.344;57bc528ef7c4a3f7bfbb59e64a187e3734bd29d8
11 changes: 11 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -734,10 +734,21 @@ if(onnxruntime_USE_TENSORRT)
endif()

if(onnxruntime_USE_NV)
# If an external project (e.g. dawn from Webgpu EP has already added a Vulkan::Headers target we shouldn't try to import another version of the Vulkan headers)
if (NOT TARGET Vulkan::Headers)
onnxruntime_fetchcontent_declare(
vulkan_headers
URL ${DEP_URL_vulkan_headers}
URL_HASH SHA1=${DEP_SHA1_vulkan_headers}
EXCLUDE_FROM_ALL
)
onnxruntime_fetchcontent_makeavailable(vulkan_headers)
endif()
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/nv_tensorrt_rtx/*)
list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h")
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared)
list(APPEND onnxruntime_test_providers_libs ${TENSORRT_LIBRARY_INFER})
list(APPEND onnxruntime_test_providers_libs Vulkan::Headers)
endif()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ constexpr const char* kCudaGraphEnable = "enable_cuda_graph";
constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable";
constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer";
constexpr const char* kRuntimeCacheFile = "nv_runtime_cache_path";
constexpr const char* kExternalComputeQueueDataParamNV_data = "VkExternalComputeQueueDataParamsNV_data";

} // namespace provider_option_names
namespace run_option_names {
Expand Down
18 changes: 11 additions & 7 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -972,15 +972,15 @@ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t n

/** \brief External memory handle type for importing GPU resources.
*
* \todo Add OPAQUE_WIN32 for Windows Vulkan-specific memory handles
* \todo Add POSIX file descriptor (OPAQUE_FD) for Linux Vulkan/CUDA/OpenCL interop
* \todo Add Linux DMA-BUF file descriptor for embedded GPU memory sharing
*
* \since Version 1.24.
*/
typedef enum OrtExternalMemoryHandleType {
ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(resource) */
ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 1, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(heap) */
ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(resource) */
ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 1, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(heap) */
ORT_EXTERNAL_MEMORY_HANDLE_TYPE_VK_MEMORY_WIN32 = 2, /**< Shared HANDLE from vkGetMemoryWin32HandleKHR, non-dedicated allocation */
ORT_EXTERNAL_MEMORY_HANDLE_TYPE_VK_MEMORY_OPAQUE_FD = 3, /**< File descriptor from vkGetMemoryOpaqueFdKHR, non-dedicated allocation */
} OrtExternalMemoryHandleType;

/** \brief Descriptor for importing external memory.
Expand All @@ -1004,7 +1004,9 @@ typedef struct OrtExternalMemoryDescriptor {
* \since Version 1.24.
*/
typedef enum OrtExternalSemaphoreType {
ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(fence) */
ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(fence) */
ORT_EXTERNAL_SEMAPHORE_VK_TIMELINE_SEMAPHORE_WIN32 = 1, /**< Shared HANDLE from vkGetSemaphoreWin32HandleKHR of a VkSemaphore created as VK_SEMAPHORE_TYPE_TIMELINE */
ORT_EXTERNAL_SEMAPHORE_VK_TIMELINE_SEMAPHORE_OPAQUE_FD = 2, /**< File descriptor from vkGetSemaphoreFdKHR of a VkSemaphore created as VK_SEMAPHORE_TYPE_TIMELINE */
} OrtExternalSemaphoreType;

/** \brief Descriptor for importing external semaphores.
Expand Down Expand Up @@ -1071,14 +1073,16 @@ typedef struct OrtGraphicsInteropConfig {
* works; streams use the default context.
*
* For D3D12: ID3D12CommandQueue*
* For Vulkan: VkQueue (cast to void*)
* For Vulkan: pass NULL
*/
void* command_queue;

/** \brief Additional API-specific options (optional).
*
* Can be used for future extensibility without changing the struct layout.
* For example, Vulkan-specific queue family index, or D3D12 fence sharing flags.
* For example, D3D12 fence sharing flags or provider-specific options like
* onnxruntime::nv::provider_option_names::kExternalComputeQueueDataParamNV_data
* for Vulkan interop for the NvTensorRTRTX provider.
*/
const OrtKeyValuePairs* additional_options;
} OrtGraphicsInteropConfig;
Expand Down
95 changes: 77 additions & 18 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cstdlib>
#include <mutex>
#include <unordered_map>
#include <cuda.h>

Check warning on line 12 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after C++ system header. Should be: nv_provider_factory.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc:12: Found C system header after C++ system header. Should be: nv_provider_factory.h, c system, c++ system, other. [build/include_order] [4]

#include "core/providers/shared_library/provider_api.h"
#include "core/session/onnxruntime_c_api.h"
Expand Down Expand Up @@ -539,8 +540,6 @@
const OrtApi& ort_api;
};

#if defined(_WIN32)

// External Resource Import Implementation (D3D12 to CUDA)
/**
* @brief Derived handle for imported external memory from D3D12 to CUDA.
Expand Down Expand Up @@ -642,8 +641,13 @@
_In_ OrtExternalMemoryHandleType handle_type) noexcept {
(void)this_ptr;
// CUDA supports both D3D12 resource and heap handles
#if defined(_WIN32)
return handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE ||
handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP;
handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP ||
handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_VK_MEMORY_WIN32;
#elif __linux__
return handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_VK_MEMORY_OPAQUE_FD;
#endif
}

static OrtStatus* ORT_API_CALL ImportMemoryImpl(
Expand Down Expand Up @@ -694,6 +698,14 @@
cu_handle_type = CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP;
is_dedicated = false; // D3D12 heaps are not dedicated
break;
case ORT_EXTERNAL_MEMORY_HANDLE_TYPE_VK_MEMORY_WIN32:
cu_handle_type = CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32;
is_dedicated = false; // API header currently, documents that this handle currently is non-dedicated
break;
case ORT_EXTERNAL_MEMORY_HANDLE_TYPE_VK_MEMORY_OPAQUE_FD:
cu_handle_type = CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD;
is_dedicated = false; // API header currently, documents that this handle currently is non-dedicated
break;
default:
// Should not reach here - CanImportMemory already validated handle type
return impl.ort_api.CreateStatus(ORT_EP_FAIL, "Unexpected external memory handle type");
Expand All @@ -702,7 +714,11 @@
// Setup external memory handle descriptor
CUDA_EXTERNAL_MEMORY_HANDLE_DESC ext_mem_desc = {};
ext_mem_desc.type = cu_handle_type;
#if defined(_WIN32)
ext_mem_desc.handle.win32.handle = desc->native_handle;
#else
ext_mem_desc.handle.fd = static_cast<int>(reinterpret_cast<intptr_t>((desc->native_handle)));
#endif
ext_mem_desc.size = desc->size_bytes;
ext_mem_desc.flags = is_dedicated ? CUDA_EXTERNAL_MEMORY_DEDICATED : 0;

Expand Down Expand Up @@ -825,8 +841,11 @@
_In_ const OrtExternalResourceImporterImpl* this_ptr,
_In_ OrtExternalSemaphoreType type) noexcept {
(void)this_ptr;
// CUDA supports D3D12 timeline fences
return type == ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE;
#if defined(_WIN32)
return type == ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE || type == ORT_EXTERNAL_SEMAPHORE_VK_TIMELINE_SEMAPHORE_WIN32;
#else
return type == ORT_EXTERNAL_SEMAPHORE_VK_TIMELINE_SEMAPHORE_OPAQUE_FD;
#endif
}

static OrtStatus* ORT_API_CALL ImportSemaphoreImpl(
Expand Down Expand Up @@ -857,8 +876,25 @@

// Setup external semaphore handle descriptor for D3D12 fence
CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC ext_sem_desc = {};
ext_sem_desc.type = CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE;
switch (desc->type) {
case ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE:
ext_sem_desc.type = CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE;
break;
case ORT_EXTERNAL_SEMAPHORE_VK_TIMELINE_SEMAPHORE_WIN32:
ext_sem_desc.type = CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32;
break;
case ORT_EXTERNAL_SEMAPHORE_VK_TIMELINE_SEMAPHORE_OPAQUE_FD:
ext_sem_desc.type = CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD;
break;
default:
// Should not reach here - ImportSemaphoreImpl already validated handle type
return impl.ort_api.CreateStatus(ORT_EP_FAIL, "Unexpected external memory handle type");
}
#if defined(_WIN32)
ext_sem_desc.handle.win32.handle = desc->native_handle;
#else
ext_sem_desc.handle.fd = static_cast<int>(reinterpret_cast<intptr_t>(desc->native_handle));
#endif
ext_sem_desc.flags = 0;

// Import the external semaphore
Expand Down Expand Up @@ -956,7 +992,7 @@
// Get the CUDA stream from OrtSyncStream
cudaStream_t cuda_stream = static_cast<cudaStream_t>(impl.ort_api.SyncStream_GetHandle(stream));

// Setup signal parameters for D3D12 fence (timeline semaphore)
// Setup signal parameters for D3D12 fence / VK timeline semaphore
CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS signal_params = {};
signal_params.params.fence.value = value;
signal_params.flags = 0;
Expand Down Expand Up @@ -994,8 +1030,6 @@
const OrtEpApi& ep_api;
};

#endif // defined(_WIN32)

// OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection.
struct NvTensorRtRtxEpFactory : OrtEpFactory {
using MemoryInfoUniquePtr = std::unique_ptr<OrtMemoryInfo, std::function<void(OrtMemoryInfo*)>>;
Expand Down Expand Up @@ -1280,17 +1314,11 @@

*out_importer = nullptr;

#if defined(_WIN32)
// Create the external resource importer
auto importer = std::make_unique<NvTrtRtxExternalResourceImporterImpl>(ep_device, factory.ort_api);
*out_importer = importer.release();

return nullptr;
#else
ORT_UNUSED_PARAMETER(ep_device);
return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED,
"External resource import is only available on Windows builds.");
#endif
}

/**
Expand Down Expand Up @@ -1529,9 +1557,40 @@
"[NvTensorRTRTX EP] D3D12 CIG context creation not supported on this platform");
#endif
} else if (config->graphics_api == ORT_GRAPHICS_API_VULKAN) {
// TODO: Add Vulkan CIG context support if needed
return onnxruntime::CreateStatus(ORT_NOT_IMPLEMENTED,
"[NvTensorRTRTX EP] Vulkan CIG context not yet implemented");
int cig_supported{false};
if (cudaSuccess != cudaDeviceGetAttribute(&cig_supported, cudaDevAttrVulkanCigSupported, device_id)) {
return onnxruntime::CreateStatus(ORT_EP_FAIL,
"[NvTensorRTRTX EP] Could not determine CiG support for CUDA device");
}
if (!cig_supported) {
LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] InitGraphicsInterop: CiG for Vulkan is not supported on the given device. Will use the default CUDA context";
return nullptr;
}
const char* nv_blob_ptr_str = factory.ort_api.GetKeyValue(config->additional_options, onnxruntime::nv::provider_option_names::kExternalComputeQueueDataParamNV_data);
if (!nv_blob_ptr_str) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] InitGraphicsInterop: Can't enable CUDA in Graphics (CiG) for Vulkan without onnxruntime::nv::provider_option_names::kExternalComputeQueueDataParamNV_data";
return nullptr;
}
uint64_t nv_blob_ptr = std::stoull(nv_blob_ptr_str);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ std::stoull can throw on malformed input (High-Priority): std::stoull(nv_blob_ptr_str) will throw std::invalid_argument or std::out_of_range if the string is not a valid unsigned integer. This is inside a noexcept-equivalent C API implementation — an uncaught exception will call std::terminate. Wrap in a try-catch or use a safe parser.

uint64_t nv_blob_ptr = 0;
try {
  nv_blob_ptr = std::stoull(nv_blob_ptr_str);
} catch (...) {
  return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT,
      "[NvTensorRTRTX EP] Invalid value for kExternalComputeQueueDataParamNV_data");
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix can be in another PR.

if (nv_blob_ptr == 0) {
return onnxruntime::CreateStatus(ORT_EP_FAIL,
"[NvTensorRTRTX EP] Could not parse provided values for onnxruntime::nv::provider_option_names::kExternalComputeQueueDataParamNV_data or onnxruntime::nv::provider_option_names::kExternalComputeQueueDataParamNV_data_len");
}

CUctxCigParam cig_params{};
cig_params.sharedDataType = CIG_DATA_TYPE_NV_BLOB;
cig_params.sharedData = reinterpret_cast<void*>(nv_blob_ptr);
CUctxCreateParams params{};
params.cigParams = &cig_params;

cu_result = cuCtxCreate_v4(&cig_context, &params, 0, device_id);
if (cu_result != CUDA_SUCCESS) {
const char* error_str = nullptr;
cuGetErrorString(cu_result, &error_str);
std::string error_msg = "[NvTensorRTRTX EP] Failed to create CIG context for Vulkan: ";
error_msg += error_str ? error_str : "unknown error";
return onnxruntime::CreateStatus(ORT_FAIL, error_msg.c_str());
}
} else {
return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT,
"[NvTensorRTRTX EP] Unsupported graphics API for CIG context");
Expand Down
Loading
Loading