diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 75ecbf87..44c4670a 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -16,6 +16,7 @@ from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from .initialization import GridSizer, QuantityFactory, SubtileGridSizer from .monitor.netcdf_monitor import NetCDFMonitor +from .monitor.diag_manager_monitor import DiagManagerMonitor from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport @@ -91,6 +92,7 @@ "LocalState", "NDSLRuntime", "Local", + "DiagManagerMonitor", "DataDimensionsField", "DataDimensionsMarkupType", ] diff --git a/ndsl/monitor/diag_manager_monitor.py b/ndsl/monitor/diag_manager_monitor.py new file mode 100644 index 00000000..9e61aa3b --- /dev/null +++ b/ndsl/monitor/diag_manager_monitor.py @@ -0,0 +1,167 @@ +from datetime import datetime, timedelta + +import numpy as np +import numpy.typing as npt + +from ndsl.monitor.protocol import Monitor + + +try: + from pyfms import diag_manager + + HAS_PYFMS = True +except ImportError: + HAS_PYFMS = False + + +class DiagManagerMonitor(Monitor): + """ + sympl.Monitor-style object for sending diagnostics to FMS's diag manager + """ + + def __init__( + self, + domain_id: int, + ) -> None: + """Create a DiagManagerMonitor. + + Args: + domain_id: integer domain-decomposition identifier as returned by mpp_define_domain + """ + if not HAS_PYFMS: + raise RuntimeError( + "pyFMS not installed, install ndsl[pyfms] to use the diag manager monitor" + ) + diag_manager.init(diag_model_subset=diag_manager.DIAG_ALL) + self.fields: dict[str, int] = {} + self.axes: dict[str, int] = {} + self.diag_end_time: datetime | None = None + self.domain_id = domain_id + + def store(self, state: dict) -> None: + """ + Sends data from quantities in the state to be written by the diag_manager. + All state variables must be registered beforehand via register_field. + """ + # get the associated quantities/axis for each field that has been registered + if state is not None: + time = state["time"] + for field_name, field_id in self.fields.items(): + field_quantity = state[field_name] + success = diag_manager.send_data( + diag_field_id=field_id, + field=field_quantity.field, + convert_cf_order=True, + time=time, + ) + if not success: + raise RuntimeError( + f"Failed to send data for field {field_name} at time {time} to diag_manager" + ) + try: + diag_manager.send_complete(timestep=self.timestep) + except NameError: + raise RuntimeError("no timestep set via set_timestep") + + def cleanup(self) -> None: + """ + Calls diag_manager.end after simulation ends to ensure all data is written. + """ + + if self.diag_end_time is None: + raise RuntimeError( + "End time was not set via set_end_time prior to cleanup call" + ) + diag_manager.end(end_time=self.diag_end_time) + + def set_end_time(self, end_time: datetime) -> None: + """ + Sets the end time to stop recieving data. Must be called prior to cleanup/diag_manager.end() + """ + diag_manager.set_time_end(end_time) + self.diag_end_time = end_time + + def set_timestep(self, timestep: timedelta) -> None: + """ + Sets the timestep to increment by after data is sent. + """ + self.timestep = timestep + + def register_field( + self, + module_name: str, + field_name: str, + units: str, + dtype: str, + init_time: datetime, + dims: list[str] | None = None, # if none, static field + missing_value: float | None = None, + long_name: str | None = None, + range_data: npt.NDArray | None = None, + ) -> None: + """ + Register a diagnostic field with the FMS diag_manager via the pyFMS interface for fortran + This corresponds to a variable/field in the output netcdf file. + Any axis/dimensions used by this variable should be registered prior to this function. + """ + if dims is not None: + field_axes = [self.axes[dim] for dim in dims] + if any(field_axes) is None: + raise ValueError( + f"All axes for field {field_name} must be registered before registering the field." + ) + + field_id = diag_manager.register_field_array( + module_name=module_name, + field_name=field_name, + axes=field_axes, + long_name=long_name, + units=units, + dtype=dtype, + missing_value=missing_value, + range_data=range_data, + init_time=init_time, + ) + if field_id < 0: + raise RuntimeError( + f"Failed to register field {field_name} in diag_manager, got field_id={field_id}" + ) + self.fields[field_name] = field_id + + def register_axis( + self, + name: str, + axis_data: np.ndarray, + not_xy: bool, + cart_name: str | None = None, + long_name: str | None = None, + units: str | None = None, + domain_id: int | None = None, + set_name: str | None = None, + ) -> None: + """ + Registers an axis with the FMS diag_manager via the pyFMS interface for fortran + This corresponds to a axis/dimension in the output netcdf file. + Time axis will be added as an unlimited dimension automatically, + so does not need to be explicitly registered. + """ + if not_xy: + self.axes[name] = diag_manager.axis_init( + name=name, + long_name=long_name, + axis_data=axis_data, + cart_name=cart_name, + set_name=set_name, + not_xy=not_xy, + units=units, + ) + else: + self.axes[name] = diag_manager.axis_init( + name=name, + long_name=long_name, + axis_data=axis_data, + cart_name=cart_name, + domain_id=domain_id, + set_name=set_name, + units=units, + ) diff --git a/pyproject.toml b/pyproject.toml index 6e724370..8bce607f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dev = [ ] docs = ["mkdocs-material", "mkdocstrings[python]", "mkdocs-exclude"] extras = ["ndsl[docs]", "ndsl[demos]", "ndsl[test]", "ndsl[dev]"] +pyfms = ["pyfms @ git+https://github.com/noaa-gfdl/pyfms.git"] serialbox = ["serialbox @ git+https://github.com/FlorianDeconinck/serialbox.git@feature/data_ijkbuff#subdirectory=src/serialbox-python"] test = ["pytest", "coverage"] zarr = ["zarr<3"] diff --git a/setup.py b/setup.py index 8b9f9922..ec9dc341 100644 --- a/setup.py +++ b/setup.py @@ -27,5 +27,4 @@ def local_pkg(name: str, relative_path: str) -> str: "dacite", # for state ] - setup(install_requires=requirements) diff --git a/tests/test_dm_monitor_cubed.py b/tests/test_dm_monitor_cubed.py new file mode 100644 index 00000000..a1ec8c59 --- /dev/null +++ b/tests/test_dm_monitor_cubed.py @@ -0,0 +1,226 @@ +"""Tests the diag_manager_monitor class can output ndsl quantity data. +This test case uses a cubic (6 tile) mosaic, and outputs a file for each tile. +""" + +from datetime import datetime, timedelta +from pathlib import Path + +import cftime +import numpy as np +import pytest +import xarray as xr +import yaml + +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DiagManagerMonitor, + MPIComm, + QuantityFactory, + TilePartitioner, +) +from ndsl.config import Backend +from ndsl.initialization import SubtileGridSizer + + +pyfms = pytest.importorskip("pyfms") + + +# init fms mpi and set up a simple domain +def fms_mpp_init(): + pyfms.fms.init(localcomm=MPIComm()._comm.py2f(), calendar_type=pyfms.fms.NOLEAP) + x = 8 + y = 8 + layout = [1, 1] + io_layout = [1, 1] + halo = 1 + tiles = 6 + domain_id = pyfms.mpp_domains.define_cubic_mosaic( + ni=[x for i in range(6)], + nj=[y for i in range(6)], + global_indices=[0, x - 1, 0, y - 1], + layout=layout, + ntiles=tiles, + halo=halo, + use_memsize=False, + ) + pyfms.mpp_domains.define_io_domain( + domain_id=domain_id, + io_layout=io_layout, + ) + pyfms.mpp_domains.set_current_domain(domain_id) + return domain_id + + +def _create_input(reduction: str = "none"): + diag_config = { + "title": "ndsl_diag_manager_test", + "base_date": "1 1 1 0 0 0", + "diag_files": [ + { + "file_name": "diag_manager_cubed_sphere", + "freq": "15 seconds", + "time_units": "seconds", + "unlimdim": "time", + "varlist": [ + { + "module": "atm_mod", + "var_name": "var1", + "long_name": "variable_number_one", + "reduction": reduction, + "kind": "r8", + }, + { + "module": "atm_mod", + "var_name": "var2", + "long_name": "variable_number_one", + "reduction": reduction, + "kind": "r8", + }, + ], + } + ], + } + with open("diag_table.yaml", "w") as f: + yaml.dump(diag_config, f, default_flow_style=False, sort_keys=False) + text_content = "&diag_manager_nml\nuse_modern_diag=.true.\n/" + with open("input.nml", "w", encoding="utf-8") as f: + f.write(text_content) + + +# Simple test, uses a lat/lon grid and (1, npes) layout +@pytest.mark.parallel +def test_dm_monitor(): + + npes = MPIComm()._comm.Get_size() + if npes % 6 != 0: + raise RuntimeError("this test requires npes to be a multiple of 6 to run") + + _create_input() + + nx = 8 + ny = 8 + nz = 2 + nhalo = 0 + layout = (1, 1) # 1 pe per tile + backend = Backend.python() + ntimesteps = 3 + + domain_id = fms_mpp_init() + partitioner = CubedSpherePartitioner(TilePartitioner((1, 1))) + communicator = CubedSphereCommunicator(MPIComm(), partitioner) + communicator.tile + + sizer = SubtileGridSizer.from_tile_params( + nx_tile=nx, + ny_tile=ny, + nz=nz, + n_halo=nhalo, + layout=layout, + tile_partitioner=partitioner.tile, + tile_rank=communicator.tile.rank, + backend=backend, + ) + quantity_factory = QuantityFactory(sizer, backend=backend) + + # pace will set up model start/end times and register axis info + monitor = DiagManagerMonitor( + domain_id=domain_id, + ) + start = datetime(1, 1, 1, 0, 0, second=0) + end = datetime(1, 1, 1, 0, 0, second=45) + step = timedelta(seconds=15) + monitor.set_timestep(step) + monitor.set_end_time(end) + + monitor.register_axis( + name="x", + axis_data=np.arange(nx, dtype=np.float64), + cart_name="x", + long_name="x coordinate", + units="m", + not_xy=False, + domain_id=domain_id, + ) + monitor.register_axis( + name="y", + axis_data=np.arange(ny, dtype=np.float64), + cart_name="y", + long_name="y coordinate", + units="m", + not_xy=False, + domain_id=domain_id, + ) + monitor.register_axis( + name="z", + axis_data=np.arange(nz, dtype=np.float64), + cart_name="z", + long_name="z coordinate", + units="m", + not_xy=True, + ) + + # fields will be registered in the component they are defined in (either pyFV3 or pySHiELD) + monitor.register_field( + module_name="atm_mod", + field_name="var1", + dims=["x", "y"], + units="m", + long_name="variable one", + init_time=start, + missing_value=-999.0, + dtype="float64", + ) + monitor.register_field( + module_name="atm_mod", + field_name="var2", + dims=["x", "y", "z"], + units="m", + long_name="variable two", + init_time=start, + missing_value=-999.0, + dtype="float64", + ) + assert "x" in monitor.axes + assert "y" in monitor.axes + assert "var1" in monitor.fields + + # pace driver will call store for each timestep to send the data + for t in range(1, ntimesteps + 1): + current_time = start + t * step + field_q1 = quantity_factory.full( + dims=("i", "j"), units="m", value=t, dtype=np.float64 + ) + field_q2 = quantity_factory.full( + dims=("i", "j", "k"), units="m", value=t * 2, dtype=np.float64 + ) + state = {"time": current_time, "var1": field_q1, "var2": field_q2} + monitor.store(state) + + # cleanup writes and closes the file + monitor.cleanup() + + pe = MPIComm()._comm.Get_rank() + 1 + filename = "diag_manager_cubed_sphere.tile" + str(pe) + ".nc" + assert Path(filename).exists() + ds = xr.open_mfdataset(filename, decode_times=True) + assert "var1" in ds + np.testing.assert_array_equal(ds["var1"].shape, (ntimesteps, ny, nx)) + assert "var2" in ds + np.testing.assert_array_equal(ds["var2"].shape, (ntimesteps, nz, ny, nx)) + assert ds["var1"].dims == ("time", "y", "x") + assert ds["var2"].dims == ("time", "z", "y", "x") + assert ds["time"].shape == (ntimesteps,) + assert ds["time"].dims == ("time",) + assert ds["time"].values[0] == cftime.DatetimeNoLeap(1, 1, 1, 0, 0, 15) + assert ds["time"].values[1] == cftime.DatetimeNoLeap(1, 1, 1, 0, 0, 30) + assert ds["time"].values[2] == cftime.DatetimeNoLeap(1, 1, 1, 0, 0, 45) + # data is just the timestep number + np.testing.assert_array_equal(ds["var1"].values[0, :, :], 1) + np.testing.assert_array_equal(ds["var1"].values[1, :, :], 2) + np.testing.assert_array_equal(ds["var1"].values[2, :, :], 3) + np.testing.assert_array_equal(ds["var2"].values[0, :, :, :], 2) + np.testing.assert_array_equal(ds["var2"].values[1, :, :, :], 4) + np.testing.assert_array_equal(ds["var2"].values[2, :, :, :], 6) + + pyfms.fms.end() diff --git a/tests/test_dm_monitor_single.py b/tests/test_dm_monitor_single.py new file mode 100644 index 00000000..f8c86328 --- /dev/null +++ b/tests/test_dm_monitor_single.py @@ -0,0 +1,249 @@ +"""Tests the diag_manager_monitor class can output ndsl quantity data. +This test case uses a single tile domain decomposition, and outputs a file from the root pe +with data gathered from any other processors. +""" + +from datetime import datetime, timedelta +from pathlib import Path + +import cftime +import numpy as np +import pytest +import xarray as xr +import yaml + +from ndsl import ( + DiagManagerMonitor, + LocalComm, + MPIComm, + QuantityFactory, + TileCommunicator, + TilePartitioner, +) +from ndsl.config import Backend +from ndsl.initialization import SubtileGridSizer + + +pyfms = pytest.importorskip("pyfms") + + +def _create_input(reduction: str = "none"): + diag_config = { + "title": "ndsl_diag_manager_test", + "base_date": "2 1 1 1 1 1", + "diag_files": [ + { + "file_name": "diag_manager_single_tile", + "freq": "1 hours", + "time_units": "hours", + "unlimdim": "time", + "varlist": [ + { + "module": "atm_mod", + "var_name": "var_2d", + "long_name": "variable_too_dee", + "reduction": "none", + "kind": "r8", + }, + { + "module": "atm_mod", + "var_name": "var_3d", + "long_name": "variable_three_dee", + "reduction": "none", + "kind": "r8", + }, + ], + } + ], + } + with open("diag_table.yaml", "w") as f: + yaml.dump(diag_config, f, default_flow_style=False, sort_keys=False) + text_content = "&diag_manager_nml\nuse_modern_diag=.true.\n/" + with open("input.nml", "w", encoding="utf-8") as f: + f.write(text_content) + + +def test_dm_monitor_single_tile(): + # mpi info + npes = MPIComm()._comm.Get_size() + pe = MPIComm()._comm.Get_rank() + # tile parameters for quantities/domains + nx = 8 + ny = 8 + nz = 2 + nhalo = 0 + backend = Backend.python() + ntimesteps = 3 + layout_fms = [1, npes] + io_layout = [1, 1] + layout_ndsl = (npes, 1) # flipped to match fms domain decomposition + global_indices = [0, nx - 1, 0, ny - 1] + + _create_input() + + pyfms.fms.init(localcomm=MPIComm()._comm.py2f(), calendar_type=pyfms.fms.NOLEAP) + + domain = pyfms.mpp_domains.define_domains( + global_indices=global_indices, + layout=layout_fms, + ) + pyfms.mpp_domains.set_current_domain(domain_id=domain.domain_id) + domain_id = domain.domain_id + pyfms.mpp_domains.define_io_domain( + domain_id=domain_id, + io_layout=io_layout, + ) + + if npes > 1: + rank = MPIComm()._comm.Get_rank() + print(f"intializing partitioner/communicator rank {rank} of {npes}") + partitioner = TilePartitioner(layout=layout_ndsl) + communicator = TileCommunicator(MPIComm(), partitioner) + communicator.tile + else: + buffer = {} + partitioner = TilePartitioner((1, 1)) + communicator = TileCommunicator( + comm=LocalComm(rank=0, total_ranks=npes, buffer_dict=buffer), + partitioner=partitioner, + ) + communicator.tile + + sizer = SubtileGridSizer.from_tile_params( + nx_tile=nx, + ny_tile=ny, + nz=nz, + n_halo=nhalo, + layout=layout_ndsl, + tile_partitioner=partitioner.tile, + tile_rank=communicator.tile.rank, + backend=backend, + ) + quantity_factory = QuantityFactory(sizer, backend=backend) + + # set up for diag manager for before the main loop, need to set timestep + end_time and register all axes and fields + monitor = DiagManagerMonitor(domain_id=domain_id) + start = datetime(2, 1, 1, 1, 1, 1) + step = timedelta(seconds=3600) + end = start + ntimesteps * step + + monitor.set_timestep(step) + monitor.set_end_time(end) + + monitor.register_axis( + name="x", + axis_data=np.arange(nx, dtype=np.float64), + units="point_E", + cart_name="x", + domain_id=domain_id, + long_name="point_E", + set_name="atm", + not_xy=False, + ) + monitor.register_axis( + name="y", + axis_data=np.arange(ny, dtype=np.float64), + units="point_N", + cart_name="y", + domain_id=domain_id, + long_name="point_N", + set_name="atm", + not_xy=False, + ) + monitor.register_axis( + name="z", + axis_data=np.arange(nz, dtype=np.float64), + units="point_Z", + cart_name="z", + long_name="point_Z", + set_name="atm", + not_xy=True, + ) + + monitor.register_field( + module_name="atm_mod", + field_name="var_2d", + dims=["x", "y"], + units="muntin", + init_time=start, + dtype="float64", + missing_value=-99.99, + ) + monitor.register_field( + module_name="atm_mod", + field_name="var_3d", + dims=["x", "y", "z"], + units="muntin", + init_time=start, + dtype="float64", + missing_value=-99.99, + ) + assert "x" in monitor.axes + assert "y" in monitor.axes + assert "z" in monitor.axes + assert "var_2d" in monitor.fields + assert "var_3d" in monitor.fields + + # set up data to send for diagnostics + var2_global = np.empty(shape=(nx, ny), dtype=np.float64) + var3_global = np.empty(shape=(nx, ny, nz), dtype=np.float64) + for i in range(nx): + for j in range(ny): + var2_global[i][j] = i * 10.0 + j + for i in range(nx): + for j in range(ny): + for k in range(nz): + var3_global[i][j][k] = i * 100 + j * 10 + k + var2 = var2_global[domain.isc : domain.iec + 1, domain.jsc : domain.jec + 1] + var3 = var3_global[domain.isc : domain.iec + 1, domain.jsc : domain.jec + 1, :] + + # pad arrays for quantity factory + var2 = np.pad(var2, (0, 1)) + var3 = np.pad(var3, (0, 1)) + field_q1 = quantity_factory.from_array(var2, dims=("i", "j"), units="m") + field_q2 = quantity_factory.from_array(var3, dims=("i", "j", "k"), units="m") + + MPIComm()._comm.Barrier() + + current_time = start + for _t in range(ntimesteps): + current_time = current_time + step + state = { + "time": current_time, + "var_2d": field_q1, + "var_3d": field_q2, + } + monitor.store(state) + + # cleanup writes and closes the file + monitor.cleanup() + + # check output! + assert Path("diag_manager_single_tile.nc").exists() + ds = xr.open_mfdataset("diag_manager_single_tile.nc", decode_times=True) + assert "var_2d" in ds + np.testing.assert_array_equal(ds["var_2d"].shape, (ntimesteps, nx, ny)) + assert ds["var_2d"].dims == ("time", "y", "x") + assert ds["var_2d"].attrs["units"] == "muntin" + assert ds["var_3d"].dims == ("time", "z", "y", "x") + assert ds["var_3d"].attrs["units"] == "muntin" + assert ds["time"].shape == (ntimesteps,) + assert ds["time"].dims == ("time",) + assert ds["time"].values[0] == cftime.DatetimeNoLeap(2, 1, 1, 2, 1, 1) + assert ds["time"].values[1] == cftime.DatetimeNoLeap(2, 1, 1, 3, 1, 1) + assert ds["time"].values[2] == cftime.DatetimeNoLeap(2, 1, 1, 4, 1, 1) + np.testing.assert_array_equal(ds["var_2d"].values[0, :, :], var2_global.transpose()) + np.testing.assert_array_equal(ds["var_2d"].values[1, :, :], var2_global.transpose()) + np.testing.assert_array_equal(ds["var_2d"].values[2, :, :], var2_global.transpose()) + # data is transposed when passed into fortran + np.testing.assert_array_equal( + ds["var_3d"].values[0, :, :, :], var3_global.transpose() + ) + np.testing.assert_array_equal( + ds["var_3d"].values[1, :, :, :], var3_global.transpose() + ) + np.testing.assert_array_equal( + ds["var_3d"].values[2, :, :, :], var3_global.transpose() + ) + + pyfms.fms.end()