Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ If you want to manage all run-time dependencies yourself, also pass the `--no-de
python -m numba.runtests numba.cuda.tests
```

This should discover the`numba.cuda` module from the `numba_cuda` package. You
This should discover the `numba.cuda` module from the `numba_cuda` package. You
can check where `numba.cuda` files are being located by running

```
Expand Down
2 changes: 2 additions & 0 deletions numba_cuda/numba/cuda/cudadrv/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __getitem__(self, devnum):
"""
Returns the context manager for device *devnum*.
"""
if not isinstance(devnum, (int, slice)) and USE_NV_BINDING:
devnum = int(devnum)
return self.lst[devnum]

def __str__(self):
Expand Down
13 changes: 13 additions & 0 deletions numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ def test_gpus_iter(self):
gpulist = list(cuda.gpus)
self.assertGreater(len(gpulist), 0)

def test_gpus_cudevice_indexing(self):
"""Test that CUdevice objects can be used to index into cuda.gpus"""
# When using the CUDA Python bindings, the device ids are CUdevice
# objects, otherwise they are integers. We test that the device id is
# usable as an index into cuda.gpus.
device_ids = [device.id for device in cuda.list_devices()]
for device_id in device_ids:
with cuda.gpus[device_id]:
# Check that the device is an integer if not using the CUDA
# Python bindings, otherwise it's a CUdevice object
assert isinstance(device_id, int) != driver.USE_NV_BINDING
self.assertEqual(cuda.gpus.current.id, device_id)


class TestContextAPI(CUDATestCase):
def tearDown(self):
Expand Down