diff --git a/ndsl/__init__.py b/ndsl/__init__.py index bb77d751..675610e3 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -28,7 +28,7 @@ from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .quantity import Quantity +from .quantity import Quantity, State from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import from .testing.dummy_comm import DummyComm from .types import Allocator @@ -48,7 +48,6 @@ "FV3CodePath", "DaceConfig", "DaCeOrchestration", - "FrozenCompiledSDFG", "orchestrate", "orchestrate_function", "ArrayReport", @@ -87,4 +86,5 @@ "DummyComm", "Allocator", "MetaEnumStr", + "State", ] diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py index 43751528..ee7e4d78 100644 --- a/ndsl/quantity/__init__.py +++ b/ndsl/quantity/__init__.py @@ -1,11 +1,11 @@ from .metadata import QuantityHaloSpec, QuantityMetadata from .quantity import Quantity +from .state import State __all__ = [ "Quantity", "QuantityMetadata", "QuantityHaloSpec", - "FieldBundle", - "FieldBundleType", + "State", ] diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 0aec5eef..9272fa22 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -255,14 +255,29 @@ def field(self) -> np.ndarray | cupy.ndarray: return self._compute_domain_view[:] @property - def data(self) -> Union[np.ndarray, cupy.ndarray]: + def data(self) -> np.ndarray | cupy.ndarray: """The underlying array of data""" return self._data @data.setter - def data(self, inputData): - if type(inputData) in [np.ndarray, cupy.ndarray]: - self._data = inputData + def data(self, input_data: np.ndarray | cupy.ndarray): + if type(input_data) not in [np.ndarray, cupy.ndarray]: + raise TypeError( + "Quantity.data buffer swap failed: " + f"given data is not an array (type: {type(input_data)})" + ) + + if input_data.shape < self.extent: + raise ValueError( + "Quantity.data buffer swap failed: " + f"new data ({input_data.shape}) is smaller " + f"than expected extent ({self.extent})." + ) + + self._data = input_data + self._compute_domain_view = BoundedArrayView( + self.data, self.dims, self.origin, self.extent + ) @property def origin(self) -> Tuple[int, ...]: diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py new file mode 100644 index 00000000..b9ccbeb0 --- /dev/null +++ b/ndsl/quantity/state.py @@ -0,0 +1,320 @@ +from __future__ import annotations + +import dataclasses +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Self + +import dacite +import xarray as xr +from mpi4py import MPI +from numpy.typing import ArrayLike + + +if TYPE_CHECKING: + from ndsl import QuantityFactory + + +@dataclasses.dataclass +class State: + """Base class for state objects in models. + + A State groups a collection of (possibly nested) Quantities in a dataclass. + + This baseclass implements common initialization functions and serialization. + + Typical usage example: + + ```python + class MyState(State): + pass + + my_state = MyState.zeros(quantity_factory) + + # ... + + my_state.to_netcdf() + ``` + """ + + @classmethod + def _init(cls, quantity_factory_allocator: Callable) -> Self: + """Allocate memory and init with a blind quantity init operation""" + + def _init_recursive(cls): + initial_quantities = {} + for _field in dataclasses.fields(cls): + if dataclasses.is_dataclass(_field.type): + initial_quantities[_field.name] = _init_recursive(_field.type) + else: + if "dims" not in _field.metadata.keys(): + raise ValueError( + "Malformed state - no dims to init " + f"Quantity in {_field.name} of type {_field.type}" + ) + + initial_quantities[_field.name] = quantity_factory_allocator( + _field.metadata["dims"], + _field.metadata["units"], + dtype=_field.metadata["dtype"], + allow_mismatch_float_precision=True, + ) + + return initial_quantities + + dict_of_quantities = _init_recursive(cls) + return dacite.from_dict(data_class=cls, data=dict_of_quantities) + + @classmethod + def empty(cls, quantity_factory: QuantityFactory) -> Self: + """Allocate all quantities""" + + return cls._init(quantity_factory.empty) + + @classmethod + def zeros(cls, quantity_factory: QuantityFactory) -> Self: + """Allocate all quantities and fill their value to zeros""" + + return cls._init(quantity_factory.zeros) + + @classmethod + def ones(cls, quantity_factory: QuantityFactory) -> Self: + """Allocate all quantities and fill their value to ones""" + + return cls._init(quantity_factory.ones) + + @classmethod + def copy_memory( + cls, + quantity_factory: QuantityFactory, + memory_map: dict[str, Any], + ) -> Self: + """Allocate all quantities and fill their value based + on the given memory map. See `update_from_memory`""" + + state = cls.zeros(quantity_factory) + state.update_copy_memory(memory_map) + + return state + + @classmethod + def move_memory( + cls, + quantity_factory: QuantityFactory, + memory_map: dict[str, Any], + *, + check_shape_and_strides: bool = True, + ) -> Self: + """Allocate all quantities and move memory based on + on the given memory map. See `update_move_memory`.""" + + state = cls.zeros(quantity_factory) + state.update_move_memory( + memory_map, + check_shape_and_strides=check_shape_and_strides, + ) + + return state + + def update_copy_memory(self, memory_map: dict[str, Any]) -> None: + """Copy data into the Quantities carried by the state. + + The memory map must follow the dataclass naming convention, e.g. + + ```python + @dataclass + class MyState: + @dataclass + class InnerA + a: Quantity + + inner_a: InnerA + b: Quantity + ``` + will update with a dictionary looking like + ```python + { + "inner_a": + { + "a": Quantity(...) + } + "b": Quantity(...) + + } + ``` + + The memory map can be sparse. + """ + + def _update_from_memory_recursive( + state: State, + memory_map: dict[str, dict | ArrayLike], + ): + for name, array in memory_map.items(): + if isinstance(array, dict): + _update_from_memory_recursive(state.__getattribute__(name), array) + else: + try: + state.__getattribute__(name).field[:] = array + except Exception as e: + e.add_note( + f"Error when initializing field {name} on state {type(self)}" + ) + raise e + + _update_from_memory_recursive(self, memory_map) + + def update_move_memory( + self, + memory_map: dict[str, dict | ArrayLike], + *, + check_shape_and_strides: bool = True, + ) -> None: + """Move memory into the Quantities carried by the state. + Memory is moved rather than copied (e.g. buffers are swapped) + + The memory map must follow the dataclass naming convention, e.g. + + ```python + @dataclass + class MyState: + @dataclass + class InnerA + a: Quantity + + inner_a: InnerA + b: Quantity + ``` + will update with a dictionary looking like + ```python + { + "inner_a": + { + "a": Quantity(...) + } + "b": Quantity(...) + + } + ``` + + The memory map can be sparse. + + Args: + memory_map: Dictionary of keys to buffers. Buffers must be np.ArrayLike + check_shape_and_strides: check that the given buffers have the same + shape and strides as the original quantity + """ + + def _update_zero_copy_recursive( + state: State, memory_map: dict[str, dict | ArrayLike] + ): + for name, array in memory_map.items(): + if isinstance(array, dict): + _update_zero_copy_recursive(state.__getattribute__(name), array) + else: + quantity = state.__getattribute__(name) + if check_shape_and_strides: + if array.shape != quantity.field.shape: + e = ValueError("Shape mismatch on zero copy for") + e.add_note(f" Error on {name} for {type(state)}") + e.add_note( + f" Shapes: {array.shape} != {quantity.field.shape}" + ) + raise e + if array.strides != quantity.data.strides: + e = ValueError("Stride mismatch on zero copy for") + e.add_note(f" Error on {name} for {type(state)}") + e.add_note( + f" Strides: {array.strides} != {quantity.data.strides}" + ) + raise e + + quantity.data = array + + _update_zero_copy_recursive(self, memory_map) + + def _netcdf_name(self, directory_path: Path) -> Path: + """Resolve rank-tied postfix if needed""" + rank_postfix = "" + if MPI.COMM_WORLD.Get_size() > 1: + rank_postfix = f"_rank{MPI.COMM_WORLD.Get_rank()}" + return directory_path / f"{type(self).__name__}{rank_postfix}.nc4" + + def to_netcdf(self, directory_path: Path = Path("./")) -> None: + """ + Save state to NetCDF. Can be reloaded with `update_from_netcdf`. + + If applicable, will save seperate NetCDF files for each running rank. + + The file names are deduced from the class name, and post fix with rank number + in the case of a multi-process use. + + Args: + directory_path: directory to save the netcdf in + """ + + def _save_recursive(state: State): + local_data = {} + for _field in dataclasses.fields(state): + if dataclasses.is_dataclass(_field.type): + local_data[_field.name] = xr.Dataset( + data_vars=_save_recursive(state.__getattribute__(_field.name)) + ) + else: + if "dims" not in _field.metadata.keys(): + raise ValueError( + "Malformed state - no dims to init " + f"Quantity in {_field.name} of type {_field.type}" + ) + + local_data[_field.name] = state.__getattribute__( + _field.name + ).field_as_xarray + + return local_data + + datatree = _save_recursive(self) + + # Move top-level into their own dataset in the "/" prefix + # to match DataTree expected format + top_level = {} + for key, value in datatree.items(): + if not isinstance(value, xr.Dataset): + top_level[key] = value + for key, value in top_level.items(): + datatree.pop(key) + datatree["/"] = xr.Dataset(data_vars=top_level) + + xr.DataTree.from_dict(datatree).to_netcdf(self._netcdf_name(directory_path)) + + def update_from_netcdf(self, directory_path: Path) -> None: + """This is a mirror of the `to_netcdf` method NOT a generic + NetCDF loader. It expects the NetCDF to be named with the auto-naming scheme + of `to_netcdf`. + + Args: + directory_path: directory carrying the netcdf saved with `to_netcdf` + + """ + datatree = xr.open_datatree(self._netcdf_name(directory_path)) + datatree_as_dict = datatree.to_dict() + + # All other cases - recursing downward + def _load_recursive(data_tree_as_dict: dict[str, xr.Dataset] | xr.Dataset): + local_data_dict = {} + for name, data_array in data_tree_as_dict.items(): + # Case of the top_level "/" + if name == "/": + for root_name, root_data_array in datatree_as_dict["/"].items(): + local_data_dict[root_name] = root_data_array.to_numpy() + else: + # Get the leading `/` out + if isinstance(data_array, xr.Dataset): + local_data_dict[name[1:]] = _load_recursive(data_array) + else: + local_data_dict[name] = data_array.to_numpy() + + return local_data_dict + + data_as_numpy_dict = _load_recursive(datatree_as_dict) + + self.update_copy_memory(data_as_numpy_dict) diff --git a/setup.py b/setup.py index 4064cb62..1719226f 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ def local_pkg(name: str, relative_path: str) -> str: "matplotlib", # for plotting in boilerplate "cartopy", # for plotting in ndsl.viz "pytest-subtests", # for translate tests + "dacite", # for state ] diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 0250a96b..dccfa94f 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -261,3 +261,31 @@ def test_to_data_array(quantity): assert ( quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data ), "data memory address is not equal" + + +def test_data_setter(): + quantity = Quantity(np.ones((5,)), dims=["dim1"], units="") + + # After allocation - field and data are the same (origin is 0) + assert quantity.data.shape == quantity.field.shape + + # Allows swap: new array is bigger than Q.shape + new_array = np.ones((10,)) + new_array[:] = 2 + quantity.data = new_array + + # After swap - field and data points to the same memory + # BUT field still respects the original origin/extent + assert (quantity.data[:] == 2).all() + assert (quantity.field[:] == 2).all() + assert quantity.data.shape != quantity.field.shape + assert quantity.field.shape == (5,) + + # Expected fail: new array is too small + new_array = np.ones((2,)) + with pytest.raises(ValueError, match="Quantity.data buffer swap failed.*"): + quantity.data = new_array + + # Expected fail: new array is not even an array + with pytest.raises(TypeError, match="Quantity.data buffer swap failed.*"): + quantity.data = "meh" diff --git a/tests/quantity/test_state.py b/tests/quantity/test_state.py new file mode 100644 index 00000000..86ec1a19 --- /dev/null +++ b/tests/quantity/test_state.py @@ -0,0 +1,70 @@ +import dataclasses +from pathlib import Path + +import numpy as np + +from ndsl import Quantity, State +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float + + +@dataclasses.dataclass +class CodeState(State): + @dataclasses.dataclass + class InnerA: + A: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + @dataclasses.dataclass + class InnerB: + B: Quantity = dataclasses.field( + metadata={ + "name": "B", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "1", + "intent": "?", + "dtype": Float, + } + ) + + inner_A: InnerA + inner_B: InnerB + C: Quantity = dataclasses.field( + metadata={ + "name": "C", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + +def test_state(): + _, quantity_factory = get_factories_single_tile( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) + + microphys_state = CodeState.zeros(quantity_factory) + microphys_state.inner_A.A.field[:] = 42.42 + microphys_state.to_netcdf() + microphys_state2 = CodeState.zeros(quantity_factory) + microphys_state2.update_from_netcdf(Path("./")) + assert (microphys_state2.inner_A.A.field[:] == 42.42).all() + a = np.ones((5, 5, 3)) + b = np.ones((5, 5, 3)) + c = np.ones((5, 5, 3)) + b[:] = 23.23 + microphys_state2.update_move_memory( + {"inner_A": {"A": a}, "inner_B": {"B": b}, "C": c}, + check_shape_and_strides=False, + ) + assert (microphys_state2.inner_A.A.field[:] == 1.0).all() + assert (microphys_state2.inner_B.B.field[:] == 23.23).all()