diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_init.py b/numba_cuda/numba/cuda/tests/cudadrv/test_init.py index de9eb291b..d018de5ea 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_init.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_init.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause +import concurrent.futures import multiprocessing as mp import os @@ -20,88 +21,83 @@ def cuInit_raising(arg): # not assigned until we attempt to initialize - mock.patch.object cannot locate # the non-existent original method, and so fails. Instead we patch # driver.cuInit with our raising version prior to any attempt to initialize. -def cuInit_raising_test(result_queue): +def cuInit_raising_test(): driver.cuInit = cuInit_raising - success = False - msg = None - try: # A CUDA operation that forces initialization of the device cuda.device_array(1) except CudaSupportError as e: success = True msg = e.msg + else: + success = False + msg = None - result_queue.put((success, msg)) + return success, msg # Similar to cuInit_raising_test above, but for testing that the string # returned by cuda_error() is as expected. -def initialization_error_test(result_queue): +def initialization_error_test(): driver.cuInit = cuInit_raising - success = False - msg = None - try: # A CUDA operation that forces initialization of the device cuda.device_array(1) except CudaSupportError: success = True + else: + success = False - msg = cuda.cuda_error() - result_queue.put((success, msg)) + return success, cuda.cuda_error() # For testing the path where Driver.__init__() catches a CudaSupportError -def cuda_disabled_test(result_queue): - success = False - msg = None - +def cuda_disabled_test(): try: # A CUDA operation that forces initialization of the device cuda.device_array(1) except CudaSupportError as e: success = True msg = e.msg + else: + success = False + msg = None - result_queue.put((success, msg)) + return success, msg # Similar to cuda_disabled_test, but checks cuda.cuda_error() instead of the # exception raised on initialization -def cuda_disabled_error_test(result_queue): - success = False - msg = None - +def cuda_disabled_error_test(): try: # A CUDA operation that forces initialization of the device cuda.device_array(1) except CudaSupportError: success = True + else: + success = False - msg = cuda.cuda_error() - result_queue.put((success, msg)) + return success, cuda.cuda_error() @skip_on_cudasim("CUDA Simulator does not initialize driver") class TestInit(CUDATestCase): def _test_init_failure(self, target, expected): # Run the initialization failure test in a separate subprocess - ctx = mp.get_context("spawn") - result_queue = ctx.Queue() - proc = ctx.Process(target=target, args=(result_queue,)) - proc.start() - proc.join(30) # should complete within 30s - success, msg = result_queue.get() + with concurrent.futures.ProcessPoolExecutor( + mp_context=mp.get_context("spawn") + ) as exe: + # should complete within 30s + success, msg = exe.submit(target).result(timeout=30) # Ensure the child process raised an exception during initialization # before checking the message if not success: - self.fail("CudaSupportError not raised") + assert "CudaSupportError not raised" in msg - self.assertIn(expected, msg) + assert expected in msg def test_init_failure_raising(self): expected = "Error at driver init: CUDA_ERROR_UNKNOWN (999)" diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py b/numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py index 225031069..45617729f 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py @@ -1,21 +1,18 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause +import concurrent.futures import multiprocessing import os from numba.cuda.testing import unittest -def set_visible_devices_and_check(q): - try: - from numba import cuda - import os +def set_visible_devices_and_check(): + from numba import cuda + import os - os.environ["CUDA_VISIBLE_DEVICES"] = "0" - q.put(len(cuda.gpus.lst)) - except: # noqa: E722 - # Sentinel value for error executing test code - q.put(-1) + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + return len(cuda.gpus.lst) class TestVisibleDevices(unittest.TestCase): @@ -38,22 +35,13 @@ def test_visible_devices_set_after_import(self): msg = "Cannot test when CUDA_VISIBLE_DEVICES already set" self.skipTest(msg) - ctx = multiprocessing.get_context("spawn") - q = ctx.Queue() - p = ctx.Process(target=set_visible_devices_and_check, args=(q,)) - p.start() - try: - visible_gpu_count = q.get() - finally: - p.join() - - # Make an obvious distinction between an error running the test code - # and an incorrect number of GPUs in the list - msg = "Error running set_visible_devices_and_check" - self.assertNotEqual(visible_gpu_count, -1, msg=msg) - - # The actual check that we see only one GPU - self.assertEqual(visible_gpu_count, 1) + with concurrent.futures.ProcessPoolExecutor( + mp_context=multiprocessing.get_context("spawn") + ) as exe: + future = exe.submit(set_visible_devices_and_check) + + visible_gpu_count = future.result() + assert visible_gpu_count == 1 if __name__ == "__main__": diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py b/numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py index c632de938..c9c7dd7c5 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py @@ -3,6 +3,8 @@ import os import multiprocessing as mp +import pytest +import concurrent.futures import numpy as np @@ -10,40 +12,27 @@ from numba.cuda.testing import skip_on_cudasim, CUDATestCase import unittest -has_mp_get_context = hasattr(mp, "get_context") -is_unix = os.name == "posix" - - -def fork_test(q): - from numba.cuda.cudadrv.error import CudaDriverError - - try: - cuda.to_device(np.arange(1)) - except CudaDriverError as e: - q.put(e) - else: - q.put(None) - @skip_on_cudasim("disabled for cudasim") class TestMultiprocessing(CUDATestCase): - @unittest.skipUnless(has_mp_get_context, "requires mp.get_context") - @unittest.skipUnless(is_unix, "requires Unix") + @unittest.skipUnless(hasattr(mp, "get_context"), "requires mp.get_context") + @unittest.skipUnless(os.name == "posix", "requires Unix") def test_fork(self): """ Test fork detection. """ + from numba.cuda.cudadrv.error import CudaDriverError + cuda.current_context() # force cuda initialize - # fork in process that also uses CUDA - ctx = mp.get_context("fork") - q = ctx.Queue() - proc = ctx.Process(target=fork_test, args=[q]) - proc.start() - exc = q.get() - proc.join() - # there should be an exception raised in the child process - self.assertIsNotNone(exc) - self.assertIn("CUDA initialized before forking", str(exc)) + with concurrent.futures.ProcessPoolExecutor( + mp_context=mp.get_context("fork") + ) as exe: + future = exe.submit(cuda.to_device, np.arange(1)) + + with pytest.raises( + CudaDriverError, match="CUDA initialized before forking" + ): + future.result() if __name__ == "__main__": diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py b/numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py index 26834e079..d432d2939 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -import traceback import threading import multiprocessing import numpy as np @@ -13,12 +12,7 @@ ) import unittest -try: - from concurrent.futures import ThreadPoolExecutor -except ImportError: - has_concurrent_futures = False -else: - has_concurrent_futures = True +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor has_mp_get_context = hasattr(multiprocessing, "get_context") @@ -41,21 +35,9 @@ def use_foo(x): np.testing.assert_equal(ary, expected) -def spawn_process_entry(q): - try: - check_concurrent_compiling() - # Catch anything that goes wrong in the threads - except: # noqa: E722 - msg = traceback.format_exc() - q.put("\n".join(["", "=" * 80, msg])) - else: - q.put(None) - - @skip_under_cuda_memcheck("Hangs cuda-memcheck") @skip_on_cudasim("disabled for cudasim") class TestMultiThreadCompiling(CUDATestCase): - @unittest.skipIf(not has_concurrent_futures, "no concurrent.futures") def test_concurrent_compiling(self): check_concurrent_compiling() @@ -63,19 +45,13 @@ def test_concurrent_compiling(self): def test_spawn_concurrent_compilation(self): # force CUDA context init cuda.get_current_device() - # use "spawn" to avoid inheriting the CUDA context - ctx = multiprocessing.get_context("spawn") - - q = ctx.Queue() - p = ctx.Process(target=spawn_process_entry, args=(q,)) - p.start() - try: - err = q.get() - finally: - p.join() - if err is not None: - raise AssertionError(err) - self.assertEqual(p.exitcode, 0, "test failed in child process") + + with ProcessPoolExecutor( + # use "spawn" to avoid inheriting the CUDA context + mp_context=multiprocessing.get_context("spawn") + ) as exe: + future = exe.submit(check_concurrent_compiling) + future.result() def test_invalid_context_error_with_d2h(self): def d2h(arr, out): @@ -83,10 +59,10 @@ def d2h(arr, out): arr = np.arange(1, 4) out = np.zeros_like(arr) - darr = cuda.to_device(arr) - th = threading.Thread(target=d2h, args=[darr, out]) - th.start() - th.join() + + with ThreadPoolExecutor() as exe: + exe.submit(d2h, cuda.to_device(arr), out) + np.testing.assert_equal(arr, out) def test_invalid_context_error_with_d2d(self): diff --git a/numba_cuda/numba/cuda/tests/support.py b/numba_cuda/numba/cuda/tests/support.py index 1a7afeb2b..7f8f254c2 100644 --- a/numba_cuda/numba/cuda/tests/support.py +++ b/numba_cuda/numba/cuda/tests/support.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause import cmath +import concurrent.futures import contextlib import enum import gc @@ -206,18 +207,6 @@ def run_in_subprocess(code, flags=None, env=None, timeout=30): return out, err -@contextlib.contextmanager -def captured_output(stream_name): - """Return a context manager used by captured_stdout/stdin/stderr - that temporarily replaces the sys stream *stream_name* with a StringIO.""" - orig_stdout = getattr(sys, stream_name) - setattr(sys, stream_name, io.StringIO()) - try: - yield getattr(sys, stream_name) - finally: - setattr(sys, stream_name, orig_stdout) - - def captured_stdout(): """Capture the output of sys.stdout: @@ -225,7 +214,7 @@ def captured_stdout(): print("hello") self.assertEqual(stdout.getvalue(), "hello\n") """ - return captured_output("stdout") + return contextlib.redirect_stdout(io.StringIO()) def captured_stderr(): @@ -235,7 +224,7 @@ def captured_stderr(): print("hello", file=sys.stderr) self.assertEqual(stderr.getvalue(), "hello\n") """ - return captured_output("stderr") + return contextlib.redirect_stderr(io.StringIO()) class TestCase(unittest.TestCase): @@ -878,40 +867,33 @@ def run_in_new_process_in_cache_dir(func, cache_dir, verbose=True): stdout: str stderr: str """ - ctx = mp.get_context("spawn") - qout = ctx.Queue() with override_env_config("NUMBA_CACHE_DIR", cache_dir): - proc = ctx.Process(target=_remote_runner, args=[func, qout]) - proc.start() - proc.join() - stdout = qout.get_nowait() - stderr = qout.get_nowait() - if verbose and stdout.strip(): - print() - print("STDOUT".center(80, "-")) - print(stdout) - if verbose and stderr.strip(): - print(file=sys.stderr) - print("STDERR".center(80, "-"), file=sys.stderr) - print(stderr, file=sys.stderr) - return { - "exitcode": proc.exitcode, - "stdout": stdout, - "stderr": stderr, - } + with concurrent.futures.ProcessPoolExecutor( + mp_context=mp.get_context("spawn") + ) as exe: + future = exe.submit(_remote_runner, func) + + stdout, stderr, exitcode = future.result() + if verbose: + if stdout: + print() + print("STDOUT".center(80, "-")) + print(stdout) + if stderr: + print(file=sys.stderr) + print("STDERR".center(80, "-"), file=sys.stderr) + print(stderr, file=sys.stderr) + return {"exitcode": exitcode, "stdout": stdout, "stderr": stderr} def _remote_runner(fn, qout): """Used by `run_in_new_process_caching()`""" - with captured_stderr() as stderr: - with captured_stdout() as stdout: - try: - fn() - except Exception: - traceback.print_exc() - exitcode = 1 - else: - exitcode = 0 - qout.put(stdout.getvalue()) - qout.put(stderr.getvalue()) - sys.exit(exitcode) + with captured_stderr() as stderr, captured_stdout() as stdout: + try: + fn() + except Exception: + traceback.print_exc(file=sys.stderr) + exitcode = 1 + else: + exitcode = 0 + return stdout.getvalue().strip(), stderr.getvalue().strip(), exitcode