diff --git a/docs/physics/state.rst b/docs/physics/state.rst index 1b603887..e56dfb07 100644 --- a/docs/physics/state.rst +++ b/docs/physics/state.rst @@ -30,7 +30,6 @@ You can initialize a zero-filled PhysicsState and MicrophysicsState from other P ... ny_tile=12, ... nz=79, ... n_halo=3, - ... extra_dim_lengths={}, ... layout=layout, ... tile_partitioner=partitioner.tile, ... tile_rank=communicator.tile.rank, diff --git a/examples/notebooks/functions.py b/examples/notebooks/functions.py index c603d09d..0c81ff58 100644 --- a/examples/notebooks/functions.py +++ b/examples/notebooks/functions.py @@ -263,7 +263,6 @@ def configure_domain( ny_tile=dimensions["ny"], nz=dimensions["nz"], n_halo=dimensions["nhalo"], - extra_dim_lengths={}, layout=dimensions["layout"], tile_partitioner=partitioner.tile, tile_rank=communicator.tile.rank, diff --git a/examples/notebooks/grid_generation.ipynb b/examples/notebooks/grid_generation.ipynb index 2f56fa34..eb331536 100644 --- a/examples/notebooks/grid_generation.ipynb +++ b/examples/notebooks/grid_generation.ipynb @@ -182,7 +182,6 @@ " ny_tile=ny,\n", " nz=nz,\n", " n_halo=nhalo,\n", - " extra_dim_lengths={},\n", " layout=layout,\n", " tile_partitioner=partitioner.tile,\n", " tile_rank=communicator.tile.rank,\n", diff --git a/pace/diagnostics.py b/pace/diagnostics.py index 8850bc08..35895e03 100644 --- a/pace/diagnostics.py +++ b/pace/diagnostics.py @@ -2,7 +2,7 @@ import dataclasses import warnings from datetime import datetime, timedelta -from typing import List, Optional, Union +from pathlib import Path import numpy as np @@ -10,7 +10,6 @@ from ndsl.constants import RGRAV, Z_DIM, Z_INTERFACE_DIM from ndsl.dsl.dace.orchestration import dace_inhibitor from ndsl.dsl.typing import Float -from ndsl.filesystem import get_fs from ndsl.grid import GridData from ndsl.monitor import Monitor, ZarrMonitor from ndsl.monitor.netcdf_monitor import NetCDFMonitor @@ -27,7 +26,7 @@ class Diagnostics(abc.ABC): @abc.abstractmethod - def store(self, time: Union[datetime, timedelta], state: DriverState): ... + def store(self, time: datetime | timedelta, state: DriverState): ... @abc.abstractmethod def store_grid(self, grid_data: GridData): ... @@ -39,7 +38,7 @@ def cleanup(self): ... @dataclasses.dataclass class ZSelect: level: int - names: List[str] + names: list[str] def select_data(self, state: DycoreState): output = {} @@ -77,15 +76,15 @@ class DiagnosticsConfig: output_format is "netcdf" names: state variables to save as diagnostics derived_names: derived diagnostics to save - z_select: save a veritcal slice of a 3D state + z_select: save a vertical slice of a 3D state """ - path: Optional[str] = None + path: str | None = None output_format: str = "zarr" time_chunk_size: int = 1 - names: List[str] = dataclasses.field(default_factory=list) - derived_names: List[str] = dataclasses.field(default_factory=list) - z_select: List[ZSelect] = dataclasses.field(default_factory=list) + names: list[str] = dataclasses.field(default_factory=list) + derived_names: list[str] = dataclasses.field(default_factory=list) + z_select: list[ZSelect] = dataclasses.field(default_factory=list) precision: str = "Float" def __post_init__(self): @@ -113,43 +112,43 @@ def diagnostics_factory(self, communicator: Communicator) -> Diagnostics: or to coordinate filesystem access between ranks """ if self.path is None: - diagnostics: Diagnostics = NullDiagnostics() + return NullDiagnostics() + + if not Path(self.path).exists(): + Path(self.path).mkdir() + + if self.output_format == "zarr": + store = zarr_storage.DirectoryStore(path=self.path) + monitor: Monitor = ZarrMonitor( + store=store, + partitioner=communicator.partitioner, + mpi_comm=communicator.comm, + ) + elif self.output_format == "netcdf": + if self.precision == "Float": + precision = Float + elif self.precision == "float32": + precision = np.float32 + elif self.precision == "float64": + precision = np.float64 + monitor = NetCDFMonitor( + path=self.path, + communicator=communicator, + time_chunk_size=self.time_chunk_size, + precision=precision, + ) else: - fs = get_fs(self.path) - if not fs.exists(self.path): - fs.makedirs(self.path, exist_ok=True) - if self.output_format == "zarr": - store = zarr_storage.DirectoryStore(path=self.path) - monitor: Monitor = ZarrMonitor( - store=store, - partitioner=communicator.partitioner, - mpi_comm=communicator.comm, - ) - elif self.output_format == "netcdf": - if self.precision == "Float": - precision = Float - elif self.precision == "float32": - precision = np.float32 - elif self.precision == "float64": - precision = np.float64 - monitor = NetCDFMonitor( - path=self.path, - communicator=communicator, - time_chunk_size=self.time_chunk_size, - precision=precision, - ) - else: - raise ValueError( - "output_format must be one of 'zarr' or 'netcdf', " - f"got {self.output_format}" - ) - diagnostics = MonitorDiagnostics( - monitor=monitor, - names=self.names, - derived_names=self.derived_names, - z_select=self.z_select, + raise ValueError( + "output_format must be one of 'zarr' or 'netcdf', " + f"got {self.output_format}" ) - return diagnostics + + return MonitorDiagnostics( + monitor=monitor, + names=self.names, + derived_names=self.derived_names, + z_select=self.z_select, + ) class MonitorDiagnostics(Diagnostics): @@ -158,9 +157,9 @@ class MonitorDiagnostics(Diagnostics): def __init__( self, monitor: Monitor, - names: List[str], - derived_names: List[str], - z_select: List[ZSelect], + names: list[str], + derived_names: list[str], + z_select: list[ZSelect], ): """ Args: @@ -174,7 +173,7 @@ def __init__( self.monitor = monitor @dace_inhibitor - def store(self, time: Union[datetime, timedelta], state: DriverState): + def store(self, time: datetime | timedelta, state: DriverState): monitor_state = {"time": time} for name in self.names: try: @@ -226,7 +225,7 @@ def cleanup(self): class NullDiagnostics(Diagnostics): """Diagnostics that do nothing.""" - def store(self, time: Union[datetime, timedelta], state: DriverState): + def store(self, time: datetime | timedelta, state: DriverState): pass def store_grid(self, grid_data: GridData): diff --git a/pace/driver.py b/pace/driver.py index 6ea2d40f..fdc44a21 100644 --- a/pace/driver.py +++ b/pace/driver.py @@ -173,7 +173,6 @@ def get_grid( ny_tile=self.nx_tile, nz=self.nz, n_halo=N_HALO_DEFAULT, - extra_dim_lengths={}, layout=self.layout, tile_partitioner=communicator.partitioner.tile, tile_rank=communicator.tile.rank, @@ -203,7 +202,6 @@ def get_driver_state( ny_tile=self.nx_tile, nz=self.nz, n_halo=N_HALO_DEFAULT, - extra_dim_lengths={}, layout=self.layout, tile_partitioner=communicator.partitioner.tile, tile_rank=communicator.tile.rank, @@ -767,7 +765,6 @@ def _setup_factories( ny_tile=config.nx_tile, nz=config.nz, n_halo=N_HALO_DEFAULT, - extra_dim_lengths={}, layout=config.layout, tile_partitioner=communicator.partitioner.tile, tile_rank=communicator.tile.rank, diff --git a/pace/state.py b/pace/state.py index fa3370b6..9caa1dc6 100644 --- a/pace/state.py +++ b/pace/state.py @@ -1,6 +1,6 @@ import dataclasses -from dataclasses import fields -from typing import List +from pathlib import Path +from typing import Self import xarray as xr @@ -8,7 +8,6 @@ from ndsl import Quantity, QuantityFactory, SubtileGridSizer from ndsl.constants import N_HALO_DEFAULT, X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import Float -from ndsl.filesystem import get_fs from ndsl.grid import DampingCoefficients, DriverGridData, GridData from ndsl.typing import Communicator from pyfv3 import DycoreState @@ -18,7 +17,7 @@ @dataclasses.dataclass() class TendencyState: """ - Accumulated tendencies from physical parameterizations to be applied + Accumulated tendencies from physical parametrizations to be applied to the dynamical core model state. """ @@ -48,7 +47,7 @@ class TendencyState: ) @classmethod - def init_zeros(cls, quantity_factory: QuantityFactory) -> "TendencyState": + def init_zeros(cls, quantity_factory: QuantityFactory) -> Self: initial_quantities = {} for _field in dataclasses.fields(cls): initial_quantities[_field.name] = quantity_factory.zeros( @@ -79,8 +78,8 @@ def load_state_from_restart( damping_coefficients: DampingCoefficients, driver_grid_data: DriverGridData, grid_data: GridData, - schemes: List[PHYSICS_PACKAGES], - ) -> "DriverState": + schemes: list[PHYSICS_PACKAGES], + ) -> Self: comm = driver_config.comm_config.get_comm() communicator = Communicator.from_layout(comm=comm, layout=driver_config.layout) sizer = SubtileGridSizer.from_tile_params( @@ -88,7 +87,6 @@ def load_state_from_restart( ny_tile=driver_config.nx_tile, nz=driver_config.nz, n_halo=N_HALO_DEFAULT, - extra_dim_lengths={}, layout=driver_config.layout, tile_partitioner=communicator.partitioner.tile, tile_rank=communicator.tile.rank, @@ -110,8 +108,6 @@ def load_state_from_restart( return state def save_state(self, comm, restart_path: str = "RESTART"): - from pathlib import Path - Path(restart_path).mkdir(parents=True, exist_ok=True) current_rank = str(comm.Get_rank()) self.dycore_state.xr_dataset.to_netcdf( @@ -165,7 +161,7 @@ def _overwrite_state_from_restart( """ ds = xr.open_dataset(path + f"/{restart_file_prefix}_{rank}.nc") - for _field in fields(type(state)): + for _field in dataclasses.fields(type(state)): if "units" in _field.metadata.keys(): state.__dict__[_field.name].data[:] = gt_utils.asarray( ds[_field.name].data[:], to_type=state.__dict__[_field.name].np.ndarray @@ -180,14 +176,11 @@ def _restart_driver_state( damping_coefficients: DampingCoefficients, driver_grid_data: DriverGridData, grid_data: GridData, - schemes: List[PHYSICS_PACKAGES], + schemes: list[PHYSICS_PACKAGES], ): - fs = get_fs(path) - - restart_files = fs.ls(path) - is_fortran_restart = any( - fname.endswith("fv_core.res.nc") for fname in restart_files - ) + # It's a restart from a FORTRAN run if we find any files in the restart directory + # that ends in "fv_core.res.nc". + is_fortran_restart = any(fname for fname in Path(path).glob("**/*fv_core.res.nc")) if is_fortran_restart: dycore_state = DycoreState.from_fortran_restart( diff --git a/tests/main/driver/test_analytic_init.py b/tests/main/driver/test_analytic_init.py index 6f5405a0..8bd2dd1c 100644 --- a/tests/main/driver/test_analytic_init.py +++ b/tests/main/driver/test_analytic_init.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import pytest import yaml @@ -13,7 +12,7 @@ # TODO: Location of test configurations will be changed after refactor, # need to update after -TESTED_CONFIGS: List[Path] = [ +TESTED_CONFIGS: list[Path] = [ EXAMPLE_CONFIGS_DIR / "analytic_test.yaml", EXAMPLE_CONFIGS_DIR / "baroclinic_c48_6ranks_serialbox_test.yaml", ] @@ -25,7 +24,7 @@ pytest.param(TESTED_CONFIGS, id="example configs"), ], ) -def test_analytic_init_config(tested_configs: List[Path]): +def test_analytic_init_config(tested_configs: list[Path]): for config_file in tested_configs: with open(Path(__file__).parent / config_file, "r") as f: config = yaml.safe_load(f) diff --git a/tests/main/driver/test_diagnostics_config.py b/tests/main/driver/test_diagnostics_config.py index 14283349..cd0b2ca0 100644 --- a/tests/main/driver/test_diagnostics_config.py +++ b/tests/main/driver/test_diagnostics_config.py @@ -48,7 +48,6 @@ def test_zselect_raises_error_if_not_3d(tmpdir): ny_tile=12, nz=79, n_halo=3, - extra_dim_lengths={}, layout=(1, 1), ), backend="numpy", @@ -70,7 +69,6 @@ def test_zselect_raises_error_if_3rd_dim_not_z(tmpdir): ny_tile=12, nz=79, n_halo=3, - extra_dim_lengths={}, layout=(1, 1), ), backend="numpy", diff --git a/tests/main/driver/test_restart_fortran.py b/tests/main/driver/test_restart_fortran.py index 92e5ed6f..c658fc02 100644 --- a/tests/main/driver/test_restart_fortran.py +++ b/tests/main/driver/test_restart_fortran.py @@ -30,7 +30,6 @@ def test_state_from_fortran_restart(): ny_tile=12, nz=63, n_halo=3, - extra_dim_lengths={}, layout=layout, tile_partitioner=partitioner.tile, tile_rank=0, diff --git a/tests/main/driver/test_restart_serial.py b/tests/main/driver/test_restart_serial.py index 7f9430f8..aa9fe693 100644 --- a/tests/main/driver/test_restart_serial.py +++ b/tests/main/driver/test_restart_serial.py @@ -63,7 +63,6 @@ def test_restart_save_to_disk(): ny_tile=12, nz=79, n_halo=3, - extra_dim_lengths={}, layout=(1, 1), tile_partitioner=partitioner.tile, tile_rank=communicator.tile.rank, diff --git a/tests/main/fv3core/test_dycore_baroclinic.py b/tests/main/fv3core/test_dycore_baroclinic.py index 950b119e..afebd877 100644 --- a/tests/main/fv3core/test_dycore_baroclinic.py +++ b/tests/main/fv3core/test_dycore_baroclinic.py @@ -105,7 +105,6 @@ def setup_dycore( ny_tile=config.npy - 1, nz=config.npz, n_halo=3, - extra_dim_lengths={}, layout=config.layout, tile_partitioner=partitioner.tile, tile_rank=communicator.tile.rank, diff --git a/tests/main/fv3core/test_dycore_call.py b/tests/main/fv3core/test_dycore_call.py index 80d8ab70..e41f6972 100644 --- a/tests/main/fv3core/test_dycore_call.py +++ b/tests/main/fv3core/test_dycore_call.py @@ -89,7 +89,6 @@ def setup_dycore() -> Tuple[DynamicalCore, DycoreState, Timer]: ny_tile=config.npy - 1, nz=config.npz, n_halo=3, - extra_dim_lengths={}, layout=config.layout, tile_partitioner=partitioner.tile, tile_rank=communicator.tile.rank, diff --git a/tests/main/physics/test_integration.py b/tests/main/physics/test_integration.py index 76168041..6e215ea5 100644 --- a/tests/main/physics/test_integration.py +++ b/tests/main/physics/test_integration.py @@ -43,7 +43,6 @@ def setup_physics(): ny_tile=physics_config.npy - 1, nz=physics_config.npz, n_halo=3, - extra_dim_lengths={}, layout=layout, tile_partitioner=partitioner.tile, tile_rank=communicator.tile.rank, diff --git a/tests/main/test_grid_init.py b/tests/main/test_grid_init.py index 1361d2e5..32f05aed 100644 --- a/tests/main/test_grid_init.py +++ b/tests/main/test_grid_init.py @@ -31,7 +31,6 @@ def get_quantity_factory(layout, nx_tile, ny_tile, nz): ny_tile=ny, nz=nz, n_halo=3, - extra_dim_lengths={}, layout=(1, 1), ), backend="numpy", diff --git a/tests/mpi/test_grid_init.py b/tests/mpi/test_grid_init.py index b0aefab0..7e0ab41a 100644 --- a/tests/mpi/test_grid_init.py +++ b/tests/mpi/test_grid_init.py @@ -29,7 +29,7 @@ def get_quantity_factory(layout, nx_tile, ny_tile, nz): ny = ny_tile // layout[1] return QuantityFactory.from_backend( sizer=SubtileGridSizer.from_tile_params( - nx=nx, ny=ny, nz=nz, n_halo=3, extra_dim_lengths={}, layout=(1, 1) + nx=nx, ny=ny, nz=nz, n_halo=3, layout=(1, 1) ), backend="numpy", ) diff --git a/tests/savepoint/test_checkpoints.py b/tests/savepoint/test_checkpoints.py index 990e60f1..2a2f22fd 100644 --- a/tests/savepoint/test_checkpoints.py +++ b/tests/savepoint/test_checkpoints.py @@ -89,7 +89,6 @@ def test_fv_dynamics( n_halo=3, tile_partitioner=communicator.partitioner.tile, tile_rank=communicator.rank, - extra_dim_lengths={}, layout=namelist.layout, ), comm=communicator, diff --git a/tests/savepoint/translate/translate_driver.py b/tests/savepoint/translate/translate_driver.py index 56aeb6a7..f3499c56 100644 --- a/tests/savepoint/translate/translate_driver.py +++ b/tests/savepoint/translate/translate_driver.py @@ -31,7 +31,6 @@ def compute_parallel(self, inputs, communicator): ny_tile=self.namelist.npy - 1, nz=self.namelist.npz, n_halo=N_HALO_DEFAULT, - extra_dim_lengths={}, layout=self.namelist.layout, tile_partitioner=communicator.partitioner.tile, tile_rank=communicator.tile.rank,