Skip to content
14 changes: 12 additions & 2 deletions src/tensorrt_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,26 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp
// Create two memory infos per device.
// The memory info is required to create allocator and gpu data transfer.
int num_cuda_devices = 0;
cudaGetDeviceCount(&num_cuda_devices);
cudaError_t cuda_err = cudaGetDeviceCount(&num_cuda_devices);
if (cuda_err != cudaSuccess) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the body of the first if block uses 6-space indentation while the rest of the file uses 4-space:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it to use 4-space.

I think i will include lintrunner to have a nice format in other PR.

return factory->ort_api.CreateStatus(ORT_EP_FAIL, cudaGetErrorString(cuda_err));
}

if (num_cuda_devices == 0) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning error when num_cuda_devices == 0 is semantically questionable.

GetSupportedDevices is an enumeration function — its job is to report which devices this EP can service. If there are no CUDA devices, the correct behavior is to report 0 supported devices and return success (nullptr), not to return a hard failure. An ORT_EP_FAIL status here may cause ORT to log spurious errors or abort initialization on machines that legitimately have no NVIDIA GPUs (e.g., an AMD-only system where the TRT EP plugin is still present).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense and i made the return success (nullptr) when there is no CUDA device available.

Also, add a check for CUDA devices in CreateEpFactories()

return factory->ort_api.CreateStatus(ORT_EP_FAIL, "No CUDA devices found.");
}

RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices));

int32_t device_id = 0;
Comment thread
chilo-ms marked this conversation as resolved.
Outdated
constexpr uint32_t kNvidiaVendorId = 0x10DE;
Comment thread
chilo-ms marked this conversation as resolved.
Outdated

for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
// C API
const OrtHardwareDevice& device = *devices[i];

if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) {
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU &&

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In tensorrt_provider_factory.cc device_id is a zero-based counter that increments for each NVIDIA GPU found in the ORT hardware device list. This value is then used to index into cuda_gpu_memory_infos (keyed 0..num_cuda_devices-1). The implicit assumption is that the ORT hardware device enumeration order matches the CUDA device ordinals.

If ORT enumerates devices in a different order than CUDA, the device_id mapping could silently use the wrong memory info. The newly added bounds check (device_id < num_cuda_devices) catches overflow, which is good, but doesn't guard against order mismatches.

Suggestion: Consider using the actual CUDA device ID from the OrtHardwareDevice (e.g., via a device-index getter API if available) rather than relying on positional enumeration order.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code to use actual CUDA device ID from the OrtHardwareDevice.

factory->ort_api.HardwareDevice_VendorId(&device) == kNvidiaVendorId) {
// These can be returned as nullptr if you have nothing to add.
OrtKeyValuePairs* ep_metadata = nullptr;
OrtKeyValuePairs* ep_options = nullptr;
Expand Down