diff --git a/numba_cuda/numba/cuda/core/base.py b/numba_cuda/numba/cuda/core/base.py index 13059d952..241f7ace4 100644 --- a/numba_cuda/numba/cuda/core/base.py +++ b/numba_cuda/numba/cuda/core/base.py @@ -933,7 +933,12 @@ def compile_subroutine( If *caching* evaluates True, the function keeps the compiled function for reuse in *.cached_internal_func*. """ - cache_key = (impl.__code__, sig, type(self.error_model)) + cache_key = ( + impl.__code__, + sig, + type(self.error_model), + self.enable_nrt, + ) if not caching: cached = None else: diff --git a/numba_cuda/numba/cuda/tests/nrt/test_nrt.py b/numba_cuda/numba/cuda/tests/nrt/test_nrt.py index 9cc9ccbef..459032d71 100644 --- a/numba_cuda/numba/cuda/tests/nrt/test_nrt.py +++ b/numba_cuda/numba/cuda/tests/nrt/test_nrt.py @@ -382,6 +382,37 @@ def foo(): self.assertEqual(stats.free, stats_free) self.assertEqual(stats.mi_free, stats_mi_free) + def test_nrt_toggle_enabled(self): + def array_reshape1d(arr, newshape, got): + y = arr.reshape(newshape) + for i in range(y.shape[0]): + got[i] = y[i] + + def array_reshape(arr, newshape): + return arr.reshape(newshape) + + with override_config("CUDA_ENABLE_NRT", True): + # compile a kernel that caches an NRT enabled reshape primitive + @cuda.jit + def kernel(out): + out = out.reshape(out.shape) + out[0] = 1 + + out = cuda.to_device(np.zeros(1, dtype=np.float64)) + kernel[1, 1](out) + + with override_config("CUDA_ENABLE_NRT", False): + # compile and launch a new kernel that gets a cache hit on the + # NRT enabled reshape, but tries to launch with NRT disabled + # globally + new_kernel = cuda.jit(array_reshape1d) + arr = np.arange(24) + expected = array_reshape(arr, (24,)) + got = np.zeros(expected.shape, dtype=arr.dtype) + new_kernel[1, 1](arr, (24,), got) + + self.assertTrue(np.array_equal(expected, got)) + if __name__ == "__main__": unittest.main()