Skip to content
Merged
7 changes: 2 additions & 5 deletions src/tensorrt_execution_provider_data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,9 @@ bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(const OrtDataTransferImpl* this
auto src_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device);
auto dst_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device);

// 0x10DE is the PCI vendor ID for NVIDIA
constexpr uint32_t nvidia_vendor_id = 0x10DE;

// Reject if GPU device is not NVIDIA
if ((src_type == OrtMemoryInfoDeviceType_GPU && src_vendor_id != nvidia_vendor_id) ||
(dst_type == OrtMemoryInfoDeviceType_GPU && dst_vendor_id != nvidia_vendor_id)) {
if ((src_type == OrtMemoryInfoDeviceType_GPU && src_vendor_id != kNvidiaVendorId) ||
(dst_type == OrtMemoryInfoDeviceType_GPU && dst_vendor_id != kNvidiaVendorId)) {
return false;
}

Expand Down
20 changes: 16 additions & 4 deletions src/tensorrt_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_
for (int device_id = 0; device_id < num_devices; ++device_id) {
OrtMemoryInfo* mem_info = nullptr;
RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU,
/*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE,
/* vendor_id */ kNvidiaVendorId,
/* device_id */ device_id, OrtDeviceMemoryType_DEFAULT,
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info));

Expand All @@ -64,7 +64,7 @@ OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_
// HOST_ACCESSIBLE memory should use the non-CPU device type
mem_info = nullptr;
RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU,
/*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE,
/* vendor_id */ kNvidiaVendorId,
/* device_id */ device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE,
/*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info));

Expand All @@ -87,7 +87,15 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
// Create two memory infos per device.
// The memory info is required to create allocator and gpu data transfer.
int num_cuda_devices = 0;
cudaGetDeviceCount(&num_cuda_devices);
cudaError_t cuda_err = cudaGetDeviceCount(&num_cuda_devices);
if (cuda_err != cudaSuccess) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the body of the first if block uses 6-space indentation while the rest of the file uses 4-space:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it to use 4-space.

I think i will include lintrunner to have a nice format in other PR.

return factory->ort_api.CreateStatus(ORT_EP_FAIL, cudaGetErrorString(cuda_err));
}

if (num_cuda_devices == 0) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning error when num_cuda_devices == 0 is semantically questionable.

GetSupportedDevices is an enumeration function — its job is to report which devices this EP can service. If there are no CUDA devices, the correct behavior is to report 0 supported devices and return success (nullptr), not to return a hard failure. An ORT_EP_FAIL status here may cause ORT to log spurious errors or abort initialization on machines that legitimately have no NVIDIA GPUs (e.g., an AMD-only system where the TRT EP plugin is still present).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense and i made the return success (nullptr) when there is no CUDA device available.

Also, add a check for CUDA devices in CreateEpFactories()

return factory->ort_api.CreateStatus(ORT_EP_FAIL, "No CUDA devices found.");
}

RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices));

int32_t device_id = 0;
Comment thread
chilo-ms marked this conversation as resolved.
Outdated
Expand All @@ -96,7 +104,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
// C API
const OrtHardwareDevice& device = *devices[i];

if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU &&

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In tensorrt_provider_factory.cc device_id is a zero-based counter that increments for each NVIDIA GPU found in the ORT hardware device list. This value is then used to index into cuda_gpu_memory_infos (keyed 0..num_cuda_devices-1). The implicit assumption is that the ORT hardware device enumeration order matches the CUDA device ordinals.

If ORT enumerates devices in a different order than CUDA, the device_id mapping could silently use the wrong memory info. The newly added bounds check (device_id < num_cuda_devices) catches overflow, which is good, but doesn't guard against order mismatches.

Suggestion: Consider using the actual CUDA device ID from the OrtHardwareDevice (e.g., via a device-index getter API if available) rather than relying on positional enumeration order.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code to use actual CUDA device ID from the OrtHardwareDevice.

factory->ort_api.HardwareDevice_VendorId(&device) == kNvidiaVendorId) {
// These can be returned as nullptr if you have nothing to add.
OrtKeyValuePairs* ep_metadata = nullptr;
OrtKeyValuePairs* ep_options = nullptr;
Expand All @@ -120,6 +129,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
return status;
}

RETURN_IF_NOT(device_id < num_cuda_devices,
"The device_id for supported device exceeds the number of CUDA devices.");

const OrtMemoryInfo* cuda_gpu_mem_info = factory->cuda_gpu_memory_infos[device_id].get();
const OrtMemoryInfo* cuda_pinned_mem_info = factory->cuda_pinned_memory_infos[device_id].get();

Expand Down
2 changes: 2 additions & 0 deletions src/utils/ep_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ struct ApiPtrs {

namespace trt_ep {

constexpr uint32_t kNvidiaVendorId = 0x10DE;

#define ENFORCE(condition, ...) \
do { \
if (!(condition)) { \
Expand Down