diff --git a/src/tensorrt_execution_provider.cc b/src/tensorrt_execution_provider.cc index 80cc43d..3f3fed7 100644 --- a/src/tensorrt_execution_provider.cc +++ b/src/tensorrt_execution_provider.cc @@ -3050,11 +3050,12 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* std::unordered_map& dds_output_allocator_maps = ep.GetDDSOutputAllocators(); auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name]; - // Get default OrtMemoryInfo from factory - const OrtMemoryInfo* mem_info = nullptr; - if (ep.factory_.cuda_gpu_memory_infos.find(device_id) != - ep.factory_.cuda_gpu_memory_infos.end()) { - mem_info = ep.factory_.cuda_gpu_memory_infos[device_id].get(); + // Get default OrtMemoryInfo from factory's device cache + const OrtMemoryInfo* mem_info = ep.factory_.GetMemoryInfoByOrdinal(device_id, /* is pinned */false); + if (mem_info == nullptr) { + std::string err_msg = "TensorRT EP failed to get OrtMemoryInfo for device_id " + + std::to_string(device_id) + " from provider factory."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } // Get allocator from OrtKernelContext @@ -3770,11 +3771,12 @@ OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_p std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input - // Get default OrtMemoryInfo from factory - const OrtMemoryInfo* mem_info = nullptr; - if (ep.factory_.cuda_gpu_memory_infos.find(device_id) != - ep.factory_.cuda_gpu_memory_infos.end()) { - mem_info = ep.factory_.cuda_gpu_memory_infos[device_id].get(); + // Get default OrtMemoryInfo from factory's device cache + const OrtMemoryInfo* mem_info = ep.factory_.GetMemoryInfoByOrdinal(device_id, /* is pinned */false); + if (mem_info == nullptr) { + std::string err_msg = "TensorRT EP failed to get OrtMemoryInfo for device_id " + + std::to_string(device_id) + " from provider factory."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); } // Get allocator from OrtKernelContext diff --git a/src/tensorrt_execution_provider.def b/src/tensorrt_execution_provider.def index ae83cb7..d2589b2 100644 --- a/src/tensorrt_execution_provider.def +++ b/src/tensorrt_execution_provider.def @@ -1,4 +1,4 @@ -LIBRARY "TensorRTEp.dll" +LIBRARY "ORTTensorRTEp.dll" EXPORTS CreateEpFactories @1 ReleaseEpFactory @2 diff --git a/src/tensorrt_execution_provider_data_transfer.cc b/src/tensorrt_execution_provider_data_transfer.cc index ca74a33..61af7ee 100644 --- a/src/tensorrt_execution_provider_data_transfer.cc +++ b/src/tensorrt_execution_provider_data_transfer.cc @@ -23,12 +23,9 @@ bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(const OrtDataTransferImpl* this auto src_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device); auto dst_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device); - // 0x10DE is the PCI vendor ID for NVIDIA - constexpr uint32_t nvidia_vendor_id = 0x10DE; - // Reject if GPU device is not NVIDIA - if ((src_type == OrtMemoryInfoDeviceType_GPU && src_vendor_id != nvidia_vendor_id) || - (dst_type == OrtMemoryInfoDeviceType_GPU && dst_vendor_id != nvidia_vendor_id)) { + if ((src_type == OrtMemoryInfoDeviceType_GPU && src_vendor_id != kNvidiaVendorId) || + (dst_type == OrtMemoryInfoDeviceType_GPU && dst_vendor_id != kNvidiaVendorId)) { return false; } @@ -110,11 +107,6 @@ OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(OrtDataTransferImpl* /*static*/ void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept { - // In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore - // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) - // - // If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here - // delete static_cast(this_ptr); - ; + delete static_cast(this_ptr); } } // namespace trt_ep diff --git a/src/tensorrt_execution_provider_data_transfer.h b/src/tensorrt_execution_provider_data_transfer.h index 816a5eb..c6bf8a6 100644 --- a/src/tensorrt_execution_provider_data_transfer.h +++ b/src/tensorrt_execution_provider_data_transfer.h @@ -9,9 +9,7 @@ namespace trt_ep { struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { - TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector& device_mem_infos, - std::vector& shared_mem_infos) - : ApiPtrs(api_ptrs), cuda_gpu_mem_devices_{device_mem_infos}, cuda_pinned_mem_devices_{shared_mem_infos} { + TRTEpDataTransfer(ApiPtrs api_ptrs) : OrtDataTransferImpl{}, ApiPtrs(api_ptrs) { CanCopy = CanCopyImpl; CopyTensors = CopyTensorsImpl; Release = ReleaseImpl; @@ -26,9 +24,5 @@ struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr, size_t num_tensors) noexcept; static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept; - - private: - std::vector& cuda_gpu_mem_devices_; - std::vector& cuda_pinned_mem_devices_; }; } // namespace trt_ep \ No newline at end of file diff --git a/src/tensorrt_provider_factory.cc b/src/tensorrt_provider_factory.cc index f0e985a..41319c1 100644 --- a/src/tensorrt_provider_factory.cc +++ b/src/tensorrt_provider_factory.cc @@ -11,26 +11,96 @@ #include #include +// --------------------------------------------------------------------------- +// C API boundary guard. +// +// Every C API entry point (ORT_API_CALL / extern "C") that returns OrtStatus* +// must catch all C++ exceptions before they cross the boundary — propagating +// exceptions through a C ABI is undefined behaviour. +// +// Usage: +// OrtStatus* ORT_API_CALL SomeImpl(...) noexcept { +// API_IMPL_BEGIN +// ... body ... +// return nullptr; +// API_IMPL_END(factory->ort_api) +// } +// --------------------------------------------------------------------------- +#define API_IMPL_BEGIN try { +#define API_IMPL_END(ort_api_ref) \ + } catch (const std::exception& ex) { \ + return (ort_api_ref).CreateStatus(ORT_EP_FAIL, ex.what()); \ + } catch (...) { \ + return (ort_api_ref).CreateStatus(ORT_EP_FAIL, "Unknown exception in TRT EP"); \ + } + +// --------------------------------------------------------------------------- +// TensorRT builder placeholder for test scenarios. +// +// TensorRT loads/unloads heavy internal libraries every time all IBuilder +// instances are destroyed. During unit testing (e.g., onnxruntime_provider_test) +// EPs are rapidly created and torn down, causing repeated overhead. +// +// ORT's test_main.cc has the same optimization behind `#ifdef USE_TENSORRT`, +// but that define is never set for plugin EPs. Instead we guard creation with +// an environment variable that the test harness can set: +// +// set ORT_TRT_EP_ENABLE_BUILDER_PLACEHOLDER=1 +// +// The placeholder is created once in CreateEpFactories() and destroyed in +// ReleaseEpFactory(), matching the factory's lifetime. +// --------------------------------------------------------------------------- +namespace { + +class PlaceholderTrtLogger : public nvinfer1::ILogger { + public: + void log(Severity /*severity*/, const char* /*msg*/) noexcept override {} +}; + +PlaceholderTrtLogger g_placeholder_trt_logger; +std::unique_ptr g_trt_builder_placeholder; + +void MaybeCreateBuilderPlaceholder() { + if (g_trt_builder_placeholder) return; // already created + + const char* env = std::getenv("ORT_TRT_EP_ENABLE_BUILDER_PLACEHOLDER"); + if (env != nullptr && std::string(env) == "1") { + g_trt_builder_placeholder.reset(nvinfer1::createInferBuilder(g_placeholder_trt_logger)); + } +} + +void DestroyBuilderPlaceholder() { + g_trt_builder_placeholder.reset(); +} + +} // namespace + namespace trt_ep { TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis) - : OrtEpFactory {}, ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { + : OrtEpFactory {}, + ApiPtrs(apis), + default_logger_{default_logger}, + ep_name_{ep_name}, + ort_api_{apis.ort_api}, + ep_api_{apis.ep_api} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; GetVersion = GetVersionImpl; - GetSupportedDevices = GetSupportedDevicesImpl; - CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; - CreateAllocator = CreateAllocatorImpl; ReleaseAllocator = ReleaseAllocatorImpl; - CreateDataTransfer = CreateDataTransferImpl; + IsStreamAware = IsStreamAwareImpl; +} - IsStreamAware = IsStreamAwareImpl; +TensorrtExecutionProviderFactory::~TensorrtExecutionProviderFactory() { + if (kernel_registry_ != nullptr) { + ep_api_.ReleaseKernelRegistry(kernel_registry_); + } } const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { @@ -48,30 +118,26 @@ const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVersionImpl(const return factory->ep_version_.c_str(); } -OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_devices) { - cuda_gpu_memory_infos.reserve(num_devices); - cuda_pinned_memory_infos.reserve(num_devices); - - for (int device_id = 0; device_id < num_devices; ++device_id) { - OrtMemoryInfo* mem_info = nullptr; - RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, - /*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE, - /* device_id */ device_id, OrtDeviceMemoryType_DEFAULT, - /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); - - cuda_gpu_memory_infos[device_id] = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - - // HOST_ACCESSIBLE memory should use the non-CPU device type - mem_info = nullptr; - RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, - /*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE, - /* device_id */ device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE, - /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); - - cuda_pinned_memory_infos[device_id] = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - } +const OrtMemoryInfo* TensorrtExecutionProviderFactory::GetMemoryInfoByOrdinal(int cuda_ordinal, bool is_pinned) { + // Get default OrtMemoryInfo from factory's device cache + const OrtMemoryInfo* mem_info = nullptr; + auto* cache_entry = FindDeviceCacheEntryByOrdinal(cuda_ordinal); + if (cache_entry != nullptr) { + mem_info = is_pinned ? cache_entry->pinned_memory_info : + cache_entry->device_memory_info; // Ort::MemoryInfo implicitly converts to OrtMemoryInfo* + } + return mem_info; +} - return nullptr; +TensorrtExecutionProviderFactory::HardwareDeviceKey TensorrtExecutionProviderFactory::MakeDeviceKey(const OrtApi& ort_api, + const OrtHardwareDevice& device, + int cuda_ordinal) { + return { + ort_api.HardwareDevice_Type(&device), + ort_api.HardwareDevice_VendorId(&device), + ort_api.HardwareDevice_DeviceId(&device), + cuda_ordinal, + }; } OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImpl( @@ -81,83 +147,165 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp OrtEpDevice** ep_devices, size_t max_ep_devices, size_t* p_num_ep_devices) noexcept { - size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); + API_IMPL_BEGIN + size_t& num_ep_devices = *p_num_ep_devices; - // Create two memory infos per device. - // The memory info is required to create allocator and gpu data transfer. - int num_cuda_devices = 0; - cudaGetDeviceCount(&num_cuda_devices); - RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices)); + // Clear stale ordinal mappings from any prior enumeration. + { + std::lock_guard lock(factory->device_cache_mutex_); + factory->ordinal_to_device_key_.clear(); + } + + auto release_ep_devices = [&](OrtStatus* status) -> OrtStatus* { + for (size_t j = 0; j < num_ep_devices; ++j) { + factory->ep_api.ReleaseEpDevice(ep_devices[j]); + ep_devices[j] = nullptr; + } + num_ep_devices = 0; + return status; + }; + + // Query CUDA device count once upfront so we can validate assigned ordinals. + int cuda_device_count = 0; + cudaError_t cuda_err = cudaGetDeviceCount(&cuda_device_count); + if (cuda_err != cudaSuccess) { + // CUDA API failure (e.g., driver not loaded, version mismatch) is a hard error. + // This is distinct from the case where CUDA works but reports zero devices. + std::string err_msg = std::string("cudaGetDeviceCount failed: ") + cudaGetErrorString(cuda_err) + + " (" + std::to_string(static_cast(cuda_err)) + ")"; + return factory->ort_api.CreateStatus(ORT_RUNTIME_EXCEPTION, err_msg.c_str()); + } - int32_t device_id = 0; + if (cuda_device_count == 0) { + RETURN_IF_ERROR(factory->ort_api.Logger_LogMessage(&factory->default_logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + "No CUDA devices found on the system. No OrtEpDevice will be created and returned.", + ORT_FILE, __LINE__, __FUNCTION__)); + } + int cuda_device_index_fallback = 0; // fallback counter when metadata lacks PCI bus ID for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - // C API const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // These can be returned as nullptr if you have nothing to add. - OrtKeyValuePairs* ep_metadata = nullptr; - OrtKeyValuePairs* ep_options = nullptr; - factory->ort_api.CreateKeyValuePairs(&ep_metadata); - factory->ort_api.CreateKeyValuePairs(&ep_options); + if (factory->ort_api.HardwareDevice_Type(&device) != OrtHardwareDeviceType::OrtHardwareDeviceType_GPU || + factory->ort_api.HardwareDevice_VendorId(&device) != kNvidiaVendorId) { + continue; + } - // The ep options can be provided here as default values. - // Users can also call SessionOptionsAppendExecutionProvider_V2 C API with provided ep options to override. - factory->ort_api.AddKeyValuePair(ep_metadata, "gpu_type", "data center"); // random example using made up values - factory->ort_api.AddKeyValuePair(ep_options, "trt_builder_optimization_level", "3"); + // Try to resolve the CUDA ordinal from pci_bus_id metadata if available. + // This is more reliable than counter-based ordinal assignment because it is + // not affected by enumeration order, CUDA_VISIBLE_DEVICES remapping, or + // mixed-vendor GPU configurations. + int current_device_id = -1; + const OrtKeyValuePairs* metadata = factory->ort_api_.HardwareDevice_Metadata(&device); + if (metadata != nullptr) { + const char* pci_bus_id = factory->ort_api_.GetKeyValue(metadata, "pci_bus_id"); + if (pci_bus_id != nullptr && pci_bus_id[0] != '\0') { + int resolved_ordinal = -1; + cudaError_t err = cudaDeviceGetByPCIBusId(&resolved_ordinal, pci_bus_id); + if (err == cudaSuccess && resolved_ordinal >= 0 && resolved_ordinal < cuda_device_count) { + current_device_id = resolved_ordinal; + } + } + } - // OrtEpDevice copies ep_metadata and ep_options. - OrtEpDevice* ep_device = nullptr; - auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, - &ep_device); + // Fallback: if pci_bus_id was not available, use counter-based ordinal assignment. + if (current_device_id < 0) { + current_device_id = cuda_device_index_fallback++; + } - factory->ort_api.ReleaseKeyValuePairs(ep_metadata); - factory->ort_api.ReleaseKeyValuePairs(ep_options); + // Validate the assigned ordinal is within the range of CUDA-visible devices. + // If hardware enumeration reports GPUs not visible to CUDA (e.g. due to + // CUDA_VISIBLE_DEVICES), skip them to avoid failures in allocator/stream creation. + if (current_device_id >= cuda_device_count) { + continue; + } - if (status != nullptr) { - return status; + const auto device_key = MakeDeviceKey(factory->ort_api, device, current_device_id); + DeviceCacheEntry* cache_entry = nullptr; + { + std::lock_guard lock(factory->device_cache_mutex_); + auto [it, inserted] = factory->device_cache_.try_emplace(device_key); + if (inserted) { + it->second.cuda_device_id = current_device_id; + it->second.device_memory_info = Ort::MemoryInfo{"Cuda", + OrtMemoryInfoDeviceType_GPU, + kNvidiaVendorId, + static_cast(current_device_id), + OrtDeviceMemoryType_DEFAULT, + /*alignment is default*/ 0, + OrtAllocatorType::OrtDeviceAllocator}; + it->second.pinned_memory_info = Ort::MemoryInfo{"CudaPinned", + OrtMemoryInfoDeviceType_GPU, + kNvidiaVendorId, + static_cast(current_device_id), + OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment is default*/ 0, + OrtAllocatorType::OrtDeviceAllocator}; } - const OrtMemoryInfo* cuda_gpu_mem_info = factory->cuda_gpu_memory_infos[device_id].get(); - const OrtMemoryInfo* cuda_pinned_mem_info = factory->cuda_pinned_memory_infos[device_id].get(); + cache_entry = &it->second; + current_device_id = cache_entry->cuda_device_id; + // Build ordinal -> key mapping for CreateAllocatorImpl lookups. + factory->ordinal_to_device_key_[current_device_id] = device_key; + } + + // These can be returned as nullptr if EP has nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_metadata); + factory->ort_api.CreateKeyValuePairs(&ep_options); + factory->ort_api.AddKeyValuePair(ep_metadata, "cuda_device_id", std::to_string(current_device_id).c_str()); + factory->ort_api.AddKeyValuePair(ep_options, "device_id", std::to_string(current_device_id).c_str()); + + // Get CUDA device properties for metadata + { + cudaDeviceProp prop; + if (cudaGetDeviceProperties(&prop, current_device_id) == cudaSuccess) { + factory->ort_api.AddKeyValuePair(ep_metadata, "cuda_device_name", prop.name); + factory->ort_api.AddKeyValuePair(ep_metadata, "cuda_compute_capability", + (std::to_string(prop.major) + "." + std::to_string(prop.minor)).c_str()); + } + } - // Register the allocator info required by TRT EP. - RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_gpu_mem_info)); - RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_mem_info)); + // OrtEpDevice copies ep_metadata and ep_options. + OrtEpDevice* ep_device = nullptr; + auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, &ep_device); - // Get memory device from memory info for gpu data transfer - factory->cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_mem_info)); - factory->cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_mem_info)); + factory->ort_api.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api.ReleaseKeyValuePairs(ep_options); - ep_devices[num_ep_devices++] = ep_device; - ++device_id; + if (status != nullptr) { + return release_ep_devices(status); } - // C++ API equivalent. Throws on error. - //{ - // Ort::ConstHardwareDevice device(devices[i]); - // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // Ort::KeyValuePairs ep_metadata; - // Ort::KeyValuePairs ep_options; - // ep_metadata.Add("version", "0.1"); - // ep_options.Add("trt_builder_optimization_level", "3"); - // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; - // ep_devices[num_ep_devices++] = ep_device.release(); - // } - //} - } + auto release_current_ep_device = [factory](OrtEpDevice* device) { + factory->ep_api.ReleaseEpDevice(device); + }; - // Create gpu data transfer - auto data_transfer_impl = std::make_unique(static_cast(*factory), - factory->cuda_gpu_mem_devices, // device memory - factory->cuda_pinned_mem_devices // shared memory - ); + // ep_device_guard owns the current device. On error, release_ep_devices cleans up + // previously committed devices [0, num_ep_devices), while the guard cleans up this one. + std::unique_ptr ep_device_guard(ep_device, release_current_ep_device); - factory->data_transfer_impl = std::move(data_transfer_impl); + // Register allocator info for GPU device memory + status = factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cache_entry->device_memory_info); + if (status != nullptr) { + return release_ep_devices(status); + } + + // Register allocator info for pinned host memory associated with the + // same CUDA ordinal as the device allocator above. + status = factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cache_entry->pinned_memory_info); + if (status != nullptr) { + return release_ep_devices(status); + } + + ep_devices[num_ep_devices++] = ep_device_guard.release(); + } return nullptr; + API_IMPL_END(factory->ort_api) } OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl( @@ -169,6 +317,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl( _In_ const OrtLogger* logger, _Out_ OrtEp** ep) noexcept { auto* factory = static_cast(this_ptr); *ep = nullptr; + API_IMPL_BEGIN if (num_devices != 1) { // we only registered for GPU and only expected to be selected for one GPU @@ -191,11 +340,29 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl( *ep = trt_ep.release(); return nullptr; + API_IMPL_END(factory->ort_api) } -void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { - TensorrtExecutionProvider* trt_ep = static_cast(ep); - delete trt_ep; +void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseEpImpl(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { + try { + TensorrtExecutionProvider* trt_ep = static_cast(ep); + delete trt_ep; + } catch (const std::exception& ex) { + // void return — cannot report via OrtStatus*. Log so teardown failures are diagnosable. + auto* factory = static_cast(this_ptr); + auto* log_status = factory->ort_api.Logger_LogMessage(&factory->default_logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, + (std::string("Exception in ReleaseEpImpl: ") + ex.what()).c_str(), + ORT_FILE, __LINE__, __FUNCTION__); + if (log_status) factory->ort_api.ReleaseStatus(log_status); + } catch (...) { + auto* factory = static_cast(this_ptr); + auto* log_status = factory->ort_api.Logger_LogMessage(&factory->default_logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, + "Unknown exception in ReleaseEpImpl", + ORT_FILE, __LINE__, __FUNCTION__); + if (log_status) factory->ort_api.ReleaseStatus(log_status); + } } OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(OrtEpFactory* this_ptr, @@ -203,6 +370,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(Or const OrtKeyValuePairs* /*allocator_options*/, OrtAllocator** allocator) noexcept { auto& factory = *static_cast(this_ptr); + API_IMPL_BEGIN // NOTE: The factory implementation is free to return a shared OrtAllocator* instance instead of creating a new // allocator on each call. To do this have an allocator instance as an OrtEpFactory class member and make @@ -249,6 +417,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(Or } return nullptr; + API_IMPL_END(factory.ort_api) } void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseAllocatorImpl(OrtEpFactory* /*this*/, @@ -261,9 +430,13 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept { auto& factory = *static_cast(this_ptr); - *data_transfer = factory.data_transfer_impl.get(); + API_IMPL_BEGIN + + auto data_transfer_impl = std::make_unique(static_cast(factory)); + *data_transfer = data_transfer_impl.release(); return nullptr; + API_IMPL_END(factory.ort_api) } bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { @@ -279,22 +452,39 @@ OrtStatus* TensorrtExecutionProviderFactory::GetKernelRegistryForEp(TensorrtExec } if (kernel_registry_ == nullptr) { - // Optional state that is provided to kernels on creation (can be null). - // We pass the OrtDataTransferImpl created by this factory to allow kernels to copy data between devices. - void* op_kernel_state = static_cast(data_transfer_impl.get()); const char* ep_name = ep.GetName(static_cast(&ep)); // This statement creates the kernel registry and caches it in the OrtEpFactory instance. // We assume that all EPs created by this factory can use the same kernel registry. This may not be the // case in a more complex OrtEpFactory that can create EP instances that are each configured for different // hardware devices. In such a scenario, a different kernel registry may be created for each EP configuration. - RETURN_IF_ERROR(CreateKernelRegistry(ep_name, op_kernel_state, &kernel_registry_)); + RETURN_IF_ERROR(CreateKernelRegistry(ep_name, nullptr, &kernel_registry_)); } *out_kernel_registry = kernel_registry_; return nullptr; } +TensorrtExecutionProviderFactory::DeviceCacheEntry* TensorrtExecutionProviderFactory::FindDeviceCacheEntryByOrdinalLocked(int cuda_ordinal) { + auto key_it = ordinal_to_device_key_.find(cuda_ordinal); + if (key_it == ordinal_to_device_key_.end()) { + return nullptr; + } + auto cache_it = device_cache_.find(key_it->second); + if (cache_it == device_cache_.end()) { + return nullptr; + } + return &cache_it->second; +} + +// IMPORTANT: Entries are never erased from device_cache_ after insertion. +// This guarantees pointer stability for DeviceCacheEntry* returned by +// FindDeviceCacheEntryByOrdinal() after the lock is released. +TensorrtExecutionProviderFactory::DeviceCacheEntry* TensorrtExecutionProviderFactory::FindDeviceCacheEntryByOrdinal(int cuda_ordinal) { + std::lock_guard lock(device_cache_mutex_); + return FindDeviceCacheEntryByOrdinalLocked(cuda_ordinal); +} + } // namespace trt_ep #define EXPORT_SYMBOL @@ -313,23 +503,78 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const // Manual init for the C++ API Ort::InitApi(ort_api); - // Factory could use registration_name or define its own EP name. - std::unique_ptr factory = std::make_unique(registration_name, *default_logger, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api}); + try { + int cuda_device_count = 0; + const cudaError_t cuda_err = cudaGetDeviceCount(&cuda_device_count); + if (cuda_err != cudaSuccess) { + // CUDA API failure (e.g., driver not loaded, version mismatch) is a hard error. + // This is distinct from the case where CUDA works but reports zero devices. + std::string err_msg = std::string("cudaGetDeviceCount failed: ") + cudaGetErrorString(cuda_err) + + " (" + std::to_string(static_cast(cuda_err)) + ")"; + return ort_api->CreateStatus(ORT_RUNTIME_EXCEPTION, err_msg.c_str()); + } - if (max_factories < 1) { - return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, - "Not enough space to return EP factory. Need at least one."); - } + if (cuda_device_count == 0) { + auto* log_status = ort_api->Logger_LogMessage(default_logger, ORT_LOGGING_LEVEL_INFO, + "No CUDA devices found on the system." + "TensorRT execution provider will still be " + "created but will not be able to run any models.", + ORT_FILE, __LINE__, __FUNCTION__); + if (log_status) ort_api->ReleaseStatus(log_status); + } - factories[0] = factory.release(); - *num_factories = 1; + // Create TRT builder placeholder if running under a test harness. + // This prevents TensorRT from repeatedly loading/unloading internal + // libraries as EP instances are created and destroyed across tests. + MaybeCreateBuilderPlaceholder(); - return nullptr; + // Factory could use registration_name or define its own EP name. + std::unique_ptr factory = std::make_unique( + registration_name, *default_logger, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api}); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory.release(); + *num_factories = 1; + + return nullptr; + } catch (const std::exception& ex) { + return ort_api->CreateStatus(ORT_EP_FAIL, ex.what()); + } catch (...) { + return ort_api->CreateStatus(ORT_EP_FAIL, "Unknown exception in CreateEpFactories"); + } } EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { - delete static_cast(factory); - return nullptr; + const OrtApi* ort_api = nullptr; + + try { + // Grab the OrtApi reference before destroying the factory, so we can + // use it to create an error status if the catch block is reached. + auto* trt_factory = static_cast(factory); + ort_api = &trt_factory->ort_api; + + delete trt_factory; + + // Release the placeholder builder when the last factory is torn down. + DestroyBuilderPlaceholder(); + + return nullptr; + } catch (const std::exception& ex) { + if (ort_api != nullptr) { + return ort_api->CreateStatus(ORT_EP_FAIL, ex.what()); + } + // ort_api not yet captured — nothing we can do except not crash. + return nullptr; + } catch (...) { + if (ort_api != nullptr) { + return ort_api->CreateStatus(ORT_EP_FAIL, "Unknown exception in ReleaseEpFactory"); + } + return nullptr; + } } } // extern "C" diff --git a/src/tensorrt_provider_factory.h b/src/tensorrt_provider_factory.h index d016a9f..0372712 100644 --- a/src/tensorrt_provider_factory.h +++ b/src/tensorrt_provider_factory.h @@ -4,6 +4,13 @@ #include "tensorrt_execution_provider_data_transfer.h" #include "cuda_allocator.h" +#include +#include +#include +#include +#include +#include + using MemoryInfoUniquePtr = std::unique_ptr>; namespace trt_ep { @@ -16,26 +23,17 @@ struct TensorrtExecutionProvider; struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { public: TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis); - - OrtStatus* CreateMemoryInfoForDevices(int num_devices); + ~TensorrtExecutionProviderFactory(); // Called by child OrtEp instances to retrieve the cached kernel registry for that EP. OrtStatus* GetKernelRegistryForEp(TensorrtExecutionProvider& ep, /*out*/ const OrtKernelRegistry** kernel_registry); - // CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo - // instance required for that. - // Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device. - std::unordered_map cuda_gpu_memory_infos; // device id -> memory info - std::unordered_map cuda_pinned_memory_infos; + const OrtMemoryInfo* GetMemoryInfoByOrdinal(int cuda_ordinal, bool is_pinned); // Keeps allocators per ep device in factory so they can be shared across sessions. std::unordered_map> cuda_gpu_allocators; // device id -> allocator std::unordered_map> cuda_pinned_allocators; - std::vector cuda_gpu_mem_devices; - std::vector cuda_pinned_mem_devices; - std::unique_ptr data_transfer_impl; // data transfer implementation for this factory - private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -69,6 +67,9 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { const std::string ep_name_; // EP name const std::string vendor_{"Nvidia"}; // EP vendor name const std::string ep_version_{"0.1.0"}; // EP version + + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; const OrtLogger& default_logger_; // Cached kernel registry used by all OrtEp instances created by this factory. Refer to OrtEp::GetKernelRegistry. @@ -76,5 +77,54 @@ struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { // Note: If this factory instead created EP instances that each supported different hardware configurations, then // the factory could cache a different kernel registry per EP configuration. OrtKernelRegistry* kernel_registry_ = nullptr; + + struct HardwareDeviceKey { + OrtHardwareDeviceType type{ OrtHardwareDeviceType::OrtHardwareDeviceType_CPU }; + uint32_t vendor_id{ 0 }; + uint32_t device_id{ 0 }; // PCI device ID — identifies the hardware model, NOT a unique device + int cuda_ordinal{ -1 }; // CUDA ordinal — unique per physical GPU on this host + + bool operator==(const HardwareDeviceKey& other) const noexcept { + return type == other.type && + vendor_id == other.vendor_id && + device_id == other.device_id && + cuda_ordinal == other.cuda_ordinal; + } + }; + + struct HardwareDeviceKeyHasher { + size_t operator()(const HardwareDeviceKey& key) const noexcept { + size_t hash = static_cast(key.type); + hash = (hash * 1315423911u) ^ static_cast(key.vendor_id); + hash = (hash * 1315423911u) ^ static_cast(key.device_id); + hash = (hash * 1315423911u) ^ static_cast(key.cuda_ordinal); + return hash; + } + }; + + static HardwareDeviceKey MakeDeviceKey(const OrtApi& ort_api, + const OrtHardwareDevice& device, + int cuda_ordinal); + + struct DeviceCacheEntry { + int cuda_device_id{ -1 }; + Ort::MemoryInfo device_memory_info{ nullptr }; + Ort::MemoryInfo pinned_memory_info{ nullptr }; + }; + + // Per-physical-device cache. The key includes the CUDA ordinal to distinguish + // identical GPUs (same PCI vendor/device ID) on multi-GPU hosts. + std::mutex device_cache_mutex_; + std::unordered_map device_cache_; + + // Ordinal-to-HardwareDeviceKey mapping built during GetSupportedDevicesImpl. + std::unordered_map ordinal_to_device_key_; + + /// Find the DeviceCacheEntry for a given CUDA ordinal. + /// Returns nullptr if the ordinal has not been registered. + DeviceCacheEntry* FindDeviceCacheEntryByOrdinal(int cuda_ordinal); + + /// Same as FindDeviceCacheEntryByOrdinal but assumes device_cache_mutex_ is already held. + DeviceCacheEntry* FindDeviceCacheEntryByOrdinalLocked(int cuda_ordinal); }; } // namespace trt_ep \ No newline at end of file diff --git a/src/utils/ep_utils.h b/src/utils/ep_utils.h index f940195..4ba8d05 100644 --- a/src/utils/ep_utils.h +++ b/src/utils/ep_utils.h @@ -23,6 +23,8 @@ struct ApiPtrs { namespace trt_ep { +constexpr uint32_t kNvidiaVendorId = 0x10DE; + #define ENFORCE(condition, ...) \ do { \ if (!(condition)) { \