Skip to content

Commit 524ec5f

Browse files
authored
[Runtime] Use cudaGetDeviceCount to check if device exists (#16377)
Using `cudaDeviceGetAttribute` will set the global error code when the device doesn't exist and will impact subsequent CUDA API calls.
1 parent 3166366 commit 524ec5f

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/runtime/cuda/cuda_device_api.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ class CUDADeviceAPI final : public DeviceAPI {
4242
int value = 0;
4343
switch (kind) {
4444
case kExist:
45-
value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id) ==
46-
cudaSuccess);
45+
int count;
46+
CUDA_CALL(cudaGetDeviceCount(&count));
47+
value = static_cast<int>(dev.device_id < count);
4748
break;
4849
case kMaxThreadsPerBlock: {
4950
CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id));

0 commit comments

Comments
 (0)