Skip to content

Commit

Permalink
Fix writing of DataTree subgroups to zarr or netCDF (#9677)
Browse files Browse the repository at this point in the history
* Fix writing of DataTree subgroups to zarr or netCDF

Consider a DataTree with a group, e.g.,
`tree = DataTree.from_dict({'/': ... '/child': ...})`

If we write `tree['/child']` to disk, the result should have groups
relative to `'/child'`, so writing and reading from the same path
restores the same object.

In addition, coordinates defined at the root should be written to
disk instead of being omitted.

* Add write_inherited_coords for additional control in DataTree.to_zarr

As discussed in the last xarray meeting, this defaults to
write_inherited_coords=True, which has a little more overhead but means
you always get coordinates when opening a sub-group.

* Switch write_inherited_coords default to false

* add whats new

* remove unused import
  • Loading branch information
shoyer authored Nov 4, 2024
1 parent fc05da9 commit 577221d
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 79 deletions.
9 changes: 8 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ New Features
~~~~~~~~~~~~
- Added :py:meth:`DataTree.persist` method (:issue:`9675`, :pull:`9682`).
By `Sam Levang <https://github.com/slevang>`_.
- Added ``write_inherited_coords`` option to :py:meth:`DataTree.to_netcdf`
and :py:meth:`DataTree.to_zarr` (:pull:`9677`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
(:issue:`2852`, :issue:`757`).
By `Deepak Cherian <https://github.com/dcherian>`_.
Expand All @@ -42,7 +45,11 @@ Deprecations
Bug fixes
~~~~~~~~~

- Fix inadvertent deep-copying of child data in DataTree.
- Fix inadvertent deep-copying of child data in DataTree (:issue:`9683`,
:pull:`9684`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Avoid including parent groups when writing DataTree subgroups to Zarr or
netCDF (:pull:`9682`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Fix regression in the interoperability of :py:meth:`DataArray.polyfit` and :py:meth:`xr.polyval` for date-time coordinates. (:pull:`9691`).
By `Pascal Bourgault <https://github.com/aulemahal>`_.
Expand Down
14 changes: 14 additions & 0 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,7 @@ def to_netcdf(
format: T_DataTreeNetcdfTypes | None = None,
engine: T_DataTreeNetcdfEngine | None = None,
group: str | None = None,
write_inherited_coords: bool = False,
compute: bool = True,
**kwargs,
):
Expand Down Expand Up @@ -1609,6 +1610,11 @@ def to_netcdf(
group : str, optional
Path to the netCDF4 group in the given file to open as the root group
of the ``DataTree``. Currently, specifying a group is not supported.
write_inherited_coords : bool, default: False
If true, replicate inherited coordinates on all descendant nodes.
Otherwise, only write coordinates at the level at which they are
originally defined. This saves disk space, but requires opening the
full tree to load inherited coordinates.
compute : bool, default: True
If true compute immediately, otherwise return a
``dask.delayed.Delayed`` object that can be computed later.
Expand All @@ -1632,6 +1638,7 @@ def to_netcdf(
format=format,
engine=engine,
group=group,
write_inherited_coords=write_inherited_coords,
compute=compute,
**kwargs,
)
Expand All @@ -1643,6 +1650,7 @@ def to_zarr(
encoding=None,
consolidated: bool = True,
group: str | None = None,
write_inherited_coords: bool = False,
compute: Literal[True] = True,
**kwargs,
):
Expand All @@ -1668,6 +1676,11 @@ def to_zarr(
after writing metadata for all groups.
group : str, optional
Group path. (a.k.a. `path` in zarr terminology.)
write_inherited_coords : bool, default: False
If true, replicate inherited coordinates on all descendant nodes.
Otherwise, only write coordinates at the level at which they are
originally defined. This saves disk space, but requires opening the
full tree to load inherited coordinates.
compute : bool, default: True
If true compute immediately, otherwise return a
``dask.delayed.Delayed`` object that can be computed later. Metadata
Expand All @@ -1690,6 +1703,7 @@ def to_zarr(
encoding=encoding,
consolidated=consolidated,
group=group,
write_inherited_coords=write_inherited_coords,
compute=compute,
**kwargs,
)
Expand Down
106 changes: 28 additions & 78 deletions xarray/core/datatree_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,15 @@

from collections.abc import Mapping, MutableMapping
from os import PathLike
from typing import TYPE_CHECKING, Any, Literal, get_args
from typing import Any, Literal, get_args

from xarray.core.datatree import DataTree
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes

if TYPE_CHECKING:
from h5netcdf.legacyapi import Dataset as h5Dataset
from netCDF4 import Dataset as ncDataset

T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"]
T_DataTreeNetcdfTypes = Literal["NETCDF4"]


def _get_nc_dataset_class(
engine: T_DataTreeNetcdfEngine | None,
) -> type[ncDataset] | type[h5Dataset]:
if engine == "netcdf4":
from netCDF4 import Dataset as ncDataset

return ncDataset
if engine == "h5netcdf":
from h5netcdf.legacyapi import Dataset as h5Dataset

return h5Dataset
if engine is None:
try:
from netCDF4 import Dataset as ncDataset

return ncDataset
except ImportError:
from h5netcdf.legacyapi import Dataset as h5Dataset

return h5Dataset
raise ValueError(f"unsupported engine: {engine}")


def _create_empty_netcdf_group(
filename: str | PathLike,
group: str,
mode: NetcdfWriteModes,
engine: T_DataTreeNetcdfEngine | None,
) -> None:
ncDataset = _get_nc_dataset_class(engine)

with ncDataset(filename, mode=mode) as rootgrp:
rootgrp.createGroup(group)


def _datatree_to_netcdf(
dt: DataTree,
filepath: str | PathLike,
Expand All @@ -59,6 +20,7 @@ def _datatree_to_netcdf(
format: T_DataTreeNetcdfTypes | None = None,
engine: T_DataTreeNetcdfEngine | None = None,
group: str | None = None,
write_inherited_coords: bool = False,
compute: bool = True,
**kwargs,
) -> None:
Expand Down Expand Up @@ -97,41 +59,31 @@ def _datatree_to_netcdf(
unlimited_dims = {}

for node in dt.subtree:
ds = node.to_dataset(inherit=False)
group_path = node.path
if ds is None:
_create_empty_netcdf_group(filepath, group_path, mode, engine)
else:
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
unlimited_dims=unlimited_dims.get(node.path),
engine=engine,
format=format,
compute=compute,
**kwargs,
)
at_root = node is dt
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
group_path = None if at_root else "/" + node.relative_to(dt)
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
unlimited_dims=unlimited_dims.get(node.path),
engine=engine,
format=format,
compute=compute,
**kwargs,
)
mode = "a"


def _create_empty_zarr_group(
store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes
):
import zarr

root = zarr.open_group(store, mode=mode)
root.create_group(group, overwrite=True)


def _datatree_to_zarr(
dt: DataTree,
store: MutableMapping | str | PathLike[str],
mode: ZarrWriteModes = "w-",
encoding: Mapping[str, Any] | None = None,
consolidated: bool = True,
group: str | None = None,
write_inherited_coords: bool = False,
compute: Literal[True] = True,
**kwargs,
):
Expand Down Expand Up @@ -163,19 +115,17 @@ def _datatree_to_zarr(
)

for node in dt.subtree:
ds = node.to_dataset(inherit=False)
group_path = node.path
if ds is None:
_create_empty_zarr_group(store, group_path, mode)
else:
ds.to_zarr(
store,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
consolidated=False,
**kwargs,
)
at_root = node is dt
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
group_path = None if at_root else "/" + node.relative_to(dt)
ds.to_zarr(
store,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
consolidated=False,
**kwargs,
)
if "w" in mode:
mode = "a"

Expand Down
74 changes: 74 additions & 0 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,24 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)

def test_write_subgroup(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
).children["child"]

expected_dt = original_dt.copy()
expected_dt.name = None

filepath = tmpdir / "test.zarr"
original_dt.to_netcdf(filepath, engine=self.engine)

with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
assert_identical(expected_dt, roundtrip_dt)


@requires_netCDF4
class TestNetCDF4DatatreeIO(DatatreeIOBase):
Expand Down Expand Up @@ -556,3 +574,59 @@ def test_open_groups_chunks(self, tmpdir) -> None:

for ds in dict_of_datasets.values():
ds.close()

def test_write_subgroup(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
).children["child"]

expected_dt = original_dt.copy()
expected_dt.name = None

filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath)

with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
assert_identical(expected_dt, roundtrip_dt)

def test_write_inherited_coords_false(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
)

filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath, write_inherited_coords=False)

with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_identical(original_dt, roundtrip_dt)

expected_child = original_dt.children["child"].copy(inherit=False)
expected_child.name = None
with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child:
assert_identical(expected_child, roundtrip_child)

def test_write_inherited_coords_true(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
)

filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath, write_inherited_coords=True)

with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_identical(original_dt, roundtrip_dt)

expected_child = original_dt.children["child"].copy(inherit=True)
expected_child.name = None
with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child:
assert_identical(expected_child, roundtrip_child)

0 comments on commit 577221d

Please sign in to comment.