diff --git a/numba_cuda/numba/cuda/simulator/memory_management/nrt.py b/numba_cuda/numba/cuda/simulator/memory_management/nrt.py index 87e8407dd..b62907cf9 100644 --- a/numba_cuda/numba/cuda/simulator/memory_management/nrt.py +++ b/numba_cuda/numba/cuda/simulator/memory_management/nrt.py @@ -3,7 +3,19 @@ from numba import config -rtsys = None + +class RTSys: + def __init__(self, *args, **kwargs): + pass + + def memsys_enable_stats(self): + pass + + def get_allocation_stats(self): + pass + + +rtsys = RTSys() config.CUDA_NRT_STATS = False config.CUDA_ENABLE_NRT = False diff --git a/numba_cuda/numba/cuda/tests/__init__.py b/numba_cuda/numba/cuda/tests/__init__.py index feb74ca8b..c436d765c 100644 --- a/numba_cuda/numba/cuda/tests/__init__.py +++ b/numba_cuda/numba/cuda/tests/__init__.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause from fnmatch import fnmatch -from numba.testing import unittest +import unittest from numba import cuda from os.path import dirname, isfile, join, normpath, relpath, splitext diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py index f96fe424b..87fe592bf 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py @@ -14,7 +14,7 @@ import unittest import warnings from numba.core.errors import NumbaDebugInfoWarning -from numba.tests.support import ignore_internal_warnings +from numba.cuda.tests.support import ignore_internal_warnings import numpy as np import inspect diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_ssa.py b/numba_cuda/numba/cuda/tests/cudapy/test_ssa.py index a6b39d77f..ec78bcedb 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_ssa.py @@ -15,7 +15,7 @@ from numba.core import errors from numba.extending import overload -from numba.tests.support import override_config +from numba.cuda.tests.support import override_config from numba.cuda.testing import CUDATestCase, skip_on_cudasim diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py b/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py index 8d11e0005..a455e0b2e 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py @@ -9,7 +9,7 @@ from numba import config, cuda, types, njit, typeof from numba.np import numpy_support from numba.cuda.tests.support import TestCase -from numba.tests.support import MemoryLeakMixin +from numba.cuda.tests.support import MemoryLeakMixin class BaseUFuncTest(MemoryLeakMixin): diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py index 2f977589c..d37332158 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py @@ -6,7 +6,7 @@ from numba import vectorize, cuda, int32, uint32, float32, float64 from numba.cuda.testing import skip_on_cudasim, CUDATestCase -from numba.tests.support import CheckWarningsMixin +from numba.cuda.tests.support import CheckWarningsMixin import unittest diff --git a/numba_cuda/numba/cuda/tests/support.py b/numba_cuda/numba/cuda/tests/support.py index 2e08e1816..7b2dfedc7 100644 --- a/numba_cuda/numba/cuda/tests/support.py +++ b/numba_cuda/numba/cuda/tests/support.py @@ -764,3 +764,56 @@ def skip_if_no_external_compiler(self): if not external_compiler_works(): self.skipTest("No suitable external compiler was found.") + + +class MemoryLeak(object): + __enable_leak_check = True + + def memory_leak_setup(self): + # Clean up any NRT-backed objects hanging in a dead reference cycle + gc.collect() + self.__init_stats = rtsys.get_allocation_stats() + + def memory_leak_teardown(self): + if self.__enable_leak_check: + self.assert_no_memory_leak() + + def assert_no_memory_leak(self): + old = self.__init_stats + new = rtsys.get_allocation_stats() + total_alloc = new.alloc - old.alloc + total_free = new.free - old.free + total_mi_alloc = new.mi_alloc - old.mi_alloc + total_mi_free = new.mi_free - old.mi_free + self.assertEqual(total_alloc, total_free) + self.assertEqual(total_mi_alloc, total_mi_free) + + def disable_leak_check(self): + # For per-test use when MemoryLeakMixin is injected into a TestCase + self.__enable_leak_check = False + + +class MemoryLeakMixin(EnableNRTStatsMixin, MemoryLeak): + def setUp(self): + super(MemoryLeakMixin, self).setUp() + self.memory_leak_setup() + + def tearDown(self): + gc.collect() + self.memory_leak_teardown() + super(MemoryLeakMixin, self).tearDown() + + +class CheckWarningsMixin(object): + @contextlib.contextmanager + def check_warnings(self, messages, category=RuntimeWarning): + with warnings.catch_warnings(record=True) as catch: + warnings.simplefilter("always") + yield + found = 0 + for w in catch: + for m in messages: + if m in str(w.message): + self.assertEqual(w.category, category) + found += 1 + self.assertEqual(found, len(messages))