Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2578,6 +2578,84 @@ const InlinedVector<const Node*> NvExecutionProvider::GetEpContextNodes() const
return ep_context_nodes;
}

std::string NvExecutionProvider::GetCompiledModelCompatibilityInfo(
const onnxruntime::GraphViewer& graph_viewer) const {
ORT_UNUSED_PARAMETER(graph_viewer);

// Protect read access to engine_headers_ for thread safety
auto lock = GetApiLock();

// Compatibility info is only supported when there is exactly one engine.
// If multiple EPContext nodes/engines exist, return empty so validation is not applicable.
if (engine_headers_.size() > 1) {
return std::string();
}

// If we have stored engine headers, return the first one found
// (typically there's only one per EP context)
if (!engine_headers_.empty()) {
return engine_headers_.begin()->second;
}

// No headers available - validation not supported for this model
return std::string();
}

common::Status NvExecutionProvider::ValidateCompiledModelCompatibilityInfo(
const std::string& compatibility_info,
OrtCompiledModelCompatibility& model_compatibility) const {
// If no compatibility info provided, validation not applicable
if (compatibility_info.empty()) {
model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
return Status::OK();
}

// Decode hex string to binary
std::vector<uint8_t> engine_header;
try {
engine_header = HexStringToBinary(compatibility_info);
} catch (const std::exception& ex) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what();
model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
return Status::OK();
}

// Use TensorRT RTX's getEngineValidity to check compatibility
uint64_t diagnostics = 0;
nvinfer1::EngineValidity validity = runtime_->getEngineValidity(
engine_header.data(),
engine_header.size(),
&diagnostics);

// Map TensorRT RTX validity to ORT compatibility status
switch (validity) {
case nvinfer1::EngineValidity::kVALID:
LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Engine is fully compatible with this system";
model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
break;

case nvinfer1::EngineValidity::kSUBOPTIMAL:
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is compatible but recompilation recommended "
<< "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")";
model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
break;

case nvinfer1::EngineValidity::kINVALID:
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is incompatible with this system "
<< "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")";
model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
break;

default:
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown TensorRT validity status: "
<< static_cast<int>(validity);
model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
break;
}

return Status::OK();
}

Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer,
const Node& fused_node,
std::unordered_map<std::string, size_t>& input_map,
Expand Down Expand Up @@ -2854,6 +2932,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"NvTensorRTRTX EP failed to create engine from network for fused node: " + fused_node.Name());
}

// Capture engine header (first 64 bytes) for compatibility validation
if (serialized_engine->size() >= kTensorRTEngineHeaderSize) {
std::string engine_header_hex = BinaryToHexString(
serialized_engine->data(),
kTensorRTEngineHeaderSize);
engine_headers_[fused_node.Name()] = engine_header_hex;
} else {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine too small to capture header for validation: "
<< serialized_engine->size() << " bytes";
}

trt_engine = std::unique_ptr<nvinfer1::ICudaEngine>(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size()));
if (trt_engine == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,13 @@ class NvExecutionProvider : public IExecutionProvider {

const InlinedVector<const Node*> GetEpContextNodes() const override;

// Engine compatibility validation methods
std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override;

common::Status ValidateCompiledModelCompatibilityInfo(
const std::string& compatibility_info,
OrtCompiledModelCompatibility& model_compatibility) const override;

private:
mutable NvExecutionProviderInfo info_;
bool external_stream_ = false;
Expand Down Expand Up @@ -424,6 +431,10 @@ class NvExecutionProvider : public IExecutionProvider {
std::unordered_map<std::string, std::vector<nvinfer1::IOptimizationProfile*>> profiles_;
std::unordered_map<std::string, DDSOutputAllocatorMap> dds_output_allocator_maps_;

// Storage for engine headers (64 bytes) for compatibility validation
// Maps fused_node_name -> hex-encoded engine header
mutable std::unordered_map<std::string, std::string> engine_headers_;

// for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture
cudnnHandle_t external_cudnn_handle_ = nullptr;
cublasHandle_t external_cublas_handle_ = nullptr;
Expand Down
123 changes: 122 additions & 1 deletion onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/cuda_stream_handle.h"

#include "onnx_ctx_model_helper.h"

Check warning on line 16 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc:16: Include the directory when naming header files [build/include_subdir] [4]
#include "nv_provider_factory.h"
#include "nv_execution_provider.h"
#include "nv_provider_factory_creator.h"
Expand All @@ -21,6 +22,11 @@

using namespace onnxruntime;

// External declarations
namespace onnxruntime {
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);
}

namespace onnxruntime {

void InitializeRegistry();
Expand Down Expand Up @@ -541,7 +547,7 @@

IsStreamAware = IsStreamAwareImpl;
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl;

ValidateCompiledModelCompatibilityInfo = ValidateCompiledModelCompatibilityInfoImpl;
ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with.
}

Expand Down Expand Up @@ -661,6 +667,7 @@

RETURN_IF_ERROR(factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options,
&ep_devices[num_ep_devices]));

factory->ort_api.ReleaseKeyValuePairs(ep_options);
factory->ort_api.ReleaseKeyValuePairs(ep_metadata);

Expand Down Expand Up @@ -735,6 +742,120 @@
return nullptr;
}

