From 8bb9f19b78034baf6bf9de574549503472abb665 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 23 Mar 2026 14:43:13 +0100 Subject: [PATCH 1/2] cleanups in translate test discovery --- ndsl/stencils/testing/conftest.py | 128 ++++++++++++++---------------- 1 file changed, 60 insertions(+), 68 deletions(-) diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index b2ceaacf..9eeabab6 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -1,5 +1,4 @@ -import os -import re +from collections.abc import Callable from pathlib import Path from typing import Any @@ -124,10 +123,10 @@ def pytest_configure(config: pytest.Config) -> None: @pytest.fixture() def data_path(pytestconfig: pytest.Config) -> tuple[Path, Path]: - return data_path_and_namelist_filename_from_config(pytestconfig) + return _data_path_and_namelist_filename_from_config(pytestconfig) -def data_path_and_namelist_filename_from_config( +def _data_path_and_namelist_filename_from_config( config: pytest.Config, ) -> tuple[Path, Path]: data_path = Path(config.getoption("data_path")) @@ -136,81 +135,76 @@ def data_path_and_namelist_filename_from_config( @pytest.fixture def threshold_overrides(pytestconfig: pytest.Config) -> dict | None: - return thresholds_from_file(pytestconfig) + return _thresholds_from_file(pytestconfig) -def thresholds_from_file(config: pytest.Config) -> dict | None: +def _thresholds_from_file(config: pytest.Config) -> dict | None: thresholds_file = config.getoption("threshold_overrides_file") if thresholds_file is None: return None return yaml.safe_load(open(thresholds_file, "r")) -def get_test_class(test_name: str) -> type | None: +def _test_class_from_name(test_name: str) -> type: translate_class_name = f"Translate{test_name.replace('-', '_')}" try: return_class = getattr(translate, translate_class_name) # type: ignore[name-defined] # noqa: F821 except AttributeError as err: if translate_class_name in err.args[0]: - return None + # raise with custom error message if translate test wasn't found + raise ValueError( + f"Could not find translate test class for test name '{test_name}'." + ) raise err return return_class -def is_parallel_test(test_name: str) -> bool: - test_class = get_test_class(test_name) - if test_class is None: - return False +def _is_parallel(test_name: str) -> bool: + test_class = _test_class_from_name(test_name) return issubclass(test_class, ParallelTranslate) -def get_test_class_instance( +def _is_sequential(test_name: str) -> bool: + return not _is_parallel(test_name) + + +def _test_class_instance( test_name: str, grid: Grid, namelist: Namelist, stencil_factory: StencilFactory ) -> Translate: - translate_class = get_test_class(test_name) - if translate_class is None: - raise ValueError( - f"Could not find translate test class for test name '{test_name}'." - ) - + translate_class = _test_class_from_name(test_name) return translate_class(grid, namelist, stencil_factory) -def get_all_savepoint_names(metafunc: Any, data_path: Path) -> set[str]: +def _all_savepoint_names( + metafunc: Any, data_path: Path, predicate: Callable[[str], bool] | None +) -> list[str]: only_names = metafunc.config.getoption("which_modules") if only_names is None: - names = [ - fname[:-3] for fname in os.listdir(data_path) if re.match(r".*\.nc", fname) - ] - savepoint_names = set([s[:-3] for s in names if s.endswith("-In")]) + savepoint_names = set(str(fname)[:-6] for fname in data_path.glob("*-In.nc")) else: savepoint_names = set(only_names.split(",")) savepoint_names.discard("") + + # Handle skipped translate tests skip_names = metafunc.config.getoption("skip_modules") if skip_names is not None: savepoint_names.difference_update(skip_names.split(",")) - return savepoint_names + if predicate is None: + return list(savepoint_names) + + return [name for name in savepoint_names if predicate(name)] -def get_sequential_savepoint_names(metafunc: Any, data_path: Path) -> list[str]: - all_names = get_all_savepoint_names(metafunc, data_path) - sequential_names = [] - for name in all_names: - if not is_parallel_test(name): - sequential_names.append(name) - return sequential_names +def _sequential_savepoint_names(metafunc: Any, data_path: Path) -> list[str]: + return _all_savepoint_names(metafunc, data_path, _is_sequential) -def get_parallel_savepoint_names(metafunc: Any, data_path: Path) -> list[str]: - all_names = get_all_savepoint_names(metafunc, data_path) - parallel_names = [] - for name in all_names: - if is_parallel_test(name): - parallel_names.append(name) - return parallel_names +def _parallel_savepoint_names(metafunc: Any, data_path: Path) -> list[str]: + return _all_savepoint_names(metafunc, data_path, _is_parallel) -def get_ranks(metafunc: Any, layout: tuple[int, int]) -> list[int] | range: + +def _get_ranks(metafunc: Any, layout: tuple[int, int]) -> list[int] | range: only_rank = metafunc.config.getoption("which_rank") if only_rank is not None: return [int(only_rank)] @@ -222,17 +216,17 @@ def get_ranks(metafunc: Any, layout: tuple[int, int]) -> list[int] | range: elif topology == "cubed-sphere": total_ranks = 6 * layout[0] * layout[1] else: - raise NotImplementedError(f"Topology {topology} is unknown.") + raise NotImplementedError(f"Topology '{topology}' is unknown.") return range(total_ranks) -def get_savepoint_restriction(metafunc: Any) -> int | None: +def _get_savepoint_restriction(metafunc: Any) -> int | None: svpt = metafunc.config.getoption("which_savepoint") return int(svpt) if svpt else None -def get_config(backend: Backend, communicator: Communicator | None) -> StencilConfig: +def _get_config(backend: Backend, communicator: Communicator | None) -> StencilConfig: stencil_config = StencilConfig( compilation_config=CompilationConfig( backend=backend, rebuild=False, validate_args=True @@ -245,16 +239,16 @@ def get_config(backend: Backend, communicator: Communicator | None) -> StencilCo return stencil_config -def sequential_savepoint_cases( +def _sequential_savepoint_cases( metafunc: Any, data_path: Path, namelist_filename: Path, *, backend: str ) -> list[SavepointCase]: ndsl_backend = Backend(backend) - savepoint_names = get_sequential_savepoint_names(metafunc, data_path) + savepoint_names = _sequential_savepoint_names(metafunc, data_path) namelist = load_f90nml(namelist_filename) grid_params = grid_params_from_f90nml(namelist) - stencil_config = get_config(ndsl_backend, None) - ranks = get_ranks(metafunc, grid_params["layout"]) - savepoint_to_replay = get_savepoint_restriction(metafunc) + stencil_config = _get_config(ndsl_backend, None) + ranks = _get_ranks(metafunc, grid_params["layout"]) + savepoint_to_replay = _get_savepoint_restriction(metafunc) grid_mode = metafunc.config.getoption("grid") topology_mode = metafunc.config.getoption("topology") sort_report = metafunc.config.getoption("sort_report") @@ -311,7 +305,7 @@ def _savepoint_cases( backend=backend, ).python_grid() if grid_mode == "compute": - compute_grid_data( + _compute_grid_data( grid, grid_params, backend, grid_params["layout"], topology_mode ) else: @@ -322,9 +316,7 @@ def _savepoint_cases( grid_indexing=grid.grid_indexing, ) for test_name in sorted(list(savepoint_names)): - testobj = get_test_class_instance( - test_name, grid, namelist, stencil_factory - ) + testobj = _test_class_instance(test_name, grid, namelist, stencil_factory) n_calls = xr.open_dataset(data_path / f"{test_name}-In.nc").sizes[ "savepoint" ] @@ -347,7 +339,7 @@ def _savepoint_cases( return return_list -def compute_grid_data( +def _compute_grid_data( grid: Grid, grid_params: dict, backend: Backend, @@ -358,12 +350,12 @@ def compute_grid_data( npx=grid_params["npx"], npy=grid_params["npy"], npz=grid_params["npz"], - communicator=get_communicator(MPIComm(), layout, topology_mode), + communicator=_get_communicator(MPIComm(), layout, topology_mode), backend=backend, ) -def parallel_savepoint_cases( +def _parallel_savepoint_cases( metafunc: Any, data_path: Path, namelist_filename: Path, @@ -378,11 +370,11 @@ def parallel_savepoint_cases( topology_mode = metafunc.config.getoption("topology") sort_report = metafunc.config.getoption("sort_report") no_report = metafunc.config.getoption("no_report") - communicator = get_communicator(comm, grid_params["layout"], topology_mode) - stencil_config = get_config(ndsl_backend, communicator) - savepoint_names = get_parallel_savepoint_names(metafunc, data_path) + communicator = _get_communicator(comm, grid_params["layout"], topology_mode) + stencil_config = _get_config(ndsl_backend, communicator) + savepoint_names = _parallel_savepoint_names(metafunc, data_path) grid_mode = metafunc.config.getoption("grid") - savepoint_to_replay = get_savepoint_restriction(metafunc) + savepoint_to_replay = _get_savepoint_restriction(metafunc) return _savepoint_cases( savepoint_names, @@ -403,16 +395,16 @@ def pytest_generate_tests(metafunc: Any) -> None: backend = metafunc.config.getoption("backend") if MPI.COMM_WORLD.Get_size() > 1: if metafunc.function.__name__ == "test_parallel_savepoint": - generate_parallel_stencil_tests(metafunc, backend=backend) + _generate_parallel_stencil_tests(metafunc, backend=backend) elif metafunc.function.__name__ == "test_sequential_savepoint": - generate_sequential_stencil_tests(metafunc, backend=backend) + _generate_sequential_stencil_tests(metafunc, backend=backend) -def generate_sequential_stencil_tests(metafunc: Any, *, backend: str) -> None: - data_path, namelist_filename = data_path_and_namelist_filename_from_config( +def _generate_sequential_stencil_tests(metafunc: Any, *, backend: str) -> None: + data_path, namelist_filename = _data_path_and_namelist_filename_from_config( metafunc.config ) - savepoint_cases = sequential_savepoint_cases( + savepoint_cases = _sequential_savepoint_cases( metafunc, data_path, namelist_filename, @@ -423,13 +415,13 @@ def generate_sequential_stencil_tests(metafunc: Any, *, backend: str) -> None: ) -def generate_parallel_stencil_tests(metafunc: Any, *, backend: str) -> None: - data_path, namelist_filename = data_path_and_namelist_filename_from_config( +def _generate_parallel_stencil_tests(metafunc: Any, *, backend: str) -> None: + data_path, namelist_filename = _data_path_and_namelist_filename_from_config( metafunc.config ) # get MPI environment comm = MPIComm() - savepoint_cases = parallel_savepoint_cases( + savepoint_cases = _parallel_savepoint_cases( metafunc, data_path, namelist_filename, @@ -442,7 +434,7 @@ def generate_parallel_stencil_tests(metafunc: Any, *, backend: str) -> None: ) -def get_communicator( +def _get_communicator( comm: Comm, layout: tuple[int, int], topology_mode: str ) -> Communicator: tile_partitioner = TilePartitioner(layout) From f78a7d304d5269aac1dc76ef373857450edfe41a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 23 Mar 2026 17:01:14 +0100 Subject: [PATCH 2/2] only take file part --- ndsl/stencils/testing/conftest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 9eeabab6..bd7d7b42 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -180,7 +180,9 @@ def _all_savepoint_names( ) -> list[str]: only_names = metafunc.config.getoption("which_modules") if only_names is None: - savepoint_names = set(str(fname)[:-6] for fname in data_path.glob("*-In.nc")) + savepoint_names = set( + str(fname.name)[:-6] for fname in data_path.glob("*-In.nc") + ) else: savepoint_names = set(only_names.split(",")) savepoint_names.discard("")