Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 60 additions & 31 deletions ndsl/stencils/testing/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import re
from pathlib import Path
from typing import Optional, Tuple

import f90nml
import pytest
import xarray as xr
import yaml
from f90nml import Namelist

from ndsl import CompilationConfig, StencilConfig, StencilFactory
from ndsl.comm.communicator import (
Expand All @@ -16,11 +17,14 @@
from ndsl.comm.mpi import MPI, MPIComm
from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner
from ndsl.dsl.dace.dace_config import DaceConfig
from ndsl.namelist import Namelist

# TODO: Remove NdslNamelist import after Issue#64 is resolved.
from ndsl.namelist import Namelist as NdslNamelist
from ndsl.stencils.testing.grid import Grid # type: ignore
from ndsl.stencils.testing.parallel_translate import ParallelTranslate
from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict
from ndsl.stencils.testing.translate import TranslateGrid
from ndsl.utils import grid_params_from_f90nml, load_f90nml


def pytest_addoption(parser):
Expand Down Expand Up @@ -73,6 +77,12 @@ def pytest_addoption(parser):
default=1,
help="How many indices of failures to print from worst to best. Default to 1.",
)
parser.addoption(
"--no_legacy_namelist",
action="store_true",
default=False,
help="Removes support for `ndsl.Namelist` in translate tests (which we are trying to get rid off, see NDSL issue #64). Defaults to False.",
)
Comment thread
jjuyeonkim marked this conversation as resolved.
parser.addoption(
"--grid",
action="store",
Expand Down Expand Up @@ -124,9 +134,9 @@ def data_path(pytestconfig):
return data_path_and_namelist_filename_from_config(pytestconfig)


def data_path_and_namelist_filename_from_config(config) -> Tuple[str, str]:
data_path = config.getoption("data_path")
namelist_filename = os.path.join(data_path, "input.nml")
def data_path_and_namelist_filename_from_config(config) -> Tuple[Path, Path]:
data_path = Path(config.getoption("data_path"))
namelist_filename = data_path / "input.nml"
return data_path, namelist_filename


Expand Down Expand Up @@ -224,10 +234,6 @@ def get_savepoint_restriction(metafunc):
return int(svpt) if svpt else None


def get_namelist(namelist_filename):
return Namelist.from_f90nml(f90nml.read(namelist_filename))


def get_config(backend: str, communicator: Optional[Communicator]):
stencil_config = StencilConfig(
compilation_config=CompilationConfig(
Expand All @@ -243,14 +249,19 @@ def get_config(backend: str, communicator: Optional[Communicator]):

def sequential_savepoint_cases(metafunc, data_path, namelist_filename, *, backend: str):
savepoint_names = get_sequential_savepoint_names(metafunc, data_path)
namelist = get_namelist(namelist_filename)
namelist = load_f90nml(namelist_filename)
grid_params = grid_params_from_f90nml(namelist)
Comment on lines +252 to +253
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me like these "grid params" could be a type (instead of an un-typed, general dict). Maybe that's a good follow-up PR or just something to think about. Might be out of scope for this PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is full on config dataclass in a follow up PR. E.g. something like:

class CartesianGridParameters
    nx: int,
    ny: int,
    nz: int,
    
class CubeSphereGridParameters:
   layout: tuple[int, int]
   cartesian_dims: CartesianGridParameters

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good follow-up. I'll put in a TODO. Also, if this PR gets approved, I'll create an issue to keep track.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stencil_config = get_config(backend, None)
ranks = get_ranks(metafunc, namelist.layout)
ranks = get_ranks(metafunc, grid_params["layout"])
savepoint_to_replay = get_savepoint_restriction(metafunc)
grid_mode = metafunc.config.getoption("grid")
topology_mode = metafunc.config.getoption("topology")
sort_report = metafunc.config.getoption("sort_report")
no_report = metafunc.config.getoption("no_report")

# Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone.
no_legacy_namelist = metafunc.config.getoption("no_legacy_namelist")

return _savepoint_cases(
savepoint_names,
ranks,
Expand All @@ -263,6 +274,7 @@ def sequential_savepoint_cases(metafunc, data_path, namelist_filename, *, backen
topology_mode,
sort_report=sort_report,
no_report=no_report,
no_legacy_namelist=no_legacy_namelist, # Issue#64: tmp flag
)


Expand All @@ -273,36 +285,38 @@ def _savepoint_cases(
stencil_config,
namelist: Namelist,
backend: str,
data_path: str,
data_path: Path,
grid_mode: str,
topology_mode: bool,
sort_report: str,
no_report: bool,
no_legacy_namelist: bool, # Issue#64: tmp flag
):
grid_params = grid_params_from_f90nml(namelist)
return_list = []
for rank in ranks:
if grid_mode == "default":
grid = Grid._make(
namelist.npx,
namelist.npy,
namelist.npz,
namelist.layout,
grid_params["npx"],
grid_params["npy"],
grid_params["npz"],
grid_params["layout"],
rank,
backend,
)
elif grid_mode == "file" or grid_mode == "compute":
ds_grid: xr.Dataset = xr.open_dataset(
os.path.join(data_path, "Grid-Info.nc")
).isel(savepoint=0)
ds_grid: xr.Dataset = xr.open_dataset(data_path / "Grid-Info.nc").isel(
savepoint=0
)
grid = TranslateGrid(
dataset_to_dict(ds_grid.isel(rank=rank)),
rank=rank,
layout=namelist.layout,
layout=grid_params["layout"],
backend=backend,
).python_grid()
if grid_mode == "compute":
compute_grid_data(
grid, namelist, backend, namelist.layout, topology_mode
grid, grid_params, backend, grid_params["layout"], topology_mode
)
else:
raise NotImplementedError(f"Grid mode {grid_mode} is unknown.")
Expand All @@ -312,12 +326,18 @@ def _savepoint_cases(
grid_indexing=grid.grid_indexing,
)
for test_name in sorted(list(savepoint_names)):
# Temporary check (Issue#64): TODO Remove check and conversion from
# f90nml.Namelist to ndsl.Namelist after ndsl.Namelist is removed
if not no_legacy_namelist: # This means we use NdslNamelist.
if not isinstance(namelist, NdslNamelist):
namelist = NdslNamelist.from_f90nml(namelist)

testobj = get_test_class_instance(
test_name, grid, namelist, stencil_factory
)
n_calls = xr.open_dataset(
os.path.join(data_path, f"{test_name}-In.nc")
).sizes["savepoint"]
n_calls = xr.open_dataset(data_path / f"{test_name}-In.nc").sizes[
"savepoint"
]
if savepoint_to_replay is not None:
savepoint_iterator = range(savepoint_to_replay, savepoint_to_replay + 1)
else:
Expand All @@ -337,11 +357,11 @@ def _savepoint_cases(
return return_list


def compute_grid_data(grid, namelist, backend, layout, topology_mode):
def compute_grid_data(grid, grid_params, backend, layout, topology_mode):
grid.make_grid_data(
npx=namelist.npx,
npy=namelist.npy,
npz=namelist.npz,
npx=grid_params["npx"],
npy=grid_params["npy"],
npz=grid_params["npz"],
communicator=get_communicator(MPIComm(), layout, topology_mode),
backend=backend,
)
Expand All @@ -350,15 +370,20 @@ def compute_grid_data(grid, namelist, backend, layout, topology_mode):
def parallel_savepoint_cases(
metafunc, data_path, namelist_filename, mpi_rank, *, backend: str, comm
):
namelist = get_namelist(namelist_filename)
namelist = load_f90nml(namelist_filename)
grid_params = grid_params_from_f90nml(namelist)
topology_mode = metafunc.config.getoption("topology")
sort_report = metafunc.config.getoption("sort_report")
no_report = metafunc.config.getoption("no_report")
communicator = get_communicator(comm, namelist.layout, topology_mode)
communicator = get_communicator(comm, grid_params["layout"], topology_mode)
stencil_config = get_config(backend, communicator)
savepoint_names = get_parallel_savepoint_names(metafunc, data_path)
grid_mode = metafunc.config.getoption("grid")
savepoint_to_replay = get_savepoint_restriction(metafunc)

# Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone.
no_legacy_namelist = metafunc.config.getoption("no_legacy_namelist")

return _savepoint_cases(
savepoint_names,
[mpi_rank],
Expand All @@ -371,6 +396,7 @@ def parallel_savepoint_cases(
topology_mode,
sort_report=sort_report,
no_report=no_report,
no_legacy_namelist=no_legacy_namelist, # Issue#64: tmp flag
)


Expand All @@ -388,7 +414,10 @@ def generate_sequential_stencil_tests(metafunc, *, backend: str):
metafunc.config
)
savepoint_cases = sequential_savepoint_cases(
metafunc, data_path, namelist_filename, backend=backend
metafunc,
data_path,
namelist_filename,
backend=backend,
)
metafunc.parametrize(
"case", savepoint_cases, ids=[str(item) for item in savepoint_cases]
Expand Down
13 changes: 12 additions & 1 deletion ndsl/stencils/testing/parallel_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

from ndsl.constants import HORIZONTAL_DIMS, N_HALO_DEFAULT, X_DIMS, Y_DIMS
from ndsl.dsl import gt4py_utils as utils

# TODO: Remove once ndsl.Namelist is gone (Issue#64)
from ndsl.namelist import Namelist as NdslNamelist
from ndsl.quantity import Quantity
from ndsl.stencils.testing.translate import (
TranslateFortranData2Py,
read_serialized_data,
)
from ndsl.utils import grid_params_from_f90nml


class ParallelTranslate:
Expand Down Expand Up @@ -129,7 +133,14 @@ def rank_grids(self):

@property
def layout(self):
return self.namelist.layout
# TODO: Once ndsl.namelist.Namelist is gone (Issue#64),
# remove this check in favor of f90nml.namelist.Namelist
if isinstance(self.namelist, NdslNamelist):
return self.namelist.layout

# Assumption: namelist is f90nml.namelist.Namelist
grid_params = grid_params_from_f90nml(self.namelist)
return grid_params["layout"]

def compute_sequential(self, inputs_list, communicator_list):
"""Compute the outputs while iterating over a set of communicator
Expand Down
20 changes: 9 additions & 11 deletions ndsl/stencils/testing/savepoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
import os
from pathlib import Path
from typing import Dict, Protocol, Union

import numpy as np
Expand All @@ -22,7 +22,7 @@ def _process_if_scalar(value: np.ndarray) -> Union[np.ndarray, float, int]:


class DataLoader:
def __init__(self, rank: int, data_path: str):
def __init__(self, rank: int, data_path: Path):
self._data_path = data_path
self._rank = rank

Expand All @@ -33,7 +33,7 @@ def load(
i_call: int = 0,
) -> Dict[str, Union[np.ndarray, float, int]]:
return dataset_to_dict(
xr.open_dataset(os.path.join(self._data_path, f"{name}{postfix}.nc"))
xr.open_dataset(self._data_path / f"{name}{postfix}.nc")
.isel(rank=self._rank)
.isel(savepoint=i_call)
)
Expand All @@ -54,7 +54,7 @@ class SavepointCase:
"""

savepoint_name: str
data_dir: str
data_dir: Path
i_call: int
testobj: Translate
grid: Grid
Expand All @@ -67,26 +67,24 @@ def __str__(self):
@property
def exists(self) -> bool:
return (
xr.open_dataset(
os.path.join(self.data_dir, f"{self.savepoint_name}-In.nc")
).sizes["rank"]
xr.open_dataset(self.data_dir / f"{self.savepoint_name}-In.nc").sizes[
"rank"
]
> self.grid.rank
)

@property
def ds_in(self) -> xr.Dataset:
return (
xr.open_dataset(os.path.join(self.data_dir, f"{self.savepoint_name}-In.nc"))
xr.open_dataset(self.data_dir / f"{self.savepoint_name}-In.nc")
.isel(rank=self.grid.rank)
.isel(savepoint=self.i_call)
)

@property
def ds_out(self) -> xr.Dataset:
return (
xr.open_dataset(
os.path.join(self.data_dir, f"{self.savepoint_name}-Out.nc")
)
xr.open_dataset(self.data_dir / f"{self.savepoint_name}-Out.nc")
.isel(rank=self.grid.rank)
.isel(savepoint=self.i_call)
)
Loading
Loading