/**
* This function is called by the public C API GetModelCompatibilityForEpDevices.
* It uses TensorRT RTX runtime directly to call runtime->getEngineValidity() to check the 64-byte engine header.
*
* @param this_ptr Factory instance pointer
* @param devices Hardware devices (not used, validation is done against current system)
* @param num_devices Number of devices
* @param compatibility_info Hex-encoded 64-byte TensorRT RTX engine header (128 hex characters)
* @param model_compatibility Output parameter for compatibility status
* @return OrtStatus* nullptr on success, error status on failure
*/
static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfoImpl(
OrtEpFactory* this_ptr,
const OrtHardwareDevice* const* devices,

Check warning on line 758 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc:758: Add #include <utility> for move [build/include_what_you_use] [4]
size_t num_devices,
const char* compatibility_info,
OrtCompiledModelCompatibility* model_compatibility) noexcept {
auto& factory = *static_cast<NvTensorRtRtxEpFactory*>(this_ptr);

// Validate input parameters
if (compatibility_info == nullptr || model_compatibility == nullptr) {
return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT,
"[NvTensorRTRTX EP] Invalid arguments: compatibility_info or model_compatibility is null");
}

// Device parameters not used for header validation
ORT_UNUSED_PARAMETER(devices);
ORT_UNUSED_PARAMETER(num_devices);

try {
// If no compatibility info provided, validation not applicable
if (compatibility_info[0] == '\0') {
*model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
return nullptr;
}

// Decode hex string to binary
std::vector<uint8_t> engine_header;
try {
engine_header = HexStringToBinary(std::string(compatibility_info));
} catch (const std::exception& ex) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what();
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
return nullptr;
}

// Validate header size (keep in sync with TensorRT engine header size)
if (engine_header.size() != kTensorRTEngineHeaderSize) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Invalid header size: " << engine_header.size()
<< " bytes (expected " << kTensorRTEngineHeaderSize << ")";
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
return nullptr;
}

// Create TensorRT runtime for validation
static std::mutex runtime_creation_mutex;
std::unique_ptr<nvinfer1::IRuntime> runtime;
{
std::lock_guard<std::mutex> lock(runtime_creation_mutex);
TensorrtLogger& trt_logger = GetTensorrtLogger(false);
runtime.reset(nvinfer1::createInferRuntime(trt_logger));
}

if (!runtime) {
LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Failed to create TensorRT runtime";
return factory.ort_api.CreateStatus(ORT_FAIL,
"[NvTensorRTRTX EP] Failed to create TensorRT runtime");
}

// Use TensorRT's getEngineValidity to check compatibility
uint64_t diagnostics = 0;
nvinfer1::EngineValidity validity = runtime->getEngineValidity(
engine_header.data(),
engine_header.size(),
&diagnostics);

// Map TensorRT validity to ORT compatibility status
switch (validity) {
case nvinfer1::EngineValidity::kVALID:
*model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
break;

case nvinfer1::EngineValidity::kSUBOPTIMAL:
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine compatible but recompilation recommended "
<< "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")";
*model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
break;

