diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index def6d7e9ea916..ee4f45f5057e0 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -2578,6 +2578,84 @@ const InlinedVector NvExecutionProvider::GetEpContextNodes() const return ep_context_nodes; } +std::string NvExecutionProvider::GetCompiledModelCompatibilityInfo( + const onnxruntime::GraphViewer& graph_viewer) const { + ORT_UNUSED_PARAMETER(graph_viewer); + + // Protect read access to engine_headers_ for thread safety + auto lock = GetApiLock(); + + // Compatibility info is only supported when there is exactly one engine. + // If multiple EPContext nodes/engines exist, return empty so validation is not applicable. + if (engine_headers_.size() > 1) { + return std::string(); + } + + // If we have stored engine headers, return the first one found + // (typically there's only one per EP context) + if (!engine_headers_.empty()) { + return engine_headers_.begin()->second; + } + + // No headers available - validation not supported for this model + return std::string(); +} + +common::Status NvExecutionProvider::ValidateCompiledModelCompatibilityInfo( + const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const { + // If no compatibility info provided, validation not applicable + if (compatibility_info.empty()) { + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return Status::OK(); + } + + // Decode hex string to binary + std::vector engine_header; + try { + engine_header = HexStringToBinary(compatibility_info); + } catch (const std::exception& ex) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what(); + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return Status::OK(); + } + + // Use TensorRT RTX's getEngineValidity to check compatibility + uint64_t diagnostics = 0; + nvinfer1::EngineValidity validity = runtime_->getEngineValidity( + engine_header.data(), + engine_header.size(), + &diagnostics); + + // Map TensorRT RTX validity to ORT compatibility status + switch (validity) { + case nvinfer1::EngineValidity::kVALID: + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Engine is fully compatible with this system"; + model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + break; + + case nvinfer1::EngineValidity::kSUBOPTIMAL: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is compatible but recompilation recommended " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + break; + + case nvinfer1::EngineValidity::kINVALID: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is incompatible with this system " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + + default: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown TensorRT validity status: " + << static_cast(validity); + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + } + + return Status::OK(); +} + Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer, const Node& fused_node, std::unordered_map& input_map, @@ -2854,6 +2932,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP failed to create engine from network for fused node: " + fused_node.Name()); } + + // Capture engine header (first 64 bytes) for compatibility validation + if (serialized_engine->size() >= kTensorRTEngineHeaderSize) { + std::string engine_header_hex = BinaryToHexString( + serialized_engine->data(), + kTensorRTEngineHeaderSize); + engine_headers_[fused_node.Name()] = engine_header_hex; + } else { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine too small to capture header for validation: " + << serialized_engine->size() << " bytes"; + } + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 5c6ca20d75ec6..e415143a6ddd1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -345,6 +345,13 @@ class NvExecutionProvider : public IExecutionProvider { const InlinedVector GetEpContextNodes() const override; + // Engine compatibility validation methods + std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override; + + common::Status ValidateCompiledModelCompatibilityInfo( + const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override; + private: mutable NvExecutionProviderInfo info_; bool external_stream_ = false; @@ -424,6 +431,10 @@ class NvExecutionProvider : public IExecutionProvider { std::unordered_map> profiles_; std::unordered_map dds_output_allocator_maps_; + // Storage for engine headers (64 bytes) for compatibility validation + // Maps fused_node_name -> hex-encoded engine header + mutable std::unordered_map engine_headers_; + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index e5015e705958d..f11f72c20c972 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -13,6 +13,7 @@ #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cuda_stream_handle.h" +#include "onnx_ctx_model_helper.h" #include "nv_provider_factory.h" #include "nv_execution_provider.h" #include "nv_provider_factory_creator.h" @@ -21,6 +22,11 @@ using namespace onnxruntime; +// External declarations +namespace onnxruntime { +extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); +} + namespace onnxruntime { void InitializeRegistry(); @@ -541,7 +547,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; - + ValidateCompiledModelCompatibilityInfo = ValidateCompiledModelCompatibilityInfoImpl; ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. } @@ -661,6 +667,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { RETURN_IF_ERROR(factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, &ep_devices[num_ep_devices])); + factory->ort_api.ReleaseKeyValuePairs(ep_options); factory->ort_api.ReleaseKeyValuePairs(ep_metadata); @@ -735,6 +742,120 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { return nullptr; } + /** + * This function is called by the public C API GetModelCompatibilityForEpDevices. + * It uses TensorRT RTX runtime directly to call runtime->getEngineValidity() to check the 64-byte engine header. + * + * @param this_ptr Factory instance pointer + * @param devices Hardware devices (not used, validation is done against current system) + * @param num_devices Number of devices + * @param compatibility_info Hex-encoded 64-byte TensorRT RTX engine header (128 hex characters) + * @param model_compatibility Output parameter for compatibility status + * @return OrtStatus* nullptr on success, error status on failure + */ + static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfoImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept { + auto& factory = *static_cast(this_ptr); + + // Validate input parameters + if (compatibility_info == nullptr || model_compatibility == nullptr) { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] Invalid arguments: compatibility_info or model_compatibility is null"); + } + + // Device parameters not used for header validation + ORT_UNUSED_PARAMETER(devices); + ORT_UNUSED_PARAMETER(num_devices); + + try { + // If no compatibility info provided, validation not applicable + if (compatibility_info[0] == '\0') { + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; + } + + // Decode hex string to binary + std::vector engine_header; + try { + engine_header = HexStringToBinary(std::string(compatibility_info)); + } catch (const std::exception& ex) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what(); + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return nullptr; + } + + // Validate header size (keep in sync with TensorRT engine header size) + if (engine_header.size() != kTensorRTEngineHeaderSize) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Invalid header size: " << engine_header.size() + << " bytes (expected " << kTensorRTEngineHeaderSize << ")"; + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return nullptr; + } + + // Create TensorRT runtime for validation + static std::mutex runtime_creation_mutex; + std::unique_ptr runtime; + { + std::lock_guard lock(runtime_creation_mutex); + TensorrtLogger& trt_logger = GetTensorrtLogger(false); + runtime.reset(nvinfer1::createInferRuntime(trt_logger)); + } + + if (!runtime) { + LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Failed to create TensorRT runtime"; + return factory.ort_api.CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] Failed to create TensorRT runtime"); + } + + // Use TensorRT's getEngineValidity to check compatibility + uint64_t diagnostics = 0; + nvinfer1::EngineValidity validity = runtime->getEngineValidity( + engine_header.data(), + engine_header.size(), + &diagnostics); + + // Map TensorRT validity to ORT compatibility status + switch (validity) { + case nvinfer1::EngineValidity::kVALID: + *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + break; + + case nvinfer1::EngineValidity::kSUBOPTIMAL: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine compatible but recompilation recommended " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + break; + + case nvinfer1::EngineValidity::kINVALID: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine incompatible with this system " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + + default: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown validity status: " + << static_cast(validity); + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + } + + return nullptr; + + } catch (const std::exception& ex) { + std::string error_msg = std::string("[NvTensorRTRTX EP] Exception during validation: ") + ex.what(); + LOGS_DEFAULT(ERROR) << error_msg; + return factory.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + } catch (...) { + LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Unknown exception during validation"; + return factory.ort_api.CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] Unknown exception during validation"); + } + } + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { gpu_memory_infos.reserve(num_devices); host_accessible_memory_infos.reserve(num_devices); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index c1626fa4f36ad..b6a4069c59700 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -14,6 +14,53 @@ namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); +/* + * Convert binary data to hex string + */ +std::string BinaryToHexString(const void* data, size_t size) { + static const char hex_chars[] = "0123456789abcdef"; + const uint8_t* bytes = static_cast(data); + std::string result; + result.reserve(size * 2); + + for (size_t i = 0; i < size; ++i) { + result.push_back(hex_chars[(bytes[i] >> 4) & 0xF]); + result.push_back(hex_chars[bytes[i] & 0xF]); + } + return result; +} + +/* + * Convert hex string back to binary + */ +std::vector HexStringToBinary(const std::string& hex) { + if (hex.size() % 2 != 0) { + ORT_THROW("Hex string must have even length"); + } + + std::vector result; + result.reserve(hex.size() / 2); + + for (size_t i = 0; i < hex.size(); i += 2) { + uint8_t byte = 0; + + // High nibble + char c = hex[i]; + byte |= (c >= '0' && c <= '9') ? static_cast((c - '0') << 4) : (c >= 'a' && c <= 'f') ? static_cast((c - 'a' + 10) << 4) + : (c >= 'A' && c <= 'F') ? static_cast((c - 'A' + 10) << 4) + : 0; + + // Low nibble + c = hex[i + 1]; + byte |= (c >= '0' && c <= '9') ? static_cast(c - '0') : (c >= 'a' && c <= 'f') ? static_cast(c - 'a' + 10) + : (c >= 'A' && c <= 'F') ? static_cast(c - 'A' + 10) + : 0; + + result.push_back(byte); + } + return result; +} + /* * Check whether the graph has the EP context contrib op. * The op can contain the precompiled engine info for TRT EP to directly load the engine. diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h index 7c52f26cc9177..80263b1ba80d5 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h @@ -24,6 +24,12 @@ static const std::string PARTITION_NAME = "partition_name"; static const std::string SDK_VERSION = "ep_sdk_version"; static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; +// TensorRT does not currently expose a header size define; keep in sync with TRT engine serialization header size. +constexpr size_t kTensorRTEngineHeaderSize = 64; +// Helper functions for engine header validation +std::string BinaryToHexString(const void* data, size_t size); +std::vector HexStringToBinary(const std::string& hex); + bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx); const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 3a7a1b6504d12..eb1427db87463 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -86,6 +86,21 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.GetHardwareDeviceIncompatibilityDetails(&ep_factory_, hw, details); } + OrtStatus* ValidateCompiledModelCompatibilityInfo( + const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept override { + // Forward to underlying factory if it supports validation + if (ep_factory_.ValidateCompiledModelCompatibilityInfo) { + return ep_factory_.ValidateCompiledModelCompatibilityInfo( + &ep_factory_, devices, num_devices, compatibility_info, model_compatibility); + } + // If not supported, return NOT_APPLICABLE + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; + } + OrtEpFactory& ep_factory_; ProviderLibrary& provider_library_; std::optional library_path_;