-
Notifications
You must be signed in to change notification settings - Fork 2
Only add support for CUDA devices in GetSupportedDevices #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
21653db
1de07bd
ab98037
0fa5f54
d8a8d84
bb325a2
86b17b0
41ea68c
3699c7f
0e36718
100f440
e8e0de6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,7 +55,7 @@ OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_ | |
| for (int device_id = 0; device_id < num_devices; ++device_id) { | ||
| OrtMemoryInfo* mem_info = nullptr; | ||
| RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, | ||
| /*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE, | ||
| /* vendor_id */ kNvidiaVendorId, | ||
| /* device_id */ device_id, OrtDeviceMemoryType_DEFAULT, | ||
| /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); | ||
|
|
||
|
|
@@ -64,7 +64,7 @@ OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_ | |
| // HOST_ACCESSIBLE memory should use the non-CPU device type | ||
| mem_info = nullptr; | ||
| RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, | ||
| /*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE, | ||
| /* vendor_id */ kNvidiaVendorId, | ||
| /* device_id */ device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE, | ||
| /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); | ||
|
|
||
|
|
@@ -87,7 +87,15 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp | |
| // Create two memory infos per device. | ||
| // The memory info is required to create allocator and gpu data transfer. | ||
| int num_cuda_devices = 0; | ||
| cudaGetDeviceCount(&num_cuda_devices); | ||
| cudaError_t cuda_err = cudaGetDeviceCount(&num_cuda_devices); | ||
| if (cuda_err != cudaSuccess) { | ||
| return factory->ort_api.CreateStatus(ORT_EP_FAIL, cudaGetErrorString(cuda_err)); | ||
| } | ||
|
|
||
| if (num_cuda_devices == 0) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returning error when num_cuda_devices == 0 is semantically questionable. GetSupportedDevices is an enumeration function — its job is to report which devices this EP can service. If there are no CUDA devices, the correct behavior is to report 0 supported devices and return success (nullptr), not to return a hard failure. An ORT_EP_FAIL status here may cause ORT to log spurious errors or abort initialization on machines that legitimately have no NVIDIA GPUs (e.g., an AMD-only system where the TRT EP plugin is still present).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense and i made the return success (nullptr) when there is no CUDA device available. Also, add a check for CUDA devices in |
||
| return factory->ort_api.CreateStatus(ORT_EP_FAIL, "No CUDA devices found."); | ||
| } | ||
|
|
||
| RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices)); | ||
|
|
||
| int32_t device_id = 0; | ||
|
chilo-ms marked this conversation as resolved.
Outdated
|
||
|
|
@@ -96,7 +104,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp | |
| // C API | ||
| const OrtHardwareDevice& device = *devices[i]; | ||
|
|
||
| if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { | ||
| if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In tensorrt_provider_factory.cc device_id is a zero-based counter that increments for each NVIDIA GPU found in the ORT hardware device list. This value is then used to index into cuda_gpu_memory_infos (keyed 0..num_cuda_devices-1). The implicit assumption is that the ORT hardware device enumeration order matches the CUDA device ordinals. If ORT enumerates devices in a different order than CUDA, the device_id mapping could silently use the wrong memory info. The newly added bounds check (device_id < num_cuda_devices) catches overflow, which is good, but doesn't guard against order mismatches. Suggestion: Consider using the actual CUDA device ID from the OrtHardwareDevice (e.g., via a device-index getter API if available) rather than relying on positional enumeration order.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the code to use actual CUDA device ID from the OrtHardwareDevice. |
||
| factory->ort_api.HardwareDevice_VendorId(&device) == kNvidiaVendorId) { | ||
| // These can be returned as nullptr if you have nothing to add. | ||
| OrtKeyValuePairs* ep_metadata = nullptr; | ||
| OrtKeyValuePairs* ep_options = nullptr; | ||
|
|
@@ -120,6 +129,9 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImp | |
| return status; | ||
| } | ||
|
|
||
| RETURN_IF_NOT(device_id < num_cuda_devices, | ||
| "The device_id for supported device exceeds the number of CUDA devices."); | ||
|
|
||
| const OrtMemoryInfo* cuda_gpu_mem_info = factory->cuda_gpu_memory_infos[device_id].get(); | ||
| const OrtMemoryInfo* cuda_pinned_mem_info = factory->cuda_pinned_memory_infos[device_id].get(); | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the body of the first if block uses 6-space indentation while the rest of the file uses 4-space:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made it to use 4-space.
I think i will include lintrunner to have a nice format in other PR.