Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dask methods on DataTree #9670

Merged
merged 26 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f625f5d
implement `compute` and `load`
keewis Oct 24, 2024
507fb7d
also shallow-copy variables
keewis Oct 24, 2024
ce1683a
implement `chunksizes`
keewis Oct 24, 2024
f2a4683
add tests for `load`
keewis Oct 24, 2024
f0ff30f
add tests for `chunksizes`
keewis Oct 24, 2024
d12203c
improve the `load` tests using `DataTree.chunksizes`
keewis Oct 24, 2024
dda02ed
add a test for `compute`
keewis Oct 24, 2024
329c689
un-xfail a xpassing test
keewis Oct 24, 2024
c9fb461
implement and test `DataTree.chunk`
keewis Oct 24, 2024
0305fc5
link to `Dataset.load`
keewis Oct 24, 2024
e2a3a14
use `tree.subtree` to get absolute paths
keewis Oct 24, 2024
7f57ffa
filter out missing dims before delegating to `Dataset.chunk`
keewis Oct 24, 2024
900701b
fix the type hints for `DataTree.chunksizes`
keewis Oct 24, 2024
d45dbd0
try using `self.from_dict` instead
keewis Oct 24, 2024
5f88937
type-hint intermediate test variables
keewis Oct 24, 2024
b515972
Merge branch 'main' into datatree-dask-methods
keewis Oct 24, 2024
39d95f6
use `_node_dims` instead
keewis Oct 24, 2024
73b4466
raise on unknown chunk dim
keewis Oct 24, 2024
8b35676
check that errors in `chunk` are raised properly
keewis Oct 24, 2024
d70f4f0
adapt the docstrings of the new methods
keewis Oct 24, 2024
3e9745f
allow computing / loading unchunked trees
keewis Oct 24, 2024
b6c5f9a
reword the `chunksizes` properties
keewis Oct 24, 2024
da8df36
also freeze the top-level chunk sizes
keewis Oct 24, 2024
b11a1ef
also reword `DataArray.chunksizes`
keewis Oct 24, 2024
53c0897
fix a copy-paste error
keewis Oct 24, 2024
f7e31b4
same for `NamedArray.chunksizes`
keewis Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from xarray.core._aggregations import DataTreeAggregations
from xarray.core._typed_ops import DataTreeOpsMixin
from xarray.core.alignment import align
from xarray.core.common import TreeAttrAccessMixin
from xarray.core.common import TreeAttrAccessMixin, get_chunksizes
from xarray.core.coordinates import Coordinates, DataTreeCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
Expand Down Expand Up @@ -49,6 +49,8 @@
parse_dims_as_set,
)
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array

try:
from xarray.core.variable import calculate_dimensions
Expand Down Expand Up @@ -862,9 +864,9 @@ def _copy_node(
) -> Self:
"""Copy just one node of a tree."""
new_node = super()._copy_node(inherit=inherit, deep=deep, memo=memo)
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)
if deep:
data = data._copy(deep=True, memo=memo)
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)._copy(
deep=deep, memo=memo
)
Comment on lines +870 to +872
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

new_node._set_node_data(data)
return new_node

Expand Down Expand Up @@ -1896,3 +1898,102 @@ def apply_indexers(dataset, node_indexers):

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
return self._selective_indexing(apply_indexers, indexers)

