Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 62 additions & 68 deletions ndsl/stencils/testing/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import re
from collections.abc import Callable
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -124,10 +123,10 @@ def pytest_configure(config: pytest.Config) -> None:

@pytest.fixture()
def data_path(pytestconfig: pytest.Config) -> tuple[Path, Path]:
return data_path_and_namelist_filename_from_config(pytestconfig)
return _data_path_and_namelist_filename_from_config(pytestconfig)


def data_path_and_namelist_filename_from_config(
def _data_path_and_namelist_filename_from_config(
config: pytest.Config,
) -> tuple[Path, Path]:
data_path = Path(config.getoption("data_path"))
Expand All @@ -136,81 +135,78 @@ def data_path_and_namelist_filename_from_config(

@pytest.fixture
def threshold_overrides(pytestconfig: pytest.Config) -> dict | None:
return thresholds_from_file(pytestconfig)
return _thresholds_from_file(pytestconfig)


def thresholds_from_file(config: pytest.Config) -> dict | None:
def _thresholds_from_file(config: pytest.Config) -> dict | None:
thresholds_file = config.getoption("threshold_overrides_file")
if thresholds_file is None:
return None
return yaml.safe_load(open(thresholds_file, "r"))


def get_test_class(test_name: str) -> type | None:
def _test_class_from_name(test_name: str) -> type:
translate_class_name = f"Translate{test_name.replace('-', '_')}"
try:
return_class = getattr(translate, translate_class_name) # type: ignore[name-defined] # noqa: F821
except AttributeError as err:
if translate_class_name in err.args[0]:
return None
# raise with custom error message if translate test wasn't found
raise ValueError(
f"Could not find translate test class for test name '{test_name}'."
)
raise err
return return_class


def is_parallel_test(test_name: str) -> bool:
test_class = get_test_class(test_name)
if test_class is None:
return False
def _is_parallel(test_name: str) -> bool:
test_class = _test_class_from_name(test_name)
return issubclass(test_class, ParallelTranslate)


def get_test_class_instance(
def _is_sequential(test_name: str) -> bool:
return not _is_parallel(test_name)


def _test_class_instance(
test_name: str, grid: Grid, namelist: Namelist, stencil_factory: StencilFactory
) -> Translate:
translate_class = get_test_class(test_name)
if translate_class is None:
raise ValueError(
f"Could not find translate test class for test name '{test_name}'."
)

translate_class = _test_class_from_name(test_name)
return translate_class(grid, namelist, stencil_factory)


def get_all_savepoint_names(metafunc: Any, data_path: Path) -> set[str]:
def _all_savepoint_names(
metafunc: Any, data_path: Path, predicate: Callable[[str], bool] | None
) -> list[str]:
only_names = metafunc.config.getoption("which_modules")
if only_names is None:
names = [
fname[:-3] for fname in os.listdir(data_path) if re.match(r".*\.nc", fname)
]
savepoint_names = set([s[:-3] for s in names if s.endswith("-In")])
savepoint_names = set(
str(fname.name)[:-6] for fname in data_path.glob("*-In.nc")
)
else:
savepoint_names = set(only_names.split(","))
savepoint_names.discard("")

# Handle skipped translate tests
skip_names = metafunc.config.getoption("skip_modules")
if skip_names is not None:
savepoint_names.difference_update(skip_names.split(","))
return savepoint_names

if predicate is None:
return list(savepoint_names)

return [name for name in savepoint_names if predicate(name)]

def get_sequential_savepoint_names(metafunc: Any, data_path: Path) -> list[str]:
all_names = get_all_savepoint_names(metafunc, data_path)
sequential_names = []
for name in all_names:
if not is_parallel_test(name):
sequential_names.append(name)
return sequential_names

def _sequential_savepoint_names(metafunc: Any, data_path: Path) -> list[str]:
return _all_savepoint_names(metafunc, data_path, _is_sequential)

def get_parallel_savepoint_names(metafunc: Any, data_path: Path) -> list[str]:
all_names = get_all_savepoint_names(metafunc, data_path)
parallel_names = []
for name in all_names:
if is_parallel_test(name):
parallel_names.append(name)
return parallel_names

def _parallel_savepoint_names(metafunc: Any, data_path: Path) -> list[str]:
return _all_savepoint_names(metafunc, data_path, _is_parallel)

def get_ranks(metafunc: Any, layout: tuple[int, int]) -> list[int] | range:

def _get_ranks(metafunc: Any, layout: tuple[int, int]) -> list[int] | range:
only_rank = metafunc.config.getoption("which_rank")
if only_rank is not None:
return [int(only_rank)]
Expand All @@ -222,17 +218,17 @@ def get_ranks(metafunc: Any, layout: tuple[int, int]) -> list[int] | range:
elif topology == "cubed-sphere":
total_ranks = 6 * layout[0] * layout[1]
else:
raise NotImplementedError(f"Topology {topology} is unknown.")
raise NotImplementedError(f"Topology '{topology}' is unknown.")

return range(total_ranks)


def get_savepoint_restriction(metafunc: Any) -> int | None:
def _get_savepoint_restriction(metafunc: Any) -> int | None:
svpt = metafunc.config.getoption("which_savepoint")
return int(svpt) if svpt else None


def get_config(backend: Backend, communicator: Communicator | None) -> StencilConfig:
def _get_config(backend: Backend, communicator: Communicator | None) -> StencilConfig:
stencil_config = StencilConfig(
compilation_config=CompilationConfig(
backend=backend, rebuild=False, validate_args=True
Expand All @@ -245,16 +241,16 @@ def get_config(backend: Backend, communicator: Communicator | None) -> StencilCo
return stencil_config


def sequential_savepoint_cases(
def _sequential_savepoint_cases(
metafunc: Any, data_path: Path, namelist_filename: Path, *, backend: str
) -> list[SavepointCase]:
ndsl_backend = Backend(backend)
savepoint_names = get_sequential_savepoint_names(metafunc, data_path)
savepoint_names = _sequential_savepoint_names(metafunc, data_path)
namelist = load_f90nml(namelist_filename)
grid_params = grid_params_from_f90nml(namelist)
stencil_config = get_config(ndsl_backend, None)
ranks = get_ranks(metafunc, grid_params["layout"])
savepoint_to_replay = get_savepoint_restriction(metafunc)
stencil_config = _get_config(ndsl_backend, None)
ranks = _get_ranks(metafunc, grid_params["layout"])
savepoint_to_replay = _get_savepoint_restriction(metafunc)
grid_mode = metafunc.config.getoption("grid")
topology_mode = metafunc.config.getoption("topology")
sort_report = metafunc.config.getoption("sort_report")
Expand Down Expand Up @@ -311,7 +307,7 @@ def _savepoint_cases(
backend=backend,
).python_grid()
if grid_mode == "compute":
compute_grid_data(
_compute_grid_data(
grid, grid_params, backend, grid_params["layout"], topology_mode
)
else:
Expand All @@ -322,9 +318,7 @@ def _savepoint_cases(
grid_indexing=grid.grid_indexing,
)
for test_name in sorted(list(savepoint_names)):
testobj = get_test_class_instance(
test_name, grid, namelist, stencil_factory
)
testobj = _test_class_instance(test_name, grid, namelist, stencil_factory)
n_calls = xr.open_dataset(data_path / f"{test_name}-In.nc").sizes[
"savepoint"
]
Expand All @@ -347,7 +341,7 @@ def _savepoint_cases(
return return_list


def compute_grid_data(
def _compute_grid_data(
grid: Grid,
grid_params: dict,
backend: Backend,
Expand All @@ -358,12 +352,12 @@ def compute_grid_data(
npx=grid_params["npx"],
npy=grid_params["npy"],
npz=grid_params["npz"],
communicator=get_communicator(MPIComm(), layout, topology_mode),
communicator=_get_communicator(MPIComm(), layout, topology_mode),
backend=backend,
)


def parallel_savepoint_cases(
def _parallel_savepoint_cases(
metafunc: Any,
data_path: Path,
namelist_filename: Path,
Expand All @@ -378,11 +372,11 @@ def parallel_savepoint_cases(
topology_mode = metafunc.config.getoption("topology")
sort_report = metafunc.config.getoption("sort_report")
no_report = metafunc.config.getoption("no_report")
communicator = get_communicator(comm, grid_params["layout"], topology_mode)
stencil_config = get_config(ndsl_backend, communicator)
savepoint_names = get_parallel_savepoint_names(metafunc, data_path)
communicator = _get_communicator(comm, grid_params["layout"], topology_mode)
stencil_config = _get_config(ndsl_backend, communicator)
savepoint_names = _parallel_savepoint_names(metafunc, data_path)
grid_mode = metafunc.config.getoption("grid")
savepoint_to_replay = get_savepoint_restriction(metafunc)
savepoint_to_replay = _get_savepoint_restriction(metafunc)

return _savepoint_cases(
savepoint_names,
Expand All @@ -403,16 +397,16 @@ def pytest_generate_tests(metafunc: Any) -> None:
backend = metafunc.config.getoption("backend")
if MPI.COMM_WORLD.Get_size() > 1:
if metafunc.function.__name__ == "test_parallel_savepoint":
generate_parallel_stencil_tests(metafunc, backend=backend)
_generate_parallel_stencil_tests(metafunc, backend=backend)
elif metafunc.function.__name__ == "test_sequential_savepoint":
generate_sequential_stencil_tests(metafunc, backend=backend)
_generate_sequential_stencil_tests(metafunc, backend=backend)


def generate_sequential_stencil_tests(metafunc: Any, *, backend: str) -> None:
data_path, namelist_filename = data_path_and_namelist_filename_from_config(
def _generate_sequential_stencil_tests(metafunc: Any, *, backend: str) -> None:
data_path, namelist_filename = _data_path_and_namelist_filename_from_config(
metafunc.config
)
savepoint_cases = sequential_savepoint_cases(
savepoint_cases = _sequential_savepoint_cases(
metafunc,
data_path,
namelist_filename,
Expand All @@ -423,13 +417,13 @@ def generate_sequential_stencil_tests(metafunc: Any, *, backend: str) -> None:
)


def generate_parallel_stencil_tests(metafunc: Any, *, backend: str) -> None:
data_path, namelist_filename = data_path_and_namelist_filename_from_config(
def _generate_parallel_stencil_tests(metafunc: Any, *, backend: str) -> None:
data_path, namelist_filename = _data_path_and_namelist_filename_from_config(
metafunc.config
)
# get MPI environment
comm = MPIComm()
savepoint_cases = parallel_savepoint_cases(
savepoint_cases = _parallel_savepoint_cases(
metafunc,
data_path,
namelist_filename,
Expand All @@ -442,7 +436,7 @@ def generate_parallel_stencil_tests(metafunc: Any, *, backend: str) -> None:
)


def get_communicator(
def _get_communicator(
comm: Comm, layout: tuple[int, int], topology_mode: str
) -> Communicator:
tile_partitioner = TilePartitioner(layout)
Expand Down