diff --git a/numba_cuda/numba/cuda/api.py b/numba_cuda/numba/cuda/api.py index a0c5dc45e..cf63bd4f3 100644 --- a/numba_cuda/numba/cuda/api.py +++ b/numba_cuda/numba/cuda/api.py @@ -508,6 +508,11 @@ def close(): Explicitly clears all contexts in the current thread, and destroys all contexts if the current thread is the main thread. """ + # Must clear memsys object in case it has been used already + from .memory_management import rtsys + + rtsys.close() + devices.reset() diff --git a/numba_cuda/numba/cuda/memory_management/nrt.py b/numba_cuda/numba/cuda/memory_management/nrt.py index dac3ec746..8b31969fb 100644 --- a/numba_cuda/numba/cuda/memory_management/nrt.py +++ b/numba_cuda/numba/cuda/memory_management/nrt.py @@ -69,10 +69,18 @@ def __new__(cls, *args, **kwargs): def __init__(self): """Initialize memsys module and variable""" + self._reset() + + def _reset(self): + """Reset to the uninitialized state""" self._memsys_module = None self._memsys = None self._initialized = False + def close(self): + """Close and reset""" + self._reset() + def _compile_memsys_module(self): """ Compile memsys.cu and create a module from it in the current context