def load(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this datatree's data
from disk or a remote source into memory and return this datatree.
Unlike compute, the original datatree is modified and returned.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.

See Also
--------
dask.compute
keewis marked this conversation as resolved.
Show resolved Hide resolved
"""
# access .data to coerce everything to numpy or dask arrays
lazy_data = {
path: {
k: v._data
for k, v in node.variables.items()
if is_chunked_array(v._data)
}
for path, node in self.subtree_with_keys
}
flat_lazy_data = {
(path, var_name): array
for path, node in lazy_data.items()
for var_name, array in node.items()
}
if lazy_data:
chunkmanager = get_chunked_array_type(*flat_lazy_data.values())

# evaluate all the chunked arrays simultaneously
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
*flat_lazy_data.values(), **kwargs
)

for (path, var_name), data in zip(
flat_lazy_data, evaluated_data, strict=False
):
self[path].variables[var_name].data = data
keewis marked this conversation as resolved.
Show resolved Hide resolved

# load everything else sequentially
for node in self.subtree:
for k, v in node.variables.items():
if k not in lazy_data:
v.load()

return self

def compute(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this datatree's data
from disk or a remote source into memory and return a new dataset.
Unlike load, the original dataset is left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.

Returns
-------
object : DataTree
New object with lazy data variables and coordinates as in-memory arrays.

See Also
--------
dask.compute
"""
new = self.copy(deep=False)
return new.load(**kwargs)

@property
def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]:
"""
Mapping from group paths to a mapping of dimension names to block lengths for this dataset's data, or None if
the underlying data is not a dask array.

Cannot be modified directly, but can be modified by calling .chunk().

See Also
--------
DataTree.chunk
Dataset.chunksizes
"""
return {
f"/{path}" if path != "." else "/": get_chunksizes(node.variables.values())
for path, node in self.subtree_with_keys
keewis marked this conversation as resolved.
Show resolved Hide resolved
}
83 changes: 81 additions & 2 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from xarray.core.datatree import DataTree
from xarray.core.treenode import NotFoundInTreeError
from xarray.testing import assert_equal, assert_identical
from xarray.tests import assert_array_equal, create_test_data, source_ndarray
from xarray.tests import (
assert_array_equal,
create_test_data,
requires_dask,
source_ndarray,
)

ON_WINDOWS = sys.platform == "win32"

Expand Down Expand Up @@ -858,7 +863,6 @@ def test_to_dict(self):
actual = DataTree.from_dict(tree.children["a"].to_dict(relative=True))
assert_identical(expected, actual)

@pytest.mark.xfail
def test_roundtrip_unnamed_root(self, simple_datatree) -> None:
# See GH81

Expand Down Expand Up @@ -2195,3 +2199,78 @@ def test_close_dataset(self, tree_and_closers):

# with tree:
# pass


@requires_dask
class TestDask:
def test_chunksizes(self):
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
ds3 = xr.Dataset({"c": ("z", np.arange(4))})
ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))})

groups = {
"/": ds1.chunk({"x": 5}),
"/group1": ds2.chunk({"y": 3}),
"/group2": ds3.chunk({"z": 2}),
"/group1/subgroup1": ds4.chunk({"x": 5}),
}

tree = xr.DataTree.from_dict(groups)

expected_chunksizes = {path: node.chunksizes for path, node in groups.items()}

assert tree.chunksizes == expected_chunksizes

def test_load(self):
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
ds3 = xr.Dataset({"c": ("z", np.arange(4))})
ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))})

expected = xr.DataTree.from_dict(
{"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4}
)
tree = xr.DataTree.from_dict(
{
"/": ds1.chunk({"x": 5}),
"/group1": ds2.chunk({"y": 3}),
"/group2": ds3.chunk({"z": 2}),
"/group1/subgroup1": ds4.chunk({"x": 5}),
}
)
expected_chunksizes = {
f"/{path}" if path != "." else "/": {} for path, _ in tree.subtree_with_keys
}
actual = tree.load()

assert_identical(actual, expected)
assert tree.chunksizes == expected_chunksizes

def test_compute(self):
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
ds3 = xr.Dataset({"c": ("z", np.arange(4))})
ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))})

expected = xr.DataTree.from_dict(
{"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4}
)
tree = xr.DataTree.from_dict(
{
"/": ds1.chunk({"x": 5}),
"/group1": ds2.chunk({"y": 3}),
"/group2": ds3.chunk({"z": 2}),
"/group1/subgroup1": ds4.chunk({"x": 5}),
}
)
original_chunksizes = tree.chunksizes
expected_chunksizes = {
f"/{path}" if path != "." else "/": {} for path, _ in tree.subtree_with_keys
}
actual = tree.compute()

assert_identical(actual, expected)

assert actual.chunksizes == expected_chunksizes, "mismatching chunksizes"
assert tree.chunksizes == original_chunksizes, "original tree was modified"
Loading