Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions ndsl/quantity/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,20 +500,10 @@ def _netcdf_name(self, directory_path: Path, postfix: str = "") -> Path:
rank_postfix = f"_rank{MPI.COMM_WORLD.Get_rank()}"
return directory_path / f"{type(self).__name__}{rank_postfix}{postfix}.nc4"

def to_netcdf(self, directory_path: Path | None = None, postfix: str = "") -> None:
def to_xarray(self) -> xr.DataTree:
"""
Save state to NetCDF. Can be reloaded with `update_from_netcdf`.

If applicable, will save separate NetCDF files for each running rank.

The file names are deduced from the class name, and post fix with rank number
in the case of a multi-process use.

Args:
directory_path: directory to save the netcdf in
Format the State into a xr.DataTree.
"""
if directory_path is None:
directory_path = Path("./")

def _save_recursive(state: State) -> dict:
local_data = {}
Expand Down Expand Up @@ -547,9 +537,24 @@ def _save_recursive(state: State) -> dict:
datatree.pop(key)
datatree["/"] = xr.Dataset(data_vars=top_level)

xr.DataTree.from_dict(datatree).to_netcdf(
self._netcdf_name(directory_path, postfix)
)
return xr.DataTree.from_dict(datatree)

def to_netcdf(self, directory_path: Path | None = None, postfix: str = "") -> None:
"""
Save state to NetCDF. Can be reloaded with `update_from_netcdf`.

If applicable, will save separate NetCDF files for each running rank.

The file names are deduced from the class name, and post fix with rank number
in the case of a multi-process use.

Args:
directory_path: directory to save the netcdf in
"""
if directory_path is None:
directory_path = Path("./")

self.to_xarray().to_netcdf(self._netcdf_name(directory_path, postfix))

def update_from_netcdf(self, directory_path: Path, postfix: str = "") -> None:
"""This is a mirror of the `to_netcdf` method NOT a generic
Expand Down