diff --git a/numbast/src/numbast/tools/tests/conftest.py b/numbast/src/numbast/tools/tests/conftest.py index 4863a6c8..60c32239 100644 --- a/numbast/src/numbast/tools/tests/conftest.py +++ b/numbast/src/numbast/tools/tests/conftest.py @@ -14,6 +14,16 @@ from click.testing import CliRunner from numbast.tools.static_binding_generator import static_binding_generator +from numbast.static.renderer import clear_base_renderer_cache +from numbast.static.function import clear_function_apis_registry + + +@pytest.fixture(autouse=True) +def reset_state(): + """Reset global state before each test to avoid cross-test pollution.""" + clear_base_renderer_cache() + clear_function_apis_registry() + yield @pytest.fixture diff --git a/numbast/src/numbast/tools/tests/test_cli.py b/numbast/src/numbast/tools/tests/test_cli.py index f9511596..36ccb371 100644 --- a/numbast/src/numbast/tools/tests/test_cli.py +++ b/numbast/src/numbast/tools/tests/test_cli.py @@ -10,8 +10,6 @@ from numba import cuda import pytest -from numbast.static.renderer import clear_base_renderer_cache -from numbast.static.function import clear_function_apis_registry from numbast.tools.static_binding_generator import static_binding_generator @@ -44,9 +42,6 @@ def kernel(arr): ], ) def test_cli_yml_invalid_inputs(tmpdir, args, arch_str): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data.cuh") @@ -85,9 +80,6 @@ def test_cli_yml_invalid_inputs(tmpdir, args, arch_str): def test_cli_yml_inputs_full_spec(tmpdir, kernel, arch_str): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data.cuh") @@ -138,9 +130,6 @@ def test_cli_yml_inputs_full_spec(tmpdir, kernel, arch_str): "cc, expected", [("sm_70", False), ("sm_86", True), ("sm_90", True)] ) def test_cli_yml_inputs_full_spec_with_cc(tmpdir, cc, expected): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data.cuh") @@ -188,9 +177,6 @@ def test_cli_yml_inputs_full_spec_with_cc(tmpdir, cc, expected): def test_yaml_deduce_missing_types(tmpdir, kernel, arch_str): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data.cuh") @@ -237,9 +223,6 @@ def test_yaml_deduce_missing_types(tmpdir, kernel, arch_str): def test_yaml_deduce_missing_datamodels(tmpdir, kernel, arch_str): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data.cuh") @@ -287,9 +270,6 @@ def test_yaml_deduce_missing_datamodels(tmpdir, kernel, arch_str): def test_yaml_exclude_function(tmpdir, arch_str): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data.cuh") @@ -350,9 +330,6 @@ def kernel(arr): def test_yaml_exclude_function_empty_list(tmpdir, kernel, arch_str): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data.cuh") @@ -401,9 +378,6 @@ def test_yaml_exclude_function_empty_list(tmpdir, kernel, arch_str): def test_yaml_exclude_struct(tmpdir, arch_str): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data.cuh") @@ -461,9 +435,6 @@ def kernel(arr): def test_yaml_exclude_struct_empty_list(tmpdir, kernel, arch_str): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data.cuh") @@ -546,9 +517,6 @@ def kernel_fail(): def test_implit_ctor_lowering(tmpdir, implicit_conversion_kernel, arch_str): - clear_base_renderer_cache() - clear_function_apis_registry() - subdir = tmpdir.mkdir("sub") data = os.path.join(os.path.dirname(__file__), "data_ctor_lowering.cuh") diff --git a/numbast/src/numbast/tools/tests/test_symbol_exposure.py b/numbast/src/numbast/tools/tests/test_symbol_exposure.py index 1b2ceac5..8abf641f 100644 --- a/numbast/src/numbast/tools/tests/test_symbol_exposure.py +++ b/numbast/src/numbast/tools/tests/test_symbol_exposure.py @@ -5,9 +5,6 @@ import subprocess import sys -from numbast.static.renderer import clear_base_renderer_cache -from numbast.static.function import clear_function_apis_registry - from cuda.core.experimental import Device dev = Device(0) @@ -16,8 +13,6 @@ def test_symbol_exposure(run_in_isolated_folder, arch_str): """Test that only a limited set of symbols are exposed via __all__ imports.""" - clear_base_renderer_cache() - clear_function_apis_registry() res = run_in_isolated_folder( "cfg.yml.j2",