From 646764f47eca7191ba2c09b82515281a5b913f3e Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 30 Sep 2025 12:32:49 -0400 Subject: [PATCH 01/12] State v0 + utest --- ndsl/__init__.py | 3 +- ndsl/quantity/__init__.py | 4 +- ndsl/quantity/state.py | 160 +++++++++++++++++++++++++++++++++++ tests/quantity/test_state.py | 67 +++++++++++++++ 4 files changed, 231 insertions(+), 3 deletions(-) create mode 100644 ndsl/quantity/state.py create mode 100644 tests/quantity/test_state.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index bb77d751..b066833f 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -28,7 +28,7 @@ from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .quantity import Quantity +from .quantity import Quantity, State from .quantity.field_bundle import FieldBundle, FieldBundleType # Break circular import from .testing.dummy_comm import DummyComm from .types import Allocator @@ -87,4 +87,5 @@ "DummyComm", "Allocator", "MetaEnumStr", + "State", ] diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py index 43751528..ee7e4d78 100644 --- a/ndsl/quantity/__init__.py +++ b/ndsl/quantity/__init__.py @@ -1,11 +1,11 @@ from .metadata import QuantityHaloSpec, QuantityMetadata from .quantity import Quantity +from .state import State __all__ = [ "Quantity", "QuantityMetadata", "QuantityHaloSpec", - "FieldBundle", - "FieldBundleType", + "State", ] diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py new file mode 100644 index 00000000..23f59889 --- /dev/null +++ b/ndsl/quantity/state.py @@ -0,0 +1,160 @@ +from typing import Any, Self +from ndsl import QuantityFactory +import dataclasses +import dacite +import xarray as xr +from mpi4py import MPI +from numpy.typing import ArrayLike + + +@dataclasses.dataclass +class State: + """Base class for State object in models that bundles a collection + of functions to deal with nested dataclasses and common usage of States: + - init (zero, from memory, zero copy buffer swap) + - IO (save to NetCDF, from NetCDF) + + The State expects Quantities. + """ + + @classmethod + def zeros(cls, quantity_factory: QuantityFactory) -> Self: + """Init all quantities to zeros - included nested ones""" + + def _zeros_recursive(cls): + initial_quantities = {} + for _field in dataclasses.fields(cls): + if dataclasses.is_dataclass(_field.type): + initial_quantities[_field.name] = _zeros_recursive(_field.type) + else: + if "dims" not in _field.metadata.keys(): + raise ValueError( + "Malformed state - no dims to init " + f"Quantity in {_field.name} of type {_field.type}" + ) + + initial_quantities[_field.name] = quantity_factory.zeros( + _field.metadata["dims"], + _field.metadata["units"], + dtype=_field.metadata["dtype"], + allow_mismatch_float_precision=True, + ) + + return initial_quantities + + dict_of_qty = _zeros_recursive(cls) + return dacite.from_dict(data_class=cls, data=dict_of_qty) + + def init_from_memory(self, memory_map: dict[str, Any]): + """Will copy data from the memory map if it follows the nested + naming convention of the dataclass""" + + def _init_from_memory_recursive(dataclss, memory_map: dict[str, Any]): + for name, array in memory_map.items(): + if isinstance(array, dict): + _init_from_memory_recursive(dataclss.__getattribute__(name), array) + else: + try: + dataclss.__getattribute__(name).field[:] = array + except ValueError as e: + e.add_note( + f"Error when initializing field {name} on state {type(self)}" + ) + raise e + + _init_from_memory_recursive(self, memory_map) + + def init_zero_copy(self, memory_map: dict[str, Any], check: bool = True): + """Swap buffers given into the Quantities carried by the state + by following dataclass naming convention""" + + def _init_zero_copy_recursive(dataclss, memory_map: dict[str, Any | ArrayLike]): + for name, array in memory_map.items(): + if isinstance(array, dict): + _init_zero_copy_recursive(dataclss.__getattribute__(name), array) + else: + qty = dataclss.__getattribute__(name) + if check: + if array.shape != qty.field.shape: + e = ValueError("Shape mismatch on zero copy for") + e.add_note(f" Error on {name} for {type(dataclss)}") + e.add_note(f" Shapes: {array.shape} != {qty.field.shape}") + raise e + if array.strides != qty.data.strides: + e = ValueError("Stride mismatch on zero copy for") + e.add_note(f" Error on {name} for {type(dataclss)}") + e.add_note( + f" Strides: {array.strides} != {qty.data.strides}" + ) + raise e + + qty.data = array + + _init_zero_copy_recursive(self, memory_map) + + def to_netcdf(self, path: str = "./"): + def _save_recursive(datclss: State): + local_data = {} + for _field in dataclasses.fields(datclss): + if dataclasses.is_dataclass(_field.type): + local_data[_field.name] = xr.Dataset( + data_vars=_save_recursive(datclss.__getattribute__(_field.name)) + ) + else: + if "dims" not in _field.metadata.keys(): + raise ValueError( + "Malformed state - no dims to init " + f"Quantity in {_field.name} of type {_field.type}" + ) + + local_data[_field.name] = datclss.__getattribute__( + _field.name + ).field_as_xarray + + return local_data + + datatree = _save_recursive(self) + + # Move top-level into their own dataset in the "/" prefix + # to match DataTree expected format + top_level = {} + for key, value in datatree.items(): + if not isinstance(value, xr.Dataset): + top_level[key] = value + for key, value in top_level.items(): + datatree.pop(key) + datatree["/"] = xr.Dataset(data_vars=top_level) + + # Resolve rank-tied postfix if needed + rank_postfix = "" + if MPI.COMM_WORLD.Get_size() > 1: + rank_postfix = f"_rank{MPI.COMM_WORLD.Get_rank()}" + + xr.DataTree.from_dict(datatree).to_netcdf( + f"{path}{type(self).__name__}{rank_postfix}.nc4" + ) + + def from_netcdf(self, path: str): + datatree = xr.open_datatree(path) + datatree_as_dict = datatree.to_dict() + + # All other cases - recursing downward + def _load_recursive(data_tree_as_dict: dict[str, xr.Dataset] | xr.Dataset): + local_data_dict = {} + for name, data_array in data_tree_as_dict.items(): + # Case of the top_level "/" + if name == "/": + for root_name, root_data_array in datatree_as_dict["/"].items(): + local_data_dict[root_name] = root_data_array.to_numpy() + else: + # Get the leading `/` out + if isinstance(data_array, xr.Dataset): + local_data_dict[name[1:]] = _load_recursive(data_array) + else: + local_data_dict[name] = data_array.to_numpy() + + return local_data_dict + + data_as_numpy_dict = _load_recursive(datatree_as_dict) + + self.init_from_memory(data_as_numpy_dict) diff --git a/tests/quantity/test_state.py b/tests/quantity/test_state.py new file mode 100644 index 00000000..3af42028 --- /dev/null +++ b/tests/quantity/test_state.py @@ -0,0 +1,67 @@ +from ndsl import Quantity +from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float +from ndsl.boilerplate import get_factories_single_tile +import dataclasses +import numpy as np + +from ndsl import State + + +@dataclasses.dataclass +class CodeState(State): + @dataclasses.dataclass + class InnerA: + A: Quantity = dataclasses.field( + metadata={ + "name": "A", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + @dataclasses.dataclass + class InnerB: + B: Quantity = dataclasses.field( + metadata={ + "name": "B", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "1", + "intent": "?", + "dtype": Float, + } + ) + + inner_A: InnerA + inner_B: InnerB + C: Quantity = dataclasses.field( + metadata={ + "name": "C", + "dims": [X_DIM, Y_DIM, Z_DIM], + "units": "kg kg-1", + "intent": "?", + "dtype": Float, + } + ) + + +def test_state(): + _, qty_factry = get_factories_single_tile(5, 5, 3, 0, backend="dace:cpu_kfirst") + + microphys_state = CodeState.zeros(qty_factry) + microphys_state.inner_A.A.field[:] = 42.42 + microphys_state.to_netcdf() + microphys_state2 = CodeState.zeros(qty_factry) + microphys_state2.from_netcdf("CodeState.nc4") + assert (microphys_state2.inner_A.A.field[:] == 42.42).all() + a = np.ones((5, 5, 3)) + b = np.ones((5, 5, 3)) + c = np.ones((5, 5, 3)) + b[:] = 23.23 + microphys_state2.init_zero_copy( + {"inner_A": {"A": a}, "inner_B": {"B": b}, "C": c}, + check=False, + ) + assert (microphys_state2.inner_A.A.field[:] == 1.0).all() + assert (microphys_state2.inner_B.B.field[:] == 23.23).all() From 8a4993d6ef806ae1de732a11a25e333b5948e1fa Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 30 Sep 2025 12:33:04 -0400 Subject: [PATCH 02/12] `dacite` as a new dependancy --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 4064cb62..1719226f 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ def local_pkg(name: str, relative_path: str) -> str: "matplotlib", # for plotting in boilerplate "cartopy", # for plotting in ndsl.viz "pytest-subtests", # for translate tests + "dacite", # for state ] From 6aac410332a65d0f8b8edb5d5a41bb42b32db60b Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 30 Sep 2025 12:33:27 -0400 Subject: [PATCH 03/12] Fix for `data` setter in quantity reseting the `compute_view` for `field` accessor --- ndsl/quantity/quantity.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 0aec5eef..aab045eb 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -263,6 +263,9 @@ def data(self) -> Union[np.ndarray, cupy.ndarray]: def data(self, inputData): if type(inputData) in [np.ndarray, cupy.ndarray]: self._data = inputData + self._compute_domain_view = BoundedArrayView( + self.data, self.dims, self.origin, self.extent + ) @property def origin(self) -> Tuple[int, ...]: From f77228df72eaa3f83290936cd065507ec275eb85 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 30 Sep 2025 14:48:22 -0400 Subject: [PATCH 04/12] Lint --- ndsl/quantity/state.py | 6 ++++-- tests/quantity/test_state.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py index 23f59889..38a7d9b2 100644 --- a/ndsl/quantity/state.py +++ b/ndsl/quantity/state.py @@ -1,11 +1,13 @@ -from typing import Any, Self -from ndsl import QuantityFactory import dataclasses +from typing import Any, Self + import dacite import xarray as xr from mpi4py import MPI from numpy.typing import ArrayLike +from ndsl import QuantityFactory + @dataclasses.dataclass class State: diff --git a/tests/quantity/test_state.py b/tests/quantity/test_state.py index 3af42028..4f409ccd 100644 --- a/tests/quantity/test_state.py +++ b/tests/quantity/test_state.py @@ -1,10 +1,10 @@ -from ndsl import Quantity -from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float -from ndsl.boilerplate import get_factories_single_tile import dataclasses + import numpy as np -from ndsl import State +from ndsl import Quantity, State +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import X_DIM, Y_DIM, Z_DIM, Float @dataclasses.dataclass From 376d66a78c2d88934683957451eeb6f8d1cf3ceb Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 30 Sep 2025 14:58:54 -0400 Subject: [PATCH 05/12] Fix circular import --- ndsl/__init__.py | 1 - ndsl/quantity/state.py | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index b066833f..675610e3 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -48,7 +48,6 @@ "FV3CodePath", "DaceConfig", "DaCeOrchestration", - "FrozenCompiledSDFG", "orchestrate", "orchestrate_function", "ArrayReport", diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py index 38a7d9b2..d0d1cd44 100644 --- a/ndsl/quantity/state.py +++ b/ndsl/quantity/state.py @@ -1,12 +1,16 @@ +from __future__ import annotations + import dataclasses -from typing import Any, Self +from typing import TYPE_CHECKING, Any, Self import dacite import xarray as xr from mpi4py import MPI from numpy.typing import ArrayLike -from ndsl import QuantityFactory + +if TYPE_CHECKING: + from ndsl import QuantityFactory @dataclasses.dataclass From f9f480b982a515d93a4306c34e2cc30fb96089a8 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 1 Oct 2025 13:16:55 -0400 Subject: [PATCH 06/12] Add update/init distinction Better docstings --- ndsl/quantity/state.py | 229 ++++++++++++++++++++++++++++------- tests/quantity/test_state.py | 15 ++- 2 files changed, 195 insertions(+), 49 deletions(-) diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py index d0d1cd44..83daa53f 100644 --- a/ndsl/quantity/state.py +++ b/ndsl/quantity/state.py @@ -1,7 +1,8 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, Any, Self +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Self import dacite import xarray as xr @@ -15,23 +16,35 @@ @dataclasses.dataclass class State: - """Base class for State object in models that bundles a collection - of functions to deal with nested dataclasses and common usage of States: - - init (zero, from memory, zero copy buffer swap) - - IO (save to NetCDF, from NetCDF) + """Base class for state objects in models. - The State expects Quantities. + A State groups a collection of (possibly nested) Quantities in a dataclass. + + This baseclass implements common initialization functions and serialization. + + Typical usage example: + + ```python + class MyState(State): + pass + + my_state = MyState.zeros(quantity_factory) + + # ... + + my_state.to_netcdf() + ``` """ @classmethod - def zeros(cls, quantity_factory: QuantityFactory) -> Self: - """Init all quantities to zeros - included nested ones""" + def _init(cls, quantity_factory_allocator: Callable) -> Self: + """Allocate memory and init with a blind quantity init operation""" - def _zeros_recursive(cls): + def _init_recursive(cls): initial_quantities = {} for _field in dataclasses.fields(cls): if dataclasses.is_dataclass(_field.type): - initial_quantities[_field.name] = _zeros_recursive(_field.type) + initial_quantities[_field.name] = _init_recursive(_field.type) else: if "dims" not in _field.metadata.keys(): raise ValueError( @@ -39,7 +52,7 @@ def _zeros_recursive(cls): f"Quantity in {_field.name} of type {_field.type}" ) - initial_quantities[_field.name] = quantity_factory.zeros( + initial_quantities[_field.name] = quantity_factory_allocator( _field.metadata["dims"], _field.metadata["units"], dtype=_field.metadata["dtype"], @@ -48,57 +61,186 @@ def _zeros_recursive(cls): return initial_quantities - dict_of_qty = _zeros_recursive(cls) - return dacite.from_dict(data_class=cls, data=dict_of_qty) + dict_of_quantities = _init_recursive(cls) + return dacite.from_dict(data_class=cls, data=dict_of_quantities) + + @classmethod + def empty(cls, quantity_factory: QuantityFactory) -> Self: + """Allocate all quantities""" + + return cls._init(quantity_factory.empty) + + @classmethod + def zeros(cls, quantity_factory: QuantityFactory) -> Self: + """Allocate all quantities and fill their value to zeros""" + + return cls._init(quantity_factory.zeros) + + @classmethod + def ones(cls, quantity_factory: QuantityFactory) -> Self: + """Allocate all quantities and fill their value to ones""" + + return cls._init(quantity_factory.ones) + + @classmethod + def from_memory( + cls, + quantity_factory: QuantityFactory, + memory_map: dict[str, Any], + ) -> Self: + """Allocate all quantities and fill their value based + on the given memory map. See `update_from_memory`""" + + state = cls.zeros(quantity_factory) + state.update_from_memory(memory_map) + + return state + + @classmethod + def from_zero_copy( + cls, + quantity_factory: QuantityFactory, + memory_map: dict[str, Any], + check_shape_and_strides: bool, + ) -> Self: + """Allocate all quantities and buffer swap their memory based on + on the given memory map. See `update_zero_copy`.""" + + state = cls.zeros(quantity_factory) + state.update_zero_copy(memory_map, check_shape_and_strides) + + return state - def init_from_memory(self, memory_map: dict[str, Any]): - """Will copy data from the memory map if it follows the nested - naming convention of the dataclass""" + def update_from_memory(self, memory_map: dict[str, Any]): + """Copy data from the memory map if it follows the nested + naming convention of the dataclass. E.g. - def _init_from_memory_recursive(dataclss, memory_map: dict[str, Any]): + ```python + @dataclass + class MyState: + @dataclass + class InnerA + a: Quantity + + inner_a: InnerA + b: Quantity + ``` + will update with a dictionary looking like + ```python + { + "inner_a": + { + "a": Quantity(...) + } + "b": Quantity(...) + + } + ``` + + The memory map drives what is loaded, therefore it can be sparse. + """ + + def _update_from_memory_recursive(dataclss, memory_map: dict[str, Any]): for name, array in memory_map.items(): if isinstance(array, dict): - _init_from_memory_recursive(dataclss.__getattribute__(name), array) + _update_from_memory_recursive( + dataclss.__getattribute__(name), array + ) else: try: dataclss.__getattribute__(name).field[:] = array - except ValueError as e: + except Exception as e: e.add_note( f"Error when initializing field {name} on state {type(self)}" ) raise e - _init_from_memory_recursive(self, memory_map) + _update_from_memory_recursive(self, memory_map) - def init_zero_copy(self, memory_map: dict[str, Any], check: bool = True): + def update_zero_copy( + self, + memory_map: dict[str, Any], + check_shape_and_strides: bool, + ): """Swap buffers given into the Quantities carried by the state - by following dataclass naming convention""" + by following dataclass naming convention, e.g. + + ```python + @dataclass + class MyState: + @dataclass + class InnerA + a: Quantity + + inner_a: InnerA + b: Quantity + ``` + will update with a dictionary looking like + ```python + { + "inner_a": + { + "a": Quantity(...) + } + "b": Quantity(...) - def _init_zero_copy_recursive(dataclss, memory_map: dict[str, Any | ArrayLike]): + } + ``` + + The memory map drives what is loaded, therefore it can be sparse. + + Args: + memory_map: Dictionary of keys to buffers. Buffers must be np.ArrayLike + check_shape_and_strides: check that the given buffers have the same + shape and strides as the original quantity + """ + + def _update_zero_copy_recursive( + dataclss, memory_map: dict[str, Any | ArrayLike] + ): for name, array in memory_map.items(): if isinstance(array, dict): - _init_zero_copy_recursive(dataclss.__getattribute__(name), array) + _update_zero_copy_recursive(dataclss.__getattribute__(name), array) else: - qty = dataclss.__getattribute__(name) - if check: - if array.shape != qty.field.shape: + quantity = dataclss.__getattribute__(name) + if check_shape_and_strides: + assert hasattr(array, "shape") + if array.shape != quantity.field.shape: e = ValueError("Shape mismatch on zero copy for") e.add_note(f" Error on {name} for {type(dataclss)}") - e.add_note(f" Shapes: {array.shape} != {qty.field.shape}") + e.add_note( + f" Shapes: {array.shape} != {quantity.field.shape}" + ) raise e - if array.strides != qty.data.strides: + if array.strides != quantity.data.strides: e = ValueError("Stride mismatch on zero copy for") e.add_note(f" Error on {name} for {type(dataclss)}") e.add_note( - f" Strides: {array.strides} != {qty.data.strides}" + f" Strides: {array.strides} != {quantity.data.strides}" ) raise e - qty.data = array + quantity.data = array - _init_zero_copy_recursive(self, memory_map) + _update_zero_copy_recursive(self, memory_map) + + def _netcdf_name(self, directory_path: Path) -> Path: + # Resolve rank-tied postfix if needed + rank_postfix = "" + if MPI.COMM_WORLD.Get_size() > 1: + rank_postfix = f"_rank{MPI.COMM_WORLD.Get_rank()}" + return directory_path / f"{type(self).__name__}{rank_postfix}.nc4" + + def to_netcdf(self, directory_path: Path = Path("./")) -> None: + """ + Save state to NetCDF. Can be reloaded with `update_from_netcdf`. + + If applicable will ave seperate netcdf for each running rank + + Args: + directory_path: directory to save the netcdf in + """ - def to_netcdf(self, path: str = "./"): def _save_recursive(datclss: State): local_data = {} for _field in dataclasses.fields(datclss): @@ -131,17 +273,18 @@ def _save_recursive(datclss: State): datatree.pop(key) datatree["/"] = xr.Dataset(data_vars=top_level) - # Resolve rank-tied postfix if needed - rank_postfix = "" - if MPI.COMM_WORLD.Get_size() > 1: - rank_postfix = f"_rank{MPI.COMM_WORLD.Get_rank()}" + xr.DataTree.from_dict(datatree).to_netcdf(self._netcdf_name(directory_path)) + + def update_from_netcdf(self, directory_path: Path) -> None: + """This is a mirror of the `to_netcdf` method NOT a generic + NetCDF loader. It expects the NetCDF to be named with auto-naming scheme + of `to_netcdf`, be a `xarray.DataTree` in shape matching exactly the - xr.DataTree.from_dict(datatree).to_netcdf( - f"{path}{type(self).__name__}{rank_postfix}.nc4" - ) + Args: + directory_path: directory carrying the netcdf saved with `to_netcdf` - def from_netcdf(self, path: str): - datatree = xr.open_datatree(path) + """ + datatree = xr.open_datatree(self._netcdf_name(directory_path)) datatree_as_dict = datatree.to_dict() # All other cases - recursing downward @@ -163,4 +306,4 @@ def _load_recursive(data_tree_as_dict: dict[str, xr.Dataset] | xr.Dataset): data_as_numpy_dict = _load_recursive(datatree_as_dict) - self.init_from_memory(data_as_numpy_dict) + self.update_from_memory(data_as_numpy_dict) diff --git a/tests/quantity/test_state.py b/tests/quantity/test_state.py index 4f409ccd..851ec00b 100644 --- a/tests/quantity/test_state.py +++ b/tests/quantity/test_state.py @@ -1,4 +1,5 @@ import dataclasses +from pathlib import Path import numpy as np @@ -47,21 +48,23 @@ class InnerB: def test_state(): - _, qty_factry = get_factories_single_tile(5, 5, 3, 0, backend="dace:cpu_kfirst") + _, quantity_factory = get_factories_single_tile( + 5, 5, 3, 0, backend="dace:cpu_kfirst" + ) - microphys_state = CodeState.zeros(qty_factry) + microphys_state = CodeState.zeros(quantity_factory) microphys_state.inner_A.A.field[:] = 42.42 microphys_state.to_netcdf() - microphys_state2 = CodeState.zeros(qty_factry) - microphys_state2.from_netcdf("CodeState.nc4") + microphys_state2 = CodeState.zeros(quantity_factory) + microphys_state2.update_from_netcdf(Path("./")) assert (microphys_state2.inner_A.A.field[:] == 42.42).all() a = np.ones((5, 5, 3)) b = np.ones((5, 5, 3)) c = np.ones((5, 5, 3)) b[:] = 23.23 - microphys_state2.init_zero_copy( + microphys_state2.update_zero_copy( {"inner_A": {"A": a}, "inner_B": {"B": b}, "C": c}, - check=False, + check_shape_and_strides=False, ) assert (microphys_state2.inner_A.A.field[:] == 1.0).all() assert (microphys_state2.inner_B.B.field[:] == 23.23).all() From e8da5727e8adb1faa67fa7a1ad075f7b883357d2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 1 Oct 2025 13:27:28 -0400 Subject: [PATCH 07/12] Harden `quantity.data = ...` & add unit test --- ndsl/quantity/quantity.py | 19 +++++++++++++++---- tests/quantity/test_quantity.py | 17 +++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index aab045eb..8a00a909 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -255,17 +255,28 @@ def field(self) -> np.ndarray | cupy.ndarray: return self._compute_domain_view[:] @property - def data(self) -> Union[np.ndarray, cupy.ndarray]: + def data(self) -> np.ndarray | cupy.ndarray: """The underlying array of data""" return self._data @data.setter - def data(self, inputData): - if type(inputData) in [np.ndarray, cupy.ndarray]: - self._data = inputData + def data(self, input_data: np.ndarray | cupy.ndarray): + if type(input_data) in [np.ndarray, cupy.ndarray]: + if input_data.shape < self.data.shape: + raise ValueError( + "Quantity.data buffer swap failed: " + f"new data ({input_data.shape}) is smaller " + f"than expected data ({self.shape})." + ) + self._data = input_data self._compute_domain_view = BoundedArrayView( self.data, self.dims, self.origin, self.extent ) + else: + raise TypeError( + "Quantity.data buffer swap failed: " + f"given data is not an array (type: {type(input_data)})" + ) @property def origin(self) -> Tuple[int, ...]: diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 0250a96b..ea4202d3 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -261,3 +261,20 @@ def test_to_data_array(quantity): assert ( quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data ), "data memory address is not equal" + + +def test_data_setter(): + quantity = Quantity(np.array(5), dims=[], units="") + + # Allows swap: new array is bigger than Q.shape + new_array = np.array(10) + quantity.data = new_array + + # Expected fail: new array is too small + new_array = np.array(2) + with pytest.raises(ValueError): + quantity.data = new_array + + # Expected fail: new array is not even an array + with pytest.raises(TypeError): + quantity.data = "meh" From 3a9fe3301ce668e41ab8f924d2180ec417c0c26e Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 1 Oct 2025 13:37:26 -0400 Subject: [PATCH 08/12] Use extent --- ndsl/quantity/quantity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 8a00a909..82965ba7 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -262,11 +262,11 @@ def data(self) -> np.ndarray | cupy.ndarray: @data.setter def data(self, input_data: np.ndarray | cupy.ndarray): if type(input_data) in [np.ndarray, cupy.ndarray]: - if input_data.shape < self.data.shape: + if input_data.shape < self.extent: raise ValueError( "Quantity.data buffer swap failed: " f"new data ({input_data.shape}) is smaller " - f"than expected data ({self.shape})." + f"than expected extent ({self.extent})." ) self._data = input_data self._compute_domain_view = BoundedArrayView( From 04079d3d611eaecef4f1a2da23c67b307198860d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 1 Oct 2025 13:54:09 -0400 Subject: [PATCH 09/12] Proper test fix --- tests/quantity/test_quantity.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index ea4202d3..dac7ab59 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -258,23 +258,27 @@ def test_to_data_array(quantity): assert quantity.field_as_xarray.shape == quantity.extent np.testing.assert_array_equal(quantity.field_as_xarray.values, quantity.view[:]) if quantity.extent == quantity.data.shape: - assert ( - quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data - ), "data memory address is not equal" + assert quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data, ( + "data memory address is not equal" + ) def test_data_setter(): - quantity = Quantity(np.array(5), dims=[], units="") + quantity = Quantity(np.ones((5,)), dims=["dim1"], units="") # Allows swap: new array is bigger than Q.shape - new_array = np.array(10) + new_array = np.ones((10,)) quantity.data = new_array # Expected fail: new array is too small - new_array = np.array(2) + new_array = np.ones((2,)) with pytest.raises(ValueError): quantity.data = new_array # Expected fail: new array is not even an array with pytest.raises(TypeError): quantity.data = "meh" + + +if __name__ == "__main__": + test_data_setter() From e00a8c3476625745bdbcfb4344cd28ad55f40202 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 1 Oct 2025 13:55:59 -0400 Subject: [PATCH 10/12] Lint --- tests/quantity/test_quantity.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index dac7ab59..96271309 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -258,9 +258,9 @@ def test_to_data_array(quantity): assert quantity.field_as_xarray.shape == quantity.extent np.testing.assert_array_equal(quantity.field_as_xarray.values, quantity.view[:]) if quantity.extent == quantity.data.shape: - assert quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data, ( - "data memory address is not equal" - ) + assert ( + quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data + ), "data memory address is not equal" def test_data_setter(): From af1529f4b5d486623a2f152c7056c5281387403d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 2 Oct 2025 09:32:32 -0400 Subject: [PATCH 11/12] Better raise programming for `Quantity.data` --- ndsl/quantity/quantity.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 82965ba7..9272fa22 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -261,23 +261,24 @@ def data(self) -> np.ndarray | cupy.ndarray: @data.setter def data(self, input_data: np.ndarray | cupy.ndarray): - if type(input_data) in [np.ndarray, cupy.ndarray]: - if input_data.shape < self.extent: - raise ValueError( - "Quantity.data buffer swap failed: " - f"new data ({input_data.shape}) is smaller " - f"than expected extent ({self.extent})." - ) - self._data = input_data - self._compute_domain_view = BoundedArrayView( - self.data, self.dims, self.origin, self.extent - ) - else: + if type(input_data) not in [np.ndarray, cupy.ndarray]: raise TypeError( "Quantity.data buffer swap failed: " f"given data is not an array (type: {type(input_data)})" ) + if input_data.shape < self.extent: + raise ValueError( + "Quantity.data buffer swap failed: " + f"new data ({input_data.shape}) is smaller " + f"than expected extent ({self.extent})." + ) + + self._data = input_data + self._compute_domain_view = BoundedArrayView( + self.data, self.dims, self.origin, self.extent + ) + @property def origin(self) -> Tuple[int, ...]: """The start of the computational domain""" From ed3c77acab950c53fa2b4f1b4551be8c52094128 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 2 Oct 2025 09:33:02 -0400 Subject: [PATCH 12/12] Better docs, naming for states. Improve quantity unit test --- ndsl/quantity/state.py | 87 +++++++++++++++++++-------------- tests/quantity/test_quantity.py | 19 ++++--- tests/quantity/test_state.py | 2 +- 3 files changed, 63 insertions(+), 45 deletions(-) diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py index 83daa53f..b9ccbeb0 100644 --- a/ndsl/quantity/state.py +++ b/ndsl/quantity/state.py @@ -83,7 +83,7 @@ def ones(cls, quantity_factory: QuantityFactory) -> Self: return cls._init(quantity_factory.ones) @classmethod - def from_memory( + def copy_memory( cls, quantity_factory: QuantityFactory, memory_map: dict[str, Any], @@ -92,28 +92,33 @@ def from_memory( on the given memory map. See `update_from_memory`""" state = cls.zeros(quantity_factory) - state.update_from_memory(memory_map) + state.update_copy_memory(memory_map) return state @classmethod - def from_zero_copy( + def move_memory( cls, quantity_factory: QuantityFactory, memory_map: dict[str, Any], - check_shape_and_strides: bool, + *, + check_shape_and_strides: bool = True, ) -> Self: - """Allocate all quantities and buffer swap their memory based on - on the given memory map. See `update_zero_copy`.""" + """Allocate all quantities and move memory based on + on the given memory map. See `update_move_memory`.""" state = cls.zeros(quantity_factory) - state.update_zero_copy(memory_map, check_shape_and_strides) + state.update_move_memory( + memory_map, + check_shape_and_strides=check_shape_and_strides, + ) return state - def update_from_memory(self, memory_map: dict[str, Any]): - """Copy data from the memory map if it follows the nested - naming convention of the dataclass. E.g. + def update_copy_memory(self, memory_map: dict[str, Any]) -> None: + """Copy data into the Quantities carried by the state. + + The memory map must follow the dataclass naming convention, e.g. ```python @dataclass @@ -137,18 +142,19 @@ class InnerA } ``` - The memory map drives what is loaded, therefore it can be sparse. + The memory map can be sparse. """ - def _update_from_memory_recursive(dataclss, memory_map: dict[str, Any]): + def _update_from_memory_recursive( + state: State, + memory_map: dict[str, dict | ArrayLike], + ): for name, array in memory_map.items(): if isinstance(array, dict): - _update_from_memory_recursive( - dataclss.__getattribute__(name), array - ) + _update_from_memory_recursive(state.__getattribute__(name), array) else: try: - dataclss.__getattribute__(name).field[:] = array + state.__getattribute__(name).field[:] = array except Exception as e: e.add_note( f"Error when initializing field {name} on state {type(self)}" @@ -157,13 +163,16 @@ def _update_from_memory_recursive(dataclss, memory_map: dict[str, Any]): _update_from_memory_recursive(self, memory_map) - def update_zero_copy( + def update_move_memory( self, - memory_map: dict[str, Any], - check_shape_and_strides: bool, - ): - """Swap buffers given into the Quantities carried by the state - by following dataclass naming convention, e.g. + memory_map: dict[str, dict | ArrayLike], + *, + check_shape_and_strides: bool = True, + ) -> None: + """Move memory into the Quantities carried by the state. + Memory is moved rather than copied (e.g. buffers are swapped) + + The memory map must follow the dataclass naming convention, e.g. ```python @dataclass @@ -187,7 +196,7 @@ class InnerA } ``` - The memory map drives what is loaded, therefore it can be sparse. + The memory map can be sparse. Args: memory_map: Dictionary of keys to buffers. Buffers must be np.ArrayLike @@ -196,25 +205,24 @@ class InnerA """ def _update_zero_copy_recursive( - dataclss, memory_map: dict[str, Any | ArrayLike] + state: State, memory_map: dict[str, dict | ArrayLike] ): for name, array in memory_map.items(): if isinstance(array, dict): - _update_zero_copy_recursive(dataclss.__getattribute__(name), array) + _update_zero_copy_recursive(state.__getattribute__(name), array) else: - quantity = dataclss.__getattribute__(name) + quantity = state.__getattribute__(name) if check_shape_and_strides: - assert hasattr(array, "shape") if array.shape != quantity.field.shape: e = ValueError("Shape mismatch on zero copy for") - e.add_note(f" Error on {name} for {type(dataclss)}") + e.add_note(f" Error on {name} for {type(state)}") e.add_note( f" Shapes: {array.shape} != {quantity.field.shape}" ) raise e if array.strides != quantity.data.strides: e = ValueError("Stride mismatch on zero copy for") - e.add_note(f" Error on {name} for {type(dataclss)}") + e.add_note(f" Error on {name} for {type(state)}") e.add_note( f" Strides: {array.strides} != {quantity.data.strides}" ) @@ -225,7 +233,7 @@ def _update_zero_copy_recursive( _update_zero_copy_recursive(self, memory_map) def _netcdf_name(self, directory_path: Path) -> Path: - # Resolve rank-tied postfix if needed + """Resolve rank-tied postfix if needed""" rank_postfix = "" if MPI.COMM_WORLD.Get_size() > 1: rank_postfix = f"_rank{MPI.COMM_WORLD.Get_rank()}" @@ -235,18 +243,21 @@ def to_netcdf(self, directory_path: Path = Path("./")) -> None: """ Save state to NetCDF. Can be reloaded with `update_from_netcdf`. - If applicable will ave seperate netcdf for each running rank + If applicable, will save seperate NetCDF files for each running rank. + + The file names are deduced from the class name, and post fix with rank number + in the case of a multi-process use. Args: directory_path: directory to save the netcdf in """ - def _save_recursive(datclss: State): + def _save_recursive(state: State): local_data = {} - for _field in dataclasses.fields(datclss): + for _field in dataclasses.fields(state): if dataclasses.is_dataclass(_field.type): local_data[_field.name] = xr.Dataset( - data_vars=_save_recursive(datclss.__getattribute__(_field.name)) + data_vars=_save_recursive(state.__getattribute__(_field.name)) ) else: if "dims" not in _field.metadata.keys(): @@ -255,7 +266,7 @@ def _save_recursive(datclss: State): f"Quantity in {_field.name} of type {_field.type}" ) - local_data[_field.name] = datclss.__getattribute__( + local_data[_field.name] = state.__getattribute__( _field.name ).field_as_xarray @@ -277,8 +288,8 @@ def _save_recursive(datclss: State): def update_from_netcdf(self, directory_path: Path) -> None: """This is a mirror of the `to_netcdf` method NOT a generic - NetCDF loader. It expects the NetCDF to be named with auto-naming scheme - of `to_netcdf`, be a `xarray.DataTree` in shape matching exactly the + NetCDF loader. It expects the NetCDF to be named with the auto-naming scheme + of `to_netcdf`. Args: directory_path: directory carrying the netcdf saved with `to_netcdf` @@ -306,4 +317,4 @@ def _load_recursive(data_tree_as_dict: dict[str, xr.Dataset] | xr.Dataset): data_as_numpy_dict = _load_recursive(datatree_as_dict) - self.update_from_memory(data_as_numpy_dict) + self.update_copy_memory(data_as_numpy_dict) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 96271309..dccfa94f 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -266,19 +266,26 @@ def test_to_data_array(quantity): def test_data_setter(): quantity = Quantity(np.ones((5,)), dims=["dim1"], units="") + # After allocation - field and data are the same (origin is 0) + assert quantity.data.shape == quantity.field.shape + # Allows swap: new array is bigger than Q.shape new_array = np.ones((10,)) + new_array[:] = 2 quantity.data = new_array + # After swap - field and data points to the same memory + # BUT field still respects the original origin/extent + assert (quantity.data[:] == 2).all() + assert (quantity.field[:] == 2).all() + assert quantity.data.shape != quantity.field.shape + assert quantity.field.shape == (5,) + # Expected fail: new array is too small new_array = np.ones((2,)) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Quantity.data buffer swap failed.*"): quantity.data = new_array # Expected fail: new array is not even an array - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="Quantity.data buffer swap failed.*"): quantity.data = "meh" - - -if __name__ == "__main__": - test_data_setter() diff --git a/tests/quantity/test_state.py b/tests/quantity/test_state.py index 851ec00b..86ec1a19 100644 --- a/tests/quantity/test_state.py +++ b/tests/quantity/test_state.py @@ -62,7 +62,7 @@ def test_state(): b = np.ones((5, 5, 3)) c = np.ones((5, 5, 3)) b[:] = 23.23 - microphys_state2.update_zero_copy( + microphys_state2.update_move_memory( {"inner_A": {"A": a}, "inner_B": {"B": b}, "C": c}, check_shape_and_strides=False, )