case nvinfer1::EngineValidity::kINVALID:
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine incompatible with this system "
<< "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")";
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
break;

default:
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown validity status: "
<< static_cast<int>(validity);
*model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED;
break;
}

return nullptr;

} catch (const std::exception& ex) {
std::string error_msg = std::string("[NvTensorRTRTX EP] Exception during validation: ") + ex.what();
LOGS_DEFAULT(ERROR) << error_msg;
return factory.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str());
} catch (...) {
LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Unknown exception during validation";
return factory.ort_api.CreateStatus(ORT_FAIL,
"[NvTensorRTRTX EP] Unknown exception during validation");
}
}

OrtStatus* CreateMemoryInfoForDevices(int num_devices) {
gpu_memory_infos.reserve(num_devices);
host_accessible_memory_infos.reserve(num_devices);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,53 @@
namespace onnxruntime {
extern TensorrtLogger& GetTensorrtLogger(bool verbose_log);

/*
* Convert binary data to hex string
*/
std::string BinaryToHexString(const void* data, size_t size) {
static const char hex_chars[] = "0123456789abcdef";
const uint8_t* bytes = static_cast<const uint8_t*>(data);
std::string result;
result.reserve(size * 2);

for (size_t i = 0; i < size; ++i) {
result.push_back(hex_chars[(bytes[i] >> 4) & 0xF]);
result.push_back(hex_chars[bytes[i] & 0xF]);
}
return result;
}

/*
* Convert hex string back to binary
*/
std::vector<uint8_t> HexStringToBinary(const std::string& hex) {
if (hex.size() % 2 != 0) {
ORT_THROW("Hex string must have even length");
}

std::vector<uint8_t> result;
result.reserve(hex.size() / 2);

for (size_t i = 0; i < hex.size(); i += 2) {
uint8_t byte = 0;

// High nibble
char c = hex[i];
byte |= (c >= '0' && c <= '9') ? static_cast<uint8_t>((c - '0') << 4) : (c >= 'a' && c <= 'f') ? static_cast<uint8_t>((c - 'a' + 10) << 4)
: (c >= 'A' && c <= 'F') ? static_cast<uint8_t>((c - 'A' + 10) << 4)
: 0;

// Low nibble
c = hex[i + 1];
byte |= (c >= '0' && c <= '9') ? static_cast<uint8_t>(c - '0') : (c >= 'a' && c <= 'f') ? static_cast<uint8_t>(c - 'a' + 10)
: (c >= 'A' && c <= 'F') ? static_cast<uint8_t>(c - 'A' + 10)
: 0;

result.push_back(byte);
}
return result;
}

/*
* Check whether the graph has the EP context contrib op.
* The op can contain the precompiled engine info for TRT EP to directly load the engine.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
static const std::string SDK_VERSION = "ep_sdk_version";
static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";

// TensorRT does not currently expose a header size define; keep in sync with TRT engine serialization header size.
constexpr size_t kTensorRTEngineHeaderSize = 64;
// Helper functions for engine header validation
std::string BinaryToHexString(const void* data, size_t size);
std::vector<uint8_t> HexStringToBinary(const std::string& hex);

Check warning on line 31 in onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h:31: Add #include <vector> for vector<> [build/include_what_you_use] [4]

bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx);
const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer);
std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path);
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl {
return ep_factory_.GetHardwareDeviceIncompatibilityDetails(&ep_factory_, hw, details);
}

OrtStatus* ValidateCompiledModelCompatibilityInfo(
const OrtHardwareDevice* const* devices,
size_t num_devices,
const char* compatibility_info,
OrtCompiledModelCompatibility* model_compatibility) noexcept override {
// Forward to underlying factory if it supports validation
if (ep_factory_.ValidateCompiledModelCompatibilityInfo) {
return ep_factory_.ValidateCompiledModelCompatibilityInfo(
&ep_factory_, devices, num_devices, compatibility_info, model_compatibility);
}
// If not supported, return NOT_APPLICABLE
*model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
return nullptr;
}

OrtEpFactory& ep_factory_;
ProviderLibrary& provider_library_;
std::optional<std::filesystem::path> library_path_;
Expand Down
Loading