Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion numba_cuda/numba/cuda/simulator/memory_management/nrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@

from numba import config

rtsys = None

class RTSys:
def __init__(self, *args, **kwargs):
pass

def memsys_enable_stats(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future reference - It's a little non-obvious why this was needed but the reason is that the BaseUFuncTest switched to the Numba-CUDA MemoryLeakMixin which initializes the Numba-CUDA rtsys. Prior to this PR, it was erroneously using the CPU-based MemoryLeakMixin.

pass

def get_allocation_stats(self):
pass


rtsys = RTSys()

config.CUDA_NRT_STATS = False
config.CUDA_ENABLE_NRT = False
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-2-Clause

from fnmatch import fnmatch
from numba.testing import unittest
import unittest
from numba import cuda
from os.path import dirname, isfile, join, normpath, relpath, splitext

Expand Down
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import unittest
import warnings
from numba.core.errors import NumbaDebugInfoWarning
from numba.tests.support import ignore_internal_warnings
from numba.cuda.tests.support import ignore_internal_warnings
import numpy as np
import inspect

Expand Down
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/tests/cudapy/test_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from numba.core import errors

from numba.extending import overload
from numba.tests.support import override_config
from numba.cuda.tests.support import override_config
from numba.cuda.testing import CUDATestCase, skip_on_cudasim


Expand Down
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numba import config, cuda, types, njit, typeof
from numba.np import numpy_support
from numba.cuda.tests.support import TestCase
from numba.tests.support import MemoryLeakMixin
from numba.cuda.tests.support import MemoryLeakMixin


class BaseUFuncTest(MemoryLeakMixin):
Expand Down
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from numba import vectorize, cuda, int32, uint32, float32, float64
from numba.cuda.testing import skip_on_cudasim, CUDATestCase
from numba.tests.support import CheckWarningsMixin
from numba.cuda.tests.support import CheckWarningsMixin

import unittest

Expand Down
53 changes: 53 additions & 0 deletions numba_cuda/numba/cuda/tests/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,56 @@ def skip_if_no_external_compiler(self):

if not external_compiler_works():
self.skipTest("No suitable external compiler was found.")


class MemoryLeak(object):
__enable_leak_check = True

def memory_leak_setup(self):
# Clean up any NRT-backed objects hanging in a dead reference cycle
gc.collect()
self.__init_stats = rtsys.get_allocation_stats()

def memory_leak_teardown(self):
if self.__enable_leak_check:
self.assert_no_memory_leak()

def assert_no_memory_leak(self):
old = self.__init_stats
new = rtsys.get_allocation_stats()
total_alloc = new.alloc - old.alloc
total_free = new.free - old.free
total_mi_alloc = new.mi_alloc - old.mi_alloc
total_mi_free = new.mi_free - old.mi_free
self.assertEqual(total_alloc, total_free)
self.assertEqual(total_mi_alloc, total_mi_free)

def disable_leak_check(self):
# For per-test use when MemoryLeakMixin is injected into a TestCase
self.__enable_leak_check = False


class MemoryLeakMixin(EnableNRTStatsMixin, MemoryLeak):
def setUp(self):
super(MemoryLeakMixin, self).setUp()
self.memory_leak_setup()

def tearDown(self):
gc.collect()
self.memory_leak_teardown()
super(MemoryLeakMixin, self).tearDown()


class CheckWarningsMixin(object):
@contextlib.contextmanager
def check_warnings(self, messages, category=RuntimeWarning):
with warnings.catch_warnings(record=True) as catch:
warnings.simplefilter("always")
yield
found = 0
for w in catch:
for m in messages:
if m in str(w.message):
self.assertEqual(w.category, category)
found += 1
self.assertEqual(found, len(messages))