diff --git a/README.md b/README.md index 3aeb9bde7..bd45c4f00 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ If you want to manage all run-time dependencies yourself, also pass the `--no-de python -m numba.runtests numba.cuda.tests ``` -This should discover the`numba.cuda` module from the `numba_cuda` package. You +This should discover the `numba.cuda` module from the `numba_cuda` package. You can check where `numba.cuda` files are being located by running ``` diff --git a/numba_cuda/numba/cuda/cudadrv/devices.py b/numba_cuda/numba/cuda/cudadrv/devices.py index 362d8ebe8..158ad0786 100644 --- a/numba_cuda/numba/cuda/cudadrv/devices.py +++ b/numba_cuda/numba/cuda/cudadrv/devices.py @@ -40,6 +40,8 @@ def __getitem__(self, devnum): """ Returns the context manager for device *devnum*. """ + if not isinstance(devnum, (int, slice)) and USE_NV_BINDING: + devnum = int(devnum) return self.lst[devnum] def __str__(self): diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py b/numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py index 173b4181f..f2f851d58 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py @@ -25,6 +25,19 @@ def test_gpus_iter(self): gpulist = list(cuda.gpus) self.assertGreater(len(gpulist), 0) + def test_gpus_cudevice_indexing(self): + """Test that CUdevice objects can be used to index into cuda.gpus""" + # When using the CUDA Python bindings, the device ids are CUdevice + # objects, otherwise they are integers. We test that the device id is + # usable as an index into cuda.gpus. + device_ids = [device.id for device in cuda.list_devices()] + for device_id in device_ids: + with cuda.gpus[device_id]: + # Check that the device is an integer if not using the CUDA + # Python bindings, otherwise it's a CUdevice object + assert isinstance(device_id, int) != driver.USE_NV_BINDING + self.assertEqual(cuda.gpus.current.id, device_id) + class TestContextAPI(CUDATestCase): def tearDown(self):