From 1aff515a0435b1db998c96270ec835711fddf7e3 Mon Sep 17 00:00:00 2001 From: Umang Bhatt Date: Wed, 21 Jan 2026 11:49:21 +0530 Subject: [PATCH 1/3] Engine compatibility validity API implementation --- .../nv_tensorrt_rtx/nv_execution_provider.cc | 86 ++++++++++++ .../nv_tensorrt_rtx/nv_execution_provider.h | 11 ++ .../nv_tensorrt_rtx/nv_provider_factory.cc | 127 +++++++++++++++++- .../nv_tensorrt_rtx/onnx_ctx_model_helper.cc | 47 +++++++ .../nv_tensorrt_rtx/onnx_ctx_model_helper.h | 6 + .../plugin_ep/ep_factory_provider_bridge.h | 15 +++ 6 files changed, 290 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index def6d7e9ea916..37406912d373d 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -2578,6 +2578,80 @@ const InlinedVector NvExecutionProvider::GetEpContextNodes() const return ep_context_nodes; } +std::string NvExecutionProvider::GetCompiledModelCompatibilityInfo( + const onnxruntime::GraphViewer& graph_viewer) const { + ORT_UNUSED_PARAMETER(graph_viewer); + + // Protect read access to engine_headers_ for thread safety + auto lock = GetApiLock(); + + // If we have stored engine headers, return the first one found + // (typically there's only one per EP context) + if (!engine_headers_.empty()) { + return engine_headers_.begin()->second; + } + + // No headers available - validation not supported for this model + return std::string(); +} + +common::Status NvExecutionProvider::ValidateCompiledModelCompatibilityInfo( + const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const { + + // If no compatibility info provided, validation not applicable + if (compatibility_info.empty()) { + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return Status::OK(); + } + + // Decode hex string to binary + std::vector engine_header; + try { + engine_header = HexStringToBinary(compatibility_info); + } catch (const std::exception& ex) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what(); + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return Status::OK(); + } + + // Use TensorRT RTX's getEngineValidity to check compatibility + uint64_t diagnostics = 0; + nvinfer1::EngineValidity validity = runtime_->getEngineValidity( + engine_header.data(), + engine_header.size(), + &diagnostics + ); + + // Map TensorRT RTX validity to ORT compatibility status + switch (validity) { + case nvinfer1::EngineValidity::kVALID: + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Engine is fully compatible with this system"; + model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + break; + + case nvinfer1::EngineValidity::kSUBOPTIMAL: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is compatible but recompilation recommended " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + break; + + case nvinfer1::EngineValidity::kINVALID: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is incompatible with this system " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + + default: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown TensorRT validity status: " + << static_cast(validity); + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + } + + return Status::OK(); +} + Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer, const Node& fused_node, std::unordered_map& input_map, @@ -2854,6 +2928,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP failed to create engine from network for fused node: " + fused_node.Name()); } + + // Capture engine header (first 64 bytes) for compatibility validation + if (serialized_engine->size() >= kTensorRTEngineHeaderSize) { + std::string engine_header_hex = BinaryToHexString( + serialized_engine->data(), + kTensorRTEngineHeaderSize); + engine_headers_[fused_node.Name()] = engine_header_hex; + } else { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine too small to capture header for validation: " + << serialized_engine->size() << " bytes"; + } + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 5c6ca20d75ec6..5f87d5ba024af 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -345,6 +345,13 @@ class NvExecutionProvider : public IExecutionProvider { const InlinedVector GetEpContextNodes() const override; + // Engine compatibility validation methods + std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override; + + common::Status ValidateCompiledModelCompatibilityInfo( + const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override; + private: mutable NvExecutionProviderInfo info_; bool external_stream_ = false; @@ -423,6 +430,10 @@ class NvExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with std::unordered_map> profiles_; std::unordered_map dds_output_allocator_maps_; + + // Storage for engine headers (64 bytes) for compatibility validation + // Maps fused_node_name -> hex-encoded engine header + mutable std::unordered_map engine_headers_; // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture cudnnHandle_t external_cudnn_handle_ = nullptr; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index e5015e705958d..fbbb48f61f9dd 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -13,6 +13,7 @@ #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cuda_stream_handle.h" +#include "onnx_ctx_model_helper.h" #include "nv_provider_factory.h" #include "nv_execution_provider.h" #include "nv_provider_factory_creator.h" @@ -21,6 +22,11 @@ using namespace onnxruntime; +// External declarations +namespace onnxruntime { +extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); +} + namespace onnxruntime { void InitializeRegistry(); @@ -541,7 +547,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; - + ValidateCompiledModelCompatibilityInfo = ValidateCompiledModelCompatibilityInfoImpl; ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. } @@ -641,7 +647,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); - + int num_cuda_devices = 0; cudaGetDeviceCount(&num_cuda_devices); RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices)); @@ -661,6 +667,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { RETURN_IF_ERROR(factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, &ep_devices[num_ep_devices])); + factory->ort_api.ReleaseKeyValuePairs(ep_options); factory->ort_api.ReleaseKeyValuePairs(ep_metadata); @@ -735,6 +742,122 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { return nullptr; } + /** + * This function is called by the public C API GetModelCompatibilityForEpDevices. + * It uses TensorRT RTX runtime directly to call runtime->getEngineValidity() to check the 64-byte engine header. + * + * @param this_ptr Factory instance pointer + * @param devices Hardware devices (not used, validation is done against current system) + * @param num_devices Number of devices + * @param compatibility_info Hex-encoded 64-byte TensorRT RTX engine header (128 hex characters) + * @param model_compatibility Output parameter for compatibility status + * @return OrtStatus* nullptr on success, error status on failure + */ + static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfoImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept { + + auto& factory = *static_cast(this_ptr); + + // Validate input parameters + if (compatibility_info == nullptr || model_compatibility == nullptr) { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] Invalid arguments: compatibility_info or model_compatibility is null"); + } + + // Device parameters not used for header validation + ORT_UNUSED_PARAMETER(devices); + ORT_UNUSED_PARAMETER(num_devices); + + try { + // If no compatibility info provided, validation not applicable + if (compatibility_info[0] == '\0') { + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; + } + + // Decode hex string to binary + std::vector engine_header; + try { + engine_header = HexStringToBinary(std::string(compatibility_info)); + } catch (const std::exception& ex) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what(); + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return nullptr; + } + + // Validate header size (keep in sync with TensorRT engine header size) + if (engine_header.size() != kTensorRTEngineHeaderSize) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Invalid header size: " << engine_header.size() + << " bytes (expected " << kTensorRTEngineHeaderSize << ")"; + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return nullptr; + } + + // Create TensorRT runtime for validation + static std::mutex runtime_creation_mutex; + std::unique_ptr runtime; + { + std::lock_guard lock(runtime_creation_mutex); + TensorrtLogger& trt_logger = GetTensorrtLogger(false); + runtime.reset(nvinfer1::createInferRuntime(trt_logger)); + } + + if (!runtime) { + LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Failed to create TensorRT runtime"; + return factory.ort_api.CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] Failed to create TensorRT runtime"); + } + + // Use TensorRT's getEngineValidity to check compatibility + uint64_t diagnostics = 0; + nvinfer1::EngineValidity validity = runtime->getEngineValidity( + engine_header.data(), + engine_header.size(), + &diagnostics + ); + + // Map TensorRT validity to ORT compatibility status + switch (validity) { + case nvinfer1::EngineValidity::kVALID: + *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + break; + + case nvinfer1::EngineValidity::kSUBOPTIMAL: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine compatible but recompilation recommended " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + break; + + case nvinfer1::EngineValidity::kINVALID: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine incompatible with this system " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + + default: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown validity status: " + << static_cast(validity); + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + } + + return nullptr; + + } catch (const std::exception& ex) { + std::string error_msg = std::string("[NvTensorRTRTX EP] Exception during validation: ") + ex.what(); + LOGS_DEFAULT(ERROR) << error_msg; + return factory.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + } catch (...) { + LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Unknown exception during validation"; + return factory.ort_api.CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] Unknown exception during validation"); + } + } + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { gpu_memory_infos.reserve(num_devices); host_accessible_memory_infos.reserve(num_devices); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index c1626fa4f36ad..b5cfa38275879 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -14,6 +14,53 @@ namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); +/* + * Convert binary data to hex string + */ +std::string BinaryToHexString(const void* data, size_t size) { + static const char hex_chars[] = "0123456789abcdef"; + const uint8_t* bytes = static_cast(data); + std::string result; + result.reserve(size * 2); + + for (size_t i = 0; i < size; ++i) { + result.push_back(hex_chars[(bytes[i] >> 4) & 0xF]); + result.push_back(hex_chars[bytes[i] & 0xF]); + } + return result; +} + +/* + * Convert hex string back to binary + */ +std::vector HexStringToBinary(const std::string& hex) { + if (hex.size() % 2 != 0) { + ORT_THROW("Hex string must have even length"); + } + + std::vector result; + result.reserve(hex.size() / 2); + + for (size_t i = 0; i < hex.size(); i += 2) { + uint8_t byte = 0; + + // High nibble + char c = hex[i]; + byte |= (c >= '0' && c <= '9') ? static_cast((c - '0') << 4) : + (c >= 'a' && c <= 'f') ? static_cast((c - 'a' + 10) << 4) : + (c >= 'A' && c <= 'F') ? static_cast((c - 'A' + 10) << 4) : 0; + + // Low nibble + c = hex[i + 1]; + byte |= (c >= '0' && c <= '9') ? static_cast(c - '0') : + (c >= 'a' && c <= 'f') ? static_cast(c - 'a' + 10) : + (c >= 'A' && c <= 'F') ? static_cast(c - 'A' + 10) : 0; + + result.push_back(byte); + } + return result; +} + /* * Check whether the graph has the EP context contrib op. * The op can contain the precompiled engine info for TRT EP to directly load the engine. diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h index 7c52f26cc9177..80263b1ba80d5 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h @@ -24,6 +24,12 @@ static const std::string PARTITION_NAME = "partition_name"; static const std::string SDK_VERSION = "ep_sdk_version"; static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; +// TensorRT does not currently expose a header size define; keep in sync with TRT engine serialization header size. +constexpr size_t kTensorRTEngineHeaderSize = 64; +// Helper functions for engine header validation +std::string BinaryToHexString(const void* data, size_t size); +std::vector HexStringToBinary(const std::string& hex); + bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx); const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 3a7a1b6504d12..eb1427db87463 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -86,6 +86,21 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.GetHardwareDeviceIncompatibilityDetails(&ep_factory_, hw, details); } + OrtStatus* ValidateCompiledModelCompatibilityInfo( + const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept override { + // Forward to underlying factory if it supports validation + if (ep_factory_.ValidateCompiledModelCompatibilityInfo) { + return ep_factory_.ValidateCompiledModelCompatibilityInfo( + &ep_factory_, devices, num_devices, compatibility_info, model_compatibility); + } + // If not supported, return NOT_APPLICABLE + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; + } + OrtEpFactory& ep_factory_; ProviderLibrary& provider_library_; std::optional library_path_; From c2897835fa1dd36fe6514e4e30597a1c9cffc409 Mon Sep 17 00:00:00 2001 From: Umang Bhatt Date: Sat, 24 Jan 2026 00:04:26 +0530 Subject: [PATCH 2/3] Lintrunner fix --- .../nv_tensorrt_rtx/nv_execution_provider.cc | 32 ++++++------ .../nv_tensorrt_rtx/nv_execution_provider.h | 4 +- .../nv_tensorrt_rtx/nv_provider_factory.cc | 52 +++++++++---------- .../nv_tensorrt_rtx/onnx_ctx_model_helper.cc | 24 ++++----- 4 files changed, 54 insertions(+), 58 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 37406912d373d..1ae0574de3c87 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -2581,16 +2581,16 @@ const InlinedVector NvExecutionProvider::GetEpContextNodes() const std::string NvExecutionProvider::GetCompiledModelCompatibilityInfo( const onnxruntime::GraphViewer& graph_viewer) const { ORT_UNUSED_PARAMETER(graph_viewer); - + // Protect read access to engine_headers_ for thread safety auto lock = GetApiLock(); - + // If we have stored engine headers, return the first one found // (typically there's only one per EP context) if (!engine_headers_.empty()) { return engine_headers_.begin()->second; } - + // No headers available - validation not supported for this model return std::string(); } @@ -2598,13 +2598,12 @@ std::string NvExecutionProvider::GetCompiledModelCompatibilityInfo( common::Status NvExecutionProvider::ValidateCompiledModelCompatibilityInfo( const std::string& compatibility_info, OrtCompiledModelCompatibility& model_compatibility) const { - // If no compatibility info provided, validation not applicable if (compatibility_info.empty()) { model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; return Status::OK(); } - + // Decode hex string to binary std::vector engine_header; try { @@ -2620,35 +2619,34 @@ common::Status NvExecutionProvider::ValidateCompiledModelCompatibilityInfo( nvinfer1::EngineValidity validity = runtime_->getEngineValidity( engine_header.data(), engine_header.size(), - &diagnostics - ); - + &diagnostics); + // Map TensorRT RTX validity to ORT compatibility status switch (validity) { case nvinfer1::EngineValidity::kVALID: LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Engine is fully compatible with this system"; model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; break; - + case nvinfer1::EngineValidity::kSUBOPTIMAL: LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is compatible but recompilation recommended " << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; break; - + case nvinfer1::EngineValidity::kINVALID: LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is incompatible with this system " << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; break; - + default: - LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown TensorRT validity status: " + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown TensorRT validity status: " << static_cast(validity); model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; break; } - + return Status::OK(); } @@ -2928,18 +2926,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP failed to create engine from network for fused node: " + fused_node.Name()); } - + // Capture engine header (first 64 bytes) for compatibility validation if (serialized_engine->size() >= kTensorRTEngineHeaderSize) { std::string engine_header_hex = BinaryToHexString( - serialized_engine->data(), + serialized_engine->data(), kTensorRTEngineHeaderSize); engine_headers_[fused_node.Name()] = engine_header_hex; } else { - LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine too small to capture header for validation: " + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine too small to capture header for validation: " << serialized_engine->size() << " bytes"; } - + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 5f87d5ba024af..e415143a6ddd1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -347,7 +347,7 @@ class NvExecutionProvider : public IExecutionProvider { // Engine compatibility validation methods std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override; - + common::Status ValidateCompiledModelCompatibilityInfo( const std::string& compatibility_info, OrtCompiledModelCompatibility& model_compatibility) const override; @@ -430,7 +430,7 @@ class NvExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with std::unordered_map> profiles_; std::unordered_map dds_output_allocator_maps_; - + // Storage for engine headers (64 bytes) for compatibility validation // Maps fused_node_name -> hex-encoded engine header mutable std::unordered_map engine_headers_; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index fbbb48f61f9dd..f11f72c20c972 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -647,7 +647,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); - + int num_cuda_devices = 0; cudaGetDeviceCount(&num_cuda_devices); RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices)); @@ -667,7 +667,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { RETURN_IF_ERROR(factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, &ep_devices[num_ep_devices])); - + factory->ort_api.ReleaseKeyValuePairs(ep_options); factory->ort_api.ReleaseKeyValuePairs(ep_metadata); @@ -745,7 +745,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { /** * This function is called by the public C API GetModelCompatibilityForEpDevices. * It uses TensorRT RTX runtime directly to call runtime->getEngineValidity() to check the 64-byte engine header. - * + * * @param this_ptr Factory instance pointer * @param devices Hardware devices (not used, validation is done against current system) * @param num_devices Number of devices @@ -759,26 +759,25 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { size_t num_devices, const char* compatibility_info, OrtCompiledModelCompatibility* model_compatibility) noexcept { - auto& factory = *static_cast(this_ptr); - + // Validate input parameters if (compatibility_info == nullptr || model_compatibility == nullptr) { - return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, - "[NvTensorRTRTX EP] Invalid arguments: compatibility_info or model_compatibility is null"); + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] Invalid arguments: compatibility_info or model_compatibility is null"); } - + // Device parameters not used for header validation ORT_UNUSED_PARAMETER(devices); ORT_UNUSED_PARAMETER(num_devices); - + try { // If no compatibility info provided, validation not applicable if (compatibility_info[0] == '\0') { *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; return nullptr; } - + // Decode hex string to binary std::vector engine_header; try { @@ -788,7 +787,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; return nullptr; } - + // Validate header size (keep in sync with TensorRT engine header size) if (engine_header.size() != kTensorRTEngineHeaderSize) { LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Invalid header size: " << engine_header.size() @@ -796,7 +795,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; return nullptr; } - + // Create TensorRT runtime for validation static std::mutex runtime_creation_mutex; std::unique_ptr runtime; @@ -805,56 +804,55 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { TensorrtLogger& trt_logger = GetTensorrtLogger(false); runtime.reset(nvinfer1::createInferRuntime(trt_logger)); } - + if (!runtime) { LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Failed to create TensorRT runtime"; - return factory.ort_api.CreateStatus(ORT_FAIL, - "[NvTensorRTRTX EP] Failed to create TensorRT runtime"); + return factory.ort_api.CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] Failed to create TensorRT runtime"); } - + // Use TensorRT's getEngineValidity to check compatibility uint64_t diagnostics = 0; nvinfer1::EngineValidity validity = runtime->getEngineValidity( engine_header.data(), engine_header.size(), - &diagnostics - ); - + &diagnostics); + // Map TensorRT validity to ORT compatibility status switch (validity) { case nvinfer1::EngineValidity::kVALID: *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; break; - + case nvinfer1::EngineValidity::kSUBOPTIMAL: LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine compatible but recompilation recommended " << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; break; - + case nvinfer1::EngineValidity::kINVALID: LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine incompatible with this system " << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; break; - + default: - LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown validity status: " + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown validity status: " << static_cast(validity); *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; break; } - + return nullptr; - + } catch (const std::exception& ex) { std::string error_msg = std::string("[NvTensorRTRTX EP] Exception during validation: ") + ex.what(); LOGS_DEFAULT(ERROR) << error_msg; return factory.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); } catch (...) { LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Unknown exception during validation"; - return factory.ort_api.CreateStatus(ORT_FAIL, - "[NvTensorRTRTX EP] Unknown exception during validation"); + return factory.ort_api.CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] Unknown exception during validation"); } } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index b5cfa38275879..b6a4069c59700 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -22,7 +22,7 @@ std::string BinaryToHexString(const void* data, size_t size) { const uint8_t* bytes = static_cast(data); std::string result; result.reserve(size * 2); - + for (size_t i = 0; i < size; ++i) { result.push_back(hex_chars[(bytes[i] >> 4) & 0xF]); result.push_back(hex_chars[bytes[i] & 0xF]); @@ -37,25 +37,25 @@ std::vector HexStringToBinary(const std::string& hex) { if (hex.size() % 2 != 0) { ORT_THROW("Hex string must have even length"); } - + std::vector result; result.reserve(hex.size() / 2); - + for (size_t i = 0; i < hex.size(); i += 2) { uint8_t byte = 0; - + // High nibble char c = hex[i]; - byte |= (c >= '0' && c <= '9') ? static_cast((c - '0') << 4) : - (c >= 'a' && c <= 'f') ? static_cast((c - 'a' + 10) << 4) : - (c >= 'A' && c <= 'F') ? static_cast((c - 'A' + 10) << 4) : 0; - + byte |= (c >= '0' && c <= '9') ? static_cast((c - '0') << 4) : (c >= 'a' && c <= 'f') ? static_cast((c - 'a' + 10) << 4) + : (c >= 'A' && c <= 'F') ? static_cast((c - 'A' + 10) << 4) + : 0; + // Low nibble c = hex[i + 1]; - byte |= (c >= '0' && c <= '9') ? static_cast(c - '0') : - (c >= 'a' && c <= 'f') ? static_cast(c - 'a' + 10) : - (c >= 'A' && c <= 'F') ? static_cast(c - 'A' + 10) : 0; - + byte |= (c >= '0' && c <= '9') ? static_cast(c - '0') : (c >= 'a' && c <= 'f') ? static_cast(c - 'a' + 10) + : (c >= 'A' && c <= 'F') ? static_cast(c - 'A' + 10) + : 0; + result.push_back(byte); } return result; From 553490e0b7d220e4f2152de3bc13779685501192 Mon Sep 17 00:00:00 2001 From: Umang Bhatt Date: Sat, 24 Jan 2026 10:58:32 +0530 Subject: [PATCH 3/3] Fxing EP_NOT_APPLICALE for multiple EP Context Nodes --- .../core/providers/nv_tensorrt_rtx/nv_execution_provider.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 1ae0574de3c87..ee4f45f5057e0 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -2585,6 +2585,12 @@ std::string NvExecutionProvider::GetCompiledModelCompatibilityInfo( // Protect read access to engine_headers_ for thread safety auto lock = GetApiLock(); + // Compatibility info is only supported when there is exactly one engine. + // If multiple EPContext nodes/engines exist, return empty so validation is not applicable. + if (engine_headers_.size() > 1) { + return std::string(); + } + // If we have stored engine headers, return the first one found // (typically there's only one per EP context) if (!engine_headers_.empty()) {