Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/physics/state.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ You can initialize a zero-filled PhysicsState and MicrophysicsState from other P
... ny_tile=12,
... nz=79,
... n_halo=3,
... extra_dim_lengths={},
... layout=layout,
... tile_partitioner=partitioner.tile,
... tile_rank=communicator.tile.rank,
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def configure_domain(
ny_tile=dimensions["ny"],
nz=dimensions["nz"],
n_halo=dimensions["nhalo"],
extra_dim_lengths={},
layout=dimensions["layout"],
tile_partitioner=partitioner.tile,
tile_rank=communicator.tile.rank,
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/grid_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@
" ny_tile=ny,\n",
" nz=nz,\n",
" n_halo=nhalo,\n",
" extra_dim_lengths={},\n",
" layout=layout,\n",
" tile_partitioner=partitioner.tile,\n",
" tile_rank=communicator.tile.rank,\n",
Expand Down
97 changes: 48 additions & 49 deletions pace/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import dataclasses
import warnings
from datetime import datetime, timedelta
from typing import List, Optional, Union
from pathlib import Path

import numpy as np

from ndsl import Quantity
from ndsl.constants import RGRAV, Z_DIM, Z_INTERFACE_DIM
from ndsl.dsl.dace.orchestration import dace_inhibitor
from ndsl.dsl.typing import Float
from ndsl.filesystem import get_fs
from ndsl.grid import GridData
from ndsl.monitor import Monitor, ZarrMonitor
from ndsl.monitor.netcdf_monitor import NetCDFMonitor
Expand All @@ -27,7 +26,7 @@

class Diagnostics(abc.ABC):
@abc.abstractmethod
def store(self, time: Union[datetime, timedelta], state: DriverState): ...
def store(self, time: datetime | timedelta, state: DriverState): ...

@abc.abstractmethod
def store_grid(self, grid_data: GridData): ...
Expand All @@ -39,7 +38,7 @@ def cleanup(self): ...
@dataclasses.dataclass
class ZSelect:
level: int
names: List[str]
names: list[str]

def select_data(self, state: DycoreState):
output = {}
Expand Down Expand Up @@ -77,15 +76,15 @@ class DiagnosticsConfig:
output_format is "netcdf"
names: state variables to save as diagnostics
derived_names: derived diagnostics to save
z_select: save a veritcal slice of a 3D state
z_select: save a vertical slice of a 3D state
"""

path: Optional[str] = None
path: str | None = None
output_format: str = "zarr"
time_chunk_size: int = 1
names: List[str] = dataclasses.field(default_factory=list)
derived_names: List[str] = dataclasses.field(default_factory=list)
z_select: List[ZSelect] = dataclasses.field(default_factory=list)
names: list[str] = dataclasses.field(default_factory=list)
derived_names: list[str] = dataclasses.field(default_factory=list)
z_select: list[ZSelect] = dataclasses.field(default_factory=list)
precision: str = "Float"

def __post_init__(self):
Expand Down Expand Up @@ -113,43 +112,43 @@ def diagnostics_factory(self, communicator: Communicator) -> Diagnostics:
or to coordinate filesystem access between ranks
"""
if self.path is None:
diagnostics: Diagnostics = NullDiagnostics()
return NullDiagnostics()

if not Path(self.path).exists():
Path(self.path).mkdir()

if self.output_format == "zarr":
store = zarr_storage.DirectoryStore(path=self.path)
monitor: Monitor = ZarrMonitor(
store=store,
partitioner=communicator.partitioner,
mpi_comm=communicator.comm,
)
elif self.output_format == "netcdf":
if self.precision == "Float":
precision = Float
elif self.precision == "float32":
precision = np.float32
elif self.precision == "float64":
precision = np.float64
monitor = NetCDFMonitor(
path=self.path,
communicator=communicator,
time_chunk_size=self.time_chunk_size,
precision=precision,
)
else:
fs = get_fs(self.path)
if not fs.exists(self.path):
fs.makedirs(self.path, exist_ok=True)
if self.output_format == "zarr":
store = zarr_storage.DirectoryStore(path=self.path)
monitor: Monitor = ZarrMonitor(
store=store,
partitioner=communicator.partitioner,
mpi_comm=communicator.comm,
)
elif self.output_format == "netcdf":
if self.precision == "Float":
precision = Float
elif self.precision == "float32":
precision = np.float32
elif self.precision == "float64":
precision = np.float64
monitor = NetCDFMonitor(
path=self.path,
communicator=communicator,
time_chunk_size=self.time_chunk_size,
precision=precision,
)
else:
raise ValueError(
"output_format must be one of 'zarr' or 'netcdf', "
f"got {self.output_format}"
)
diagnostics = MonitorDiagnostics(
monitor=monitor,
names=self.names,
derived_names=self.derived_names,
z_select=self.z_select,
raise ValueError(
"output_format must be one of 'zarr' or 'netcdf', "
f"got {self.output_format}"
)
return diagnostics

return MonitorDiagnostics(
monitor=monitor,
names=self.names,
derived_names=self.derived_names,
z_select=self.z_select,
)


class MonitorDiagnostics(Diagnostics):
Expand All @@ -158,9 +157,9 @@ class MonitorDiagnostics(Diagnostics):
def __init__(
self,
monitor: Monitor,
names: List[str],
derived_names: List[str],
z_select: List[ZSelect],
names: list[str],
derived_names: list[str],
z_select: list[ZSelect],
):
"""
Args:
Expand All @@ -174,7 +173,7 @@ def __init__(
self.monitor = monitor

@dace_inhibitor
def store(self, time: Union[datetime, timedelta], state: DriverState):
def store(self, time: datetime | timedelta, state: DriverState):
monitor_state = {"time": time}
for name in self.names:
try:
Expand Down Expand Up @@ -226,7 +225,7 @@ def cleanup(self):
class NullDiagnostics(Diagnostics):
"""Diagnostics that do nothing."""

def store(self, time: Union[datetime, timedelta], state: DriverState):
def store(self, time: datetime | timedelta, state: DriverState):
pass

def store_grid(self, grid_data: GridData):
Expand Down
3 changes: 0 additions & 3 deletions pace/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def get_grid(
ny_tile=self.nx_tile,
nz=self.nz,
n_halo=N_HALO_DEFAULT,
extra_dim_lengths={},
layout=self.layout,
tile_partitioner=communicator.partitioner.tile,
tile_rank=communicator.tile.rank,
Expand Down Expand Up @@ -203,7 +202,6 @@ def get_driver_state(
ny_tile=self.nx_tile,
nz=self.nz,
n_halo=N_HALO_DEFAULT,
extra_dim_lengths={},
layout=self.layout,
tile_partitioner=communicator.partitioner.tile,
tile_rank=communicator.tile.rank,
Expand Down Expand Up @@ -767,7 +765,6 @@ def _setup_factories(
ny_tile=config.nx_tile,
nz=config.nz,
n_halo=N_HALO_DEFAULT,
extra_dim_lengths={},
layout=config.layout,
tile_partitioner=communicator.partitioner.tile,
tile_rank=communicator.tile.rank,
Expand Down
29 changes: 11 additions & 18 deletions pace/state.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import dataclasses
from dataclasses import fields
from typing import List
from pathlib import Path
from typing import Self

import xarray as xr

import ndsl.dsl.gt4py_utils as gt_utils
from ndsl import Quantity, QuantityFactory, SubtileGridSizer
from ndsl.constants import N_HALO_DEFAULT, X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import Float
from ndsl.filesystem import get_fs
from ndsl.grid import DampingCoefficients, DriverGridData, GridData
from ndsl.typing import Communicator
from pyfv3 import DycoreState
Expand All @@ -18,7 +17,7 @@
@dataclasses.dataclass()
class TendencyState:
"""
Accumulated tendencies from physical parameterizations to be applied
Accumulated tendencies from physical parametrizations to be applied
to the dynamical core model state.
"""

Expand Down Expand Up @@ -48,7 +47,7 @@ class TendencyState:
)

@classmethod
def init_zeros(cls, quantity_factory: QuantityFactory) -> "TendencyState":
def init_zeros(cls, quantity_factory: QuantityFactory) -> Self:
initial_quantities = {}
for _field in dataclasses.fields(cls):
initial_quantities[_field.name] = quantity_factory.zeros(
Expand Down Expand Up @@ -79,16 +78,15 @@ def load_state_from_restart(
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
schemes: List[PHYSICS_PACKAGES],
) -> "DriverState":
schemes: list[PHYSICS_PACKAGES],
) -> Self:
comm = driver_config.comm_config.get_comm()
communicator = Communicator.from_layout(comm=comm, layout=driver_config.layout)
sizer = SubtileGridSizer.from_tile_params(
nx_tile=driver_config.nx_tile,
ny_tile=driver_config.nx_tile,
nz=driver_config.nz,
n_halo=N_HALO_DEFAULT,
extra_dim_lengths={},
layout=driver_config.layout,
tile_partitioner=communicator.partitioner.tile,
tile_rank=communicator.tile.rank,
Expand All @@ -110,8 +108,6 @@ def load_state_from_restart(
return state

def save_state(self, comm, restart_path: str = "RESTART"):
from pathlib import Path

Path(restart_path).mkdir(parents=True, exist_ok=True)
current_rank = str(comm.Get_rank())
self.dycore_state.xr_dataset.to_netcdf(
Expand Down Expand Up @@ -165,7 +161,7 @@ def _overwrite_state_from_restart(
"""
ds = xr.open_dataset(path + f"/{restart_file_prefix}_{rank}.nc")

for _field in fields(type(state)):
for _field in dataclasses.fields(type(state)):
if "units" in _field.metadata.keys():
state.__dict__[_field.name].data[:] = gt_utils.asarray(
ds[_field.name].data[:], to_type=state.__dict__[_field.name].np.ndarray
Expand All @@ -180,14 +176,11 @@ def _restart_driver_state(
damping_coefficients: DampingCoefficients,
driver_grid_data: DriverGridData,
grid_data: GridData,
schemes: List[PHYSICS_PACKAGES],
schemes: list[PHYSICS_PACKAGES],
):
fs = get_fs(path)

restart_files = fs.ls(path)
is_fortran_restart = any(
fname.endswith("fv_core.res.nc") for fname in restart_files
)
# It's a restart from a FORTRAN run if we find any files in the restart directory
# that ends in "fv_core.res.nc".
is_fortran_restart = any(fname for fname in Path(path).glob("**/*fv_core.res.nc"))

if is_fortran_restart:
dycore_state = DycoreState.from_fortran_restart(
Expand Down
5 changes: 2 additions & 3 deletions tests/main/driver/test_analytic_init.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pathlib import Path
from typing import List

import pytest
import yaml
Expand All @@ -13,7 +12,7 @@
# TODO: Location of test configurations will be changed after refactor,
# need to update after

TESTED_CONFIGS: List[Path] = [
TESTED_CONFIGS: list[Path] = [
EXAMPLE_CONFIGS_DIR / "analytic_test.yaml",
EXAMPLE_CONFIGS_DIR / "baroclinic_c48_6ranks_serialbox_test.yaml",
]
Expand All @@ -25,7 +24,7 @@
pytest.param(TESTED_CONFIGS, id="example configs"),
],
)
def test_analytic_init_config(tested_configs: List[Path]):
def test_analytic_init_config(tested_configs: list[Path]):
for config_file in tested_configs:
with open(Path(__file__).parent / config_file, "r") as f:
config = yaml.safe_load(f)
Expand Down
2 changes: 0 additions & 2 deletions tests/main/driver/test_diagnostics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def test_zselect_raises_error_if_not_3d(tmpdir):
ny_tile=12,
nz=79,
n_halo=3,
extra_dim_lengths={},
layout=(1, 1),
),
backend="numpy",
Expand All @@ -70,7 +69,6 @@ def test_zselect_raises_error_if_3rd_dim_not_z(tmpdir):
ny_tile=12,
nz=79,
n_halo=3,
extra_dim_lengths={},
layout=(1, 1),
),
backend="numpy",
Expand Down
1 change: 0 additions & 1 deletion tests/main/driver/test_restart_fortran.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def test_state_from_fortran_restart():
ny_tile=12,
nz=63,
n_halo=3,
extra_dim_lengths={},
layout=layout,
tile_partitioner=partitioner.tile,
tile_rank=0,
Expand Down
1 change: 0 additions & 1 deletion tests/main/driver/test_restart_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def test_restart_save_to_disk():
ny_tile=12,
nz=79,
n_halo=3,
extra_dim_lengths={},
layout=(1, 1),
tile_partitioner=partitioner.tile,
tile_rank=communicator.tile.rank,
Expand Down
1 change: 0 additions & 1 deletion tests/main/fv3core/test_dycore_baroclinic.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def setup_dycore(
ny_tile=config.npy - 1,
nz=config.npz,
n_halo=3,
extra_dim_lengths={},
layout=config.layout,
tile_partitioner=partitioner.tile,
tile_rank=communicator.tile.rank,
Expand Down
1 change: 0 additions & 1 deletion tests/main/fv3core/test_dycore_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def setup_dycore() -> Tuple[DynamicalCore, DycoreState, Timer]:
ny_tile=config.npy - 1,
nz=config.npz,
n_halo=3,
extra_dim_lengths={},
layout=config.layout,
tile_partitioner=partitioner.tile,
tile_rank=communicator.tile.rank,
Expand Down
1 change: 0 additions & 1 deletion tests/main/physics/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def setup_physics():
ny_tile=physics_config.npy - 1,
nz=physics_config.npz,
n_halo=3,
extra_dim_lengths={},
layout=layout,
tile_partitioner=partitioner.tile,
tile_rank=communicator.tile.rank,
Expand Down
1 change: 0 additions & 1 deletion tests/main/test_grid_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def get_quantity_factory(layout, nx_tile, ny_tile, nz):
ny_tile=ny,
nz=nz,
n_halo=3,
extra_dim_lengths={},
layout=(1, 1),
),
backend="numpy",
Expand Down
Loading