-
Notifications
You must be signed in to change notification settings - Fork 16
State V0 #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
FlorianDeconinck
merged 12 commits into
NOAA-GFDL:develop
from
FlorianDeconinck:feature/state_v0
Oct 2, 2025
Merged
State V0 #242
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
646764f
State v0 + utest
FlorianDeconinck 8a4993d
`dacite` as a new dependancy
FlorianDeconinck 6aac410
Fix for `data` setter in quantity reseting the `compute_view` for `fi…
FlorianDeconinck f77228d
Lint
FlorianDeconinck 376d66a
Fix circular import
FlorianDeconinck f9f480b
Add update/init distinction
FlorianDeconinck e8da572
Harden `quantity.data = ...` & add unit test
FlorianDeconinck 3a9fe33
Use extent
FlorianDeconinck 04079d3
Proper test fix
FlorianDeconinck e00a8c3
Lint
FlorianDeconinck af1529f
Better raise programming for `Quantity.data`
FlorianDeconinck ed3c77a
Better docs, naming for states. Improve quantity unit test
FlorianDeconinck File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,11 @@ | ||
| from .metadata import QuantityHaloSpec, QuantityMetadata | ||
| from .quantity import Quantity | ||
| from .state import State | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "Quantity", | ||
| "QuantityMetadata", | ||
| "QuantityHaloSpec", | ||
| "FieldBundle", | ||
| "FieldBundleType", | ||
| "State", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,320 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import dataclasses | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any, Callable, Self | ||
|
|
||
| import dacite | ||
| import xarray as xr | ||
| from mpi4py import MPI | ||
| from numpy.typing import ArrayLike | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from ndsl import QuantityFactory | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class State: | ||
| """Base class for state objects in models. | ||
|
|
||
| A State groups a collection of (possibly nested) Quantities in a dataclass. | ||
|
|
||
| This baseclass implements common initialization functions and serialization. | ||
|
|
||
| Typical usage example: | ||
|
|
||
| ```python | ||
| class MyState(State): | ||
| pass | ||
|
|
||
| my_state = MyState.zeros(quantity_factory) | ||
|
|
||
| # ... | ||
|
|
||
| my_state.to_netcdf() | ||
| ``` | ||
| """ | ||
|
|
||
| @classmethod | ||
| def _init(cls, quantity_factory_allocator: Callable) -> Self: | ||
| """Allocate memory and init with a blind quantity init operation""" | ||
|
|
||
| def _init_recursive(cls): | ||
| initial_quantities = {} | ||
| for _field in dataclasses.fields(cls): | ||
| if dataclasses.is_dataclass(_field.type): | ||
| initial_quantities[_field.name] = _init_recursive(_field.type) | ||
| else: | ||
| if "dims" not in _field.metadata.keys(): | ||
| raise ValueError( | ||
| "Malformed state - no dims to init " | ||
| f"Quantity in {_field.name} of type {_field.type}" | ||
| ) | ||
|
|
||
| initial_quantities[_field.name] = quantity_factory_allocator( | ||
| _field.metadata["dims"], | ||
| _field.metadata["units"], | ||
| dtype=_field.metadata["dtype"], | ||
| allow_mismatch_float_precision=True, | ||
| ) | ||
|
|
||
| return initial_quantities | ||
|
|
||
| dict_of_quantities = _init_recursive(cls) | ||
| return dacite.from_dict(data_class=cls, data=dict_of_quantities) | ||
|
|
||
| @classmethod | ||
| def empty(cls, quantity_factory: QuantityFactory) -> Self: | ||
| """Allocate all quantities""" | ||
|
|
||
| return cls._init(quantity_factory.empty) | ||
|
|
||
| @classmethod | ||
| def zeros(cls, quantity_factory: QuantityFactory) -> Self: | ||
| """Allocate all quantities and fill their value to zeros""" | ||
|
|
||
| return cls._init(quantity_factory.zeros) | ||
|
|
||
| @classmethod | ||
| def ones(cls, quantity_factory: QuantityFactory) -> Self: | ||
| """Allocate all quantities and fill their value to ones""" | ||
|
|
||
| return cls._init(quantity_factory.ones) | ||
|
|
||
| @classmethod | ||
| def copy_memory( | ||
| cls, | ||
| quantity_factory: QuantityFactory, | ||
| memory_map: dict[str, Any], | ||
| ) -> Self: | ||
| """Allocate all quantities and fill their value based | ||
| on the given memory map. See `update_from_memory`""" | ||
|
|
||
| state = cls.zeros(quantity_factory) | ||
| state.update_copy_memory(memory_map) | ||
|
|
||
| return state | ||
|
|
||
| @classmethod | ||
| def move_memory( | ||
| cls, | ||
| quantity_factory: QuantityFactory, | ||
| memory_map: dict[str, Any], | ||
| *, | ||
| check_shape_and_strides: bool = True, | ||
| ) -> Self: | ||
| """Allocate all quantities and move memory based on | ||
| on the given memory map. See `update_move_memory`.""" | ||
|
|
||
| state = cls.zeros(quantity_factory) | ||
| state.update_move_memory( | ||
| memory_map, | ||
| check_shape_and_strides=check_shape_and_strides, | ||
| ) | ||
|
|
||
| return state | ||
|
|
||
| def update_copy_memory(self, memory_map: dict[str, Any]) -> None: | ||
| """Copy data into the Quantities carried by the state. | ||
|
|
||
| The memory map must follow the dataclass naming convention, e.g. | ||
|
|
||
| ```python | ||
| @dataclass | ||
| class MyState: | ||
| @dataclass | ||
| class InnerA | ||
| a: Quantity | ||
|
|
||
| inner_a: InnerA | ||
| b: Quantity | ||
| ``` | ||
| will update with a dictionary looking like | ||
| ```python | ||
| { | ||
| "inner_a": | ||
| { | ||
| "a": Quantity(...) | ||
| } | ||
| "b": Quantity(...) | ||
|
|
||
| } | ||
| ``` | ||
|
|
||
| The memory map can be sparse. | ||
| """ | ||
|
|
||
| def _update_from_memory_recursive( | ||
| state: State, | ||
| memory_map: dict[str, dict | ArrayLike], | ||
| ): | ||
| for name, array in memory_map.items(): | ||
| if isinstance(array, dict): | ||
| _update_from_memory_recursive(state.__getattribute__(name), array) | ||
| else: | ||
| try: | ||
| state.__getattribute__(name).field[:] = array | ||
| except Exception as e: | ||
| e.add_note( | ||
| f"Error when initializing field {name} on state {type(self)}" | ||
| ) | ||
| raise e | ||
|
|
||
| _update_from_memory_recursive(self, memory_map) | ||
|
|
||
| def update_move_memory( | ||
| self, | ||
| memory_map: dict[str, dict | ArrayLike], | ||
| *, | ||
| check_shape_and_strides: bool = True, | ||
| ) -> None: | ||
| """Move memory into the Quantities carried by the state. | ||
| Memory is moved rather than copied (e.g. buffers are swapped) | ||
|
|
||
| The memory map must follow the dataclass naming convention, e.g. | ||
|
|
||
| ```python | ||
| @dataclass | ||
| class MyState: | ||
| @dataclass | ||
| class InnerA | ||
| a: Quantity | ||
|
|
||
| inner_a: InnerA | ||
| b: Quantity | ||
| ``` | ||
| will update with a dictionary looking like | ||
| ```python | ||
| { | ||
| "inner_a": | ||
| { | ||
| "a": Quantity(...) | ||
| } | ||
| "b": Quantity(...) | ||
|
|
||
| } | ||
| ``` | ||
|
|
||
| The memory map can be sparse. | ||
|
|
||
| Args: | ||
| memory_map: Dictionary of keys to buffers. Buffers must be np.ArrayLike | ||
| check_shape_and_strides: check that the given buffers have the same | ||
| shape and strides as the original quantity | ||
| """ | ||
|
|
||
| def _update_zero_copy_recursive( | ||
| state: State, memory_map: dict[str, dict | ArrayLike] | ||
| ): | ||
| for name, array in memory_map.items(): | ||
| if isinstance(array, dict): | ||
| _update_zero_copy_recursive(state.__getattribute__(name), array) | ||
| else: | ||
| quantity = state.__getattribute__(name) | ||
| if check_shape_and_strides: | ||
| if array.shape != quantity.field.shape: | ||
| e = ValueError("Shape mismatch on zero copy for") | ||
| e.add_note(f" Error on {name} for {type(state)}") | ||
| e.add_note( | ||
| f" Shapes: {array.shape} != {quantity.field.shape}" | ||
| ) | ||
| raise e | ||
| if array.strides != quantity.data.strides: | ||
| e = ValueError("Stride mismatch on zero copy for") | ||
| e.add_note(f" Error on {name} for {type(state)}") | ||
| e.add_note( | ||
| f" Strides: {array.strides} != {quantity.data.strides}" | ||
| ) | ||
| raise e | ||
|
|
||
| quantity.data = array | ||
|
|
||
| _update_zero_copy_recursive(self, memory_map) | ||
|
|
||
| def _netcdf_name(self, directory_path: Path) -> Path: | ||
| """Resolve rank-tied postfix if needed""" | ||
| rank_postfix = "" | ||
| if MPI.COMM_WORLD.Get_size() > 1: | ||
| rank_postfix = f"_rank{MPI.COMM_WORLD.Get_rank()}" | ||
| return directory_path / f"{type(self).__name__}{rank_postfix}.nc4" | ||
|
|
||
| def to_netcdf(self, directory_path: Path = Path("./")) -> None: | ||
| """ | ||
| Save state to NetCDF. Can be reloaded with `update_from_netcdf`. | ||
|
|
||
| If applicable, will save seperate NetCDF files for each running rank. | ||
|
|
||
| The file names are deduced from the class name, and post fix with rank number | ||
| in the case of a multi-process use. | ||
|
|
||
| Args: | ||
| directory_path: directory to save the netcdf in | ||
| """ | ||
|
|
||
| def _save_recursive(state: State): | ||
| local_data = {} | ||
| for _field in dataclasses.fields(state): | ||
| if dataclasses.is_dataclass(_field.type): | ||
| local_data[_field.name] = xr.Dataset( | ||
| data_vars=_save_recursive(state.__getattribute__(_field.name)) | ||
| ) | ||
| else: | ||
| if "dims" not in _field.metadata.keys(): | ||
| raise ValueError( | ||
| "Malformed state - no dims to init " | ||
| f"Quantity in {_field.name} of type {_field.type}" | ||
| ) | ||
|
|
||
| local_data[_field.name] = state.__getattribute__( | ||
| _field.name | ||
| ).field_as_xarray | ||
|
|
||
| return local_data | ||
|
|
||
| datatree = _save_recursive(self) | ||
|
|
||
| # Move top-level into their own dataset in the "/" prefix | ||
| # to match DataTree expected format | ||
| top_level = {} | ||
| for key, value in datatree.items(): | ||
| if not isinstance(value, xr.Dataset): | ||
| top_level[key] = value | ||
| for key, value in top_level.items(): | ||
| datatree.pop(key) | ||
| datatree["/"] = xr.Dataset(data_vars=top_level) | ||
|
|
||
| xr.DataTree.from_dict(datatree).to_netcdf(self._netcdf_name(directory_path)) | ||
|
|
||
| def update_from_netcdf(self, directory_path: Path) -> None: | ||
| """This is a mirror of the `to_netcdf` method NOT a generic | ||
| NetCDF loader. It expects the NetCDF to be named with the auto-naming scheme | ||
| of `to_netcdf`. | ||
|
|
||
| Args: | ||
| directory_path: directory carrying the netcdf saved with `to_netcdf` | ||
|
|
||
| """ | ||
|
FlorianDeconinck marked this conversation as resolved.
|
||
| datatree = xr.open_datatree(self._netcdf_name(directory_path)) | ||
| datatree_as_dict = datatree.to_dict() | ||
|
|
||
| # All other cases - recursing downward | ||
| def _load_recursive(data_tree_as_dict: dict[str, xr.Dataset] | xr.Dataset): | ||
| local_data_dict = {} | ||
| for name, data_array in data_tree_as_dict.items(): | ||
| # Case of the top_level "/" | ||
| if name == "/": | ||
| for root_name, root_data_array in datatree_as_dict["/"].items(): | ||
| local_data_dict[root_name] = root_data_array.to_numpy() | ||
| else: | ||
| # Get the leading `/` out | ||
| if isinstance(data_array, xr.Dataset): | ||
| local_data_dict[name[1:]] = _load_recursive(data_array) | ||
| else: | ||
| local_data_dict[name] = data_array.to_numpy() | ||
|
|
||
| return local_data_dict | ||
|
|
||
| data_as_numpy_dict = _load_recursive(datatree_as_dict) | ||
|
|
||
| self.update_copy_memory(data_as_numpy_dict) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.