-
Notifications
You must be signed in to change notification settings - Fork 16
Namelist Refactor: Utility functions for namelist to dict + conftest --no_legacy_namelist #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8b78cb0
3311101
17cfd67
a721999
03f87d8
a4f273f
6f796f6
b2d969f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,12 @@ | ||
| import os | ||
| import re | ||
| from pathlib import Path | ||
| from typing import Optional, Tuple | ||
|
|
||
| import f90nml | ||
| import pytest | ||
| import xarray as xr | ||
| import yaml | ||
| from f90nml import Namelist | ||
|
|
||
| from ndsl import CompilationConfig, StencilConfig, StencilFactory | ||
| from ndsl.comm.communicator import ( | ||
|
|
@@ -16,11 +17,14 @@ | |
| from ndsl.comm.mpi import MPI, MPIComm | ||
| from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner | ||
| from ndsl.dsl.dace.dace_config import DaceConfig | ||
| from ndsl.namelist import Namelist | ||
|
|
||
| # TODO: Remove NdslNamelist import after Issue#64 is resolved. | ||
| from ndsl.namelist import Namelist as NdslNamelist | ||
| from ndsl.stencils.testing.grid import Grid # type: ignore | ||
| from ndsl.stencils.testing.parallel_translate import ParallelTranslate | ||
| from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict | ||
| from ndsl.stencils.testing.translate import TranslateGrid | ||
| from ndsl.utils import grid_params_from_f90nml, load_f90nml | ||
|
|
||
|
|
||
| def pytest_addoption(parser): | ||
|
|
@@ -73,6 +77,12 @@ def pytest_addoption(parser): | |
| default=1, | ||
| help="How many indices of failures to print from worst to best. Default to 1.", | ||
| ) | ||
| parser.addoption( | ||
| "--no_legacy_namelist", | ||
| action="store_true", | ||
| default=False, | ||
| help="Removes support for `ndsl.Namelist` in translate tests (which we are trying to get rid off, see NDSL issue #64). Defaults to False.", | ||
| ) | ||
| parser.addoption( | ||
| "--grid", | ||
| action="store", | ||
|
|
@@ -124,9 +134,9 @@ def data_path(pytestconfig): | |
| return data_path_and_namelist_filename_from_config(pytestconfig) | ||
|
|
||
|
|
||
| def data_path_and_namelist_filename_from_config(config) -> Tuple[str, str]: | ||
| data_path = config.getoption("data_path") | ||
| namelist_filename = os.path.join(data_path, "input.nml") | ||
| def data_path_and_namelist_filename_from_config(config) -> Tuple[Path, Path]: | ||
| data_path = Path(config.getoption("data_path")) | ||
| namelist_filename = data_path / "input.nml" | ||
| return data_path, namelist_filename | ||
|
|
||
|
|
||
|
|
@@ -224,10 +234,6 @@ def get_savepoint_restriction(metafunc): | |
| return int(svpt) if svpt else None | ||
|
|
||
|
|
||
| def get_namelist(namelist_filename): | ||
| return Namelist.from_f90nml(f90nml.read(namelist_filename)) | ||
|
|
||
|
|
||
| def get_config(backend: str, communicator: Optional[Communicator]): | ||
| stencil_config = StencilConfig( | ||
| compilation_config=CompilationConfig( | ||
|
|
@@ -243,14 +249,19 @@ def get_config(backend: str, communicator: Optional[Communicator]): | |
|
|
||
| def sequential_savepoint_cases(metafunc, data_path, namelist_filename, *, backend: str): | ||
| savepoint_names = get_sequential_savepoint_names(metafunc, data_path) | ||
| namelist = get_namelist(namelist_filename) | ||
| namelist = load_f90nml(namelist_filename) | ||
| grid_params = grid_params_from_f90nml(namelist) | ||
|
Comment on lines
+252
to
+253
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems to me like these "grid params" could be a type (instead of an un-typed, general
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is full on config dataclass in a follow up PR. E.g. something like: class CartesianGridParameters
nx: int,
ny: int,
nz: int,
class CubeSphereGridParameters:
layout: tuple[int, int]
cartesian_dims: CartesianGridParameters
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds like a good follow-up. I'll put in a TODO. Also, if this PR gets approved, I'll create an issue to keep track.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| stencil_config = get_config(backend, None) | ||
| ranks = get_ranks(metafunc, namelist.layout) | ||
| ranks = get_ranks(metafunc, grid_params["layout"]) | ||
| savepoint_to_replay = get_savepoint_restriction(metafunc) | ||
| grid_mode = metafunc.config.getoption("grid") | ||
| topology_mode = metafunc.config.getoption("topology") | ||
| sort_report = metafunc.config.getoption("sort_report") | ||
| no_report = metafunc.config.getoption("no_report") | ||
|
|
||
| # Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone. | ||
| no_legacy_namelist = metafunc.config.getoption("no_legacy_namelist") | ||
|
|
||
| return _savepoint_cases( | ||
| savepoint_names, | ||
| ranks, | ||
|
|
@@ -263,6 +274,7 @@ def sequential_savepoint_cases(metafunc, data_path, namelist_filename, *, backen | |
| topology_mode, | ||
| sort_report=sort_report, | ||
| no_report=no_report, | ||
| no_legacy_namelist=no_legacy_namelist, # Issue#64: tmp flag | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -273,36 +285,38 @@ def _savepoint_cases( | |
| stencil_config, | ||
| namelist: Namelist, | ||
| backend: str, | ||
| data_path: str, | ||
| data_path: Path, | ||
| grid_mode: str, | ||
| topology_mode: bool, | ||
| sort_report: str, | ||
| no_report: bool, | ||
| no_legacy_namelist: bool, # Issue#64: tmp flag | ||
| ): | ||
| grid_params = grid_params_from_f90nml(namelist) | ||
| return_list = [] | ||
| for rank in ranks: | ||
| if grid_mode == "default": | ||
| grid = Grid._make( | ||
| namelist.npx, | ||
| namelist.npy, | ||
| namelist.npz, | ||
| namelist.layout, | ||
| grid_params["npx"], | ||
| grid_params["npy"], | ||
| grid_params["npz"], | ||
| grid_params["layout"], | ||
| rank, | ||
| backend, | ||
| ) | ||
| elif grid_mode == "file" or grid_mode == "compute": | ||
| ds_grid: xr.Dataset = xr.open_dataset( | ||
| os.path.join(data_path, "Grid-Info.nc") | ||
| ).isel(savepoint=0) | ||
| ds_grid: xr.Dataset = xr.open_dataset(data_path / "Grid-Info.nc").isel( | ||
| savepoint=0 | ||
| ) | ||
| grid = TranslateGrid( | ||
| dataset_to_dict(ds_grid.isel(rank=rank)), | ||
| rank=rank, | ||
| layout=namelist.layout, | ||
| layout=grid_params["layout"], | ||
| backend=backend, | ||
| ).python_grid() | ||
| if grid_mode == "compute": | ||
| compute_grid_data( | ||
| grid, namelist, backend, namelist.layout, topology_mode | ||
| grid, grid_params, backend, grid_params["layout"], topology_mode | ||
| ) | ||
| else: | ||
| raise NotImplementedError(f"Grid mode {grid_mode} is unknown.") | ||
|
|
@@ -312,12 +326,18 @@ def _savepoint_cases( | |
| grid_indexing=grid.grid_indexing, | ||
| ) | ||
| for test_name in sorted(list(savepoint_names)): | ||
| # Temporary check (Issue#64): TODO Remove check and conversion from | ||
| # f90nml.Namelist to ndsl.Namelist after ndsl.Namelist is removed | ||
| if not no_legacy_namelist: # This means we use NdslNamelist. | ||
| if not isinstance(namelist, NdslNamelist): | ||
| namelist = NdslNamelist.from_f90nml(namelist) | ||
|
|
||
| testobj = get_test_class_instance( | ||
| test_name, grid, namelist, stencil_factory | ||
| ) | ||
| n_calls = xr.open_dataset( | ||
| os.path.join(data_path, f"{test_name}-In.nc") | ||
| ).sizes["savepoint"] | ||
| n_calls = xr.open_dataset(data_path / f"{test_name}-In.nc").sizes[ | ||
| "savepoint" | ||
| ] | ||
| if savepoint_to_replay is not None: | ||
| savepoint_iterator = range(savepoint_to_replay, savepoint_to_replay + 1) | ||
| else: | ||
|
|
@@ -337,11 +357,11 @@ def _savepoint_cases( | |
| return return_list | ||
|
|
||
|
|
||
| def compute_grid_data(grid, namelist, backend, layout, topology_mode): | ||
| def compute_grid_data(grid, grid_params, backend, layout, topology_mode): | ||
| grid.make_grid_data( | ||
| npx=namelist.npx, | ||
| npy=namelist.npy, | ||
| npz=namelist.npz, | ||
| npx=grid_params["npx"], | ||
| npy=grid_params["npy"], | ||
| npz=grid_params["npz"], | ||
| communicator=get_communicator(MPIComm(), layout, topology_mode), | ||
| backend=backend, | ||
| ) | ||
|
|
@@ -350,15 +370,20 @@ def compute_grid_data(grid, namelist, backend, layout, topology_mode): | |
| def parallel_savepoint_cases( | ||
| metafunc, data_path, namelist_filename, mpi_rank, *, backend: str, comm | ||
| ): | ||
| namelist = get_namelist(namelist_filename) | ||
| namelist = load_f90nml(namelist_filename) | ||
| grid_params = grid_params_from_f90nml(namelist) | ||
| topology_mode = metafunc.config.getoption("topology") | ||
| sort_report = metafunc.config.getoption("sort_report") | ||
| no_report = metafunc.config.getoption("no_report") | ||
| communicator = get_communicator(comm, namelist.layout, topology_mode) | ||
| communicator = get_communicator(comm, grid_params["layout"], topology_mode) | ||
| stencil_config = get_config(backend, communicator) | ||
| savepoint_names = get_parallel_savepoint_names(metafunc, data_path) | ||
| grid_mode = metafunc.config.getoption("grid") | ||
| savepoint_to_replay = get_savepoint_restriction(metafunc) | ||
|
|
||
| # Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone. | ||
| no_legacy_namelist = metafunc.config.getoption("no_legacy_namelist") | ||
|
|
||
| return _savepoint_cases( | ||
| savepoint_names, | ||
| [mpi_rank], | ||
|
|
@@ -371,6 +396,7 @@ def parallel_savepoint_cases( | |
| topology_mode, | ||
| sort_report=sort_report, | ||
| no_report=no_report, | ||
| no_legacy_namelist=no_legacy_namelist, # Issue#64: tmp flag | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -388,7 +414,10 @@ def generate_sequential_stencil_tests(metafunc, *, backend: str): | |
| metafunc.config | ||
| ) | ||
| savepoint_cases = sequential_savepoint_cases( | ||
| metafunc, data_path, namelist_filename, backend=backend | ||
| metafunc, | ||
| data_path, | ||
| namelist_filename, | ||
| backend=backend, | ||
| ) | ||
| metafunc.parametrize( | ||
| "case", savepoint_cases, ids=[str(item) for item in savepoint_cases] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.