Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions numbast/src/numbast/tools/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@
from click.testing import CliRunner

from numbast.tools.static_binding_generator import static_binding_generator
from numbast.static.renderer import clear_base_renderer_cache
from numbast.static.function import clear_function_apis_registry


@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test to avoid cross-test pollution."""
clear_base_renderer_cache()
clear_function_apis_registry()
yield


@pytest.fixture
Expand Down
32 changes: 0 additions & 32 deletions numbast/src/numbast/tools/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from numba import cuda
import pytest

from numbast.static.renderer import clear_base_renderer_cache
from numbast.static.function import clear_function_apis_registry
from numbast.tools.static_binding_generator import static_binding_generator


Expand Down Expand Up @@ -44,9 +42,6 @@ def kernel(arr):
],
)
def test_cli_yml_invalid_inputs(tmpdir, args, arch_str):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data.cuh")

Expand Down Expand Up @@ -85,9 +80,6 @@ def test_cli_yml_invalid_inputs(tmpdir, args, arch_str):


def test_cli_yml_inputs_full_spec(tmpdir, kernel, arch_str):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data.cuh")

Expand Down Expand Up @@ -138,9 +130,6 @@ def test_cli_yml_inputs_full_spec(tmpdir, kernel, arch_str):
"cc, expected", [("sm_70", False), ("sm_86", True), ("sm_90", True)]
)
def test_cli_yml_inputs_full_spec_with_cc(tmpdir, cc, expected):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data.cuh")

Expand Down Expand Up @@ -188,9 +177,6 @@ def test_cli_yml_inputs_full_spec_with_cc(tmpdir, cc, expected):


def test_yaml_deduce_missing_types(tmpdir, kernel, arch_str):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data.cuh")

Expand Down Expand Up @@ -237,9 +223,6 @@ def test_yaml_deduce_missing_types(tmpdir, kernel, arch_str):


def test_yaml_deduce_missing_datamodels(tmpdir, kernel, arch_str):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data.cuh")

Expand Down Expand Up @@ -287,9 +270,6 @@ def test_yaml_deduce_missing_datamodels(tmpdir, kernel, arch_str):


def test_yaml_exclude_function(tmpdir, arch_str):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data.cuh")

Expand Down Expand Up @@ -350,9 +330,6 @@ def kernel(arr):


def test_yaml_exclude_function_empty_list(tmpdir, kernel, arch_str):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data.cuh")

Expand Down Expand Up @@ -401,9 +378,6 @@ def test_yaml_exclude_function_empty_list(tmpdir, kernel, arch_str):


def test_yaml_exclude_struct(tmpdir, arch_str):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data.cuh")

Expand Down Expand Up @@ -461,9 +435,6 @@ def kernel(arr):


def test_yaml_exclude_struct_empty_list(tmpdir, kernel, arch_str):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data.cuh")

Expand Down Expand Up @@ -546,9 +517,6 @@ def kernel_fail():


def test_implit_ctor_lowering(tmpdir, implicit_conversion_kernel, arch_str):
clear_base_renderer_cache()
clear_function_apis_registry()

subdir = tmpdir.mkdir("sub")
data = os.path.join(os.path.dirname(__file__), "data_ctor_lowering.cuh")

Expand Down
5 changes: 0 additions & 5 deletions numbast/src/numbast/tools/tests/test_symbol_exposure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import subprocess
import sys

from numbast.static.renderer import clear_base_renderer_cache
from numbast.static.function import clear_function_apis_registry

from cuda.core.experimental import Device

dev = Device(0)
Expand All @@ -16,8 +13,6 @@

def test_symbol_exposure(run_in_isolated_folder, arch_str):
"""Test that only a limited set of symbols are exposed via __all__ imports."""
clear_base_renderer_cache()
clear_function_apis_registry()

res = run_in_isolated_folder(
"cfg.yml.j2",
Expand Down
Loading