diff --git a/ndsl/quantity/state.py b/ndsl/quantity/state.py index 4e1a7dd0..9447cccd 100644 --- a/ndsl/quantity/state.py +++ b/ndsl/quantity/state.py @@ -500,20 +500,10 @@ def _netcdf_name(self, directory_path: Path, postfix: str = "") -> Path: rank_postfix = f"_rank{MPI.COMM_WORLD.Get_rank()}" return directory_path / f"{type(self).__name__}{rank_postfix}{postfix}.nc4" - def to_netcdf(self, directory_path: Path | None = None, postfix: str = "") -> None: + def to_xarray(self) -> xr.DataTree: """ - Save state to NetCDF. Can be reloaded with `update_from_netcdf`. - - If applicable, will save separate NetCDF files for each running rank. - - The file names are deduced from the class name, and post fix with rank number - in the case of a multi-process use. - - Args: - directory_path: directory to save the netcdf in + Format the State into a xr.DataTree. """ - if directory_path is None: - directory_path = Path("./") def _save_recursive(state: State) -> dict: local_data = {} @@ -547,9 +537,24 @@ def _save_recursive(state: State) -> dict: datatree.pop(key) datatree["/"] = xr.Dataset(data_vars=top_level) - xr.DataTree.from_dict(datatree).to_netcdf( - self._netcdf_name(directory_path, postfix) - ) + return xr.DataTree.from_dict(datatree) + + def to_netcdf(self, directory_path: Path | None = None, postfix: str = "") -> None: + """ + Save state to NetCDF. Can be reloaded with `update_from_netcdf`. + + If applicable, will save separate NetCDF files for each running rank. + + The file names are deduced from the class name, and post fix with rank number + in the case of a multi-process use. + + Args: + directory_path: directory to save the netcdf in + """ + if directory_path is None: + directory_path = Path("./") + + self.to_xarray().to_netcdf(self._netcdf_name(directory_path, postfix)) def update_from_netcdf(self, directory_path: Path, postfix: str = "") -> None: """This is a mirror of the `to_netcdf` method NOT a generic