Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dask methods on DataTree #9670

Merged
merged 26 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f625f5d
implement `compute` and `load`
keewis Oct 24, 2024
507fb7d
also shallow-copy variables
keewis Oct 24, 2024
ce1683a
implement `chunksizes`
keewis Oct 24, 2024
f2a4683
add tests for `load`
keewis Oct 24, 2024
f0ff30f
add tests for `chunksizes`
keewis Oct 24, 2024
d12203c
improve the `load` tests using `DataTree.chunksizes`
keewis Oct 24, 2024
dda02ed
add a test for `compute`
keewis Oct 24, 2024
329c689
un-xfail a xpassing test
keewis Oct 24, 2024
c9fb461
implement and test `DataTree.chunk`
keewis Oct 24, 2024
0305fc5
link to `Dataset.load`
keewis Oct 24, 2024
e2a3a14
use `tree.subtree` to get absolute paths
keewis Oct 24, 2024
7f57ffa
filter out missing dims before delegating to `Dataset.chunk`
keewis Oct 24, 2024
900701b
fix the type hints for `DataTree.chunksizes`
keewis Oct 24, 2024
d45dbd0
try using `self.from_dict` instead
keewis Oct 24, 2024
5f88937
type-hint intermediate test variables
keewis Oct 24, 2024
b515972
Merge branch 'main' into datatree-dask-methods
keewis Oct 24, 2024
39d95f6
use `_node_dims` instead
keewis Oct 24, 2024
73b4466
raise on unknown chunk dim
keewis Oct 24, 2024
8b35676
check that errors in `chunk` are raised properly
keewis Oct 24, 2024
d70f4f0
adapt the docstrings of the new methods
keewis Oct 24, 2024
3e9745f
allow computing / loading unchunked trees
keewis Oct 24, 2024
b6c5f9a
reword the `chunksizes` properties
keewis Oct 24, 2024
da8df36
also freeze the top-level chunk sizes
keewis Oct 24, 2024
b11a1ef
also reword `DataArray.chunksizes`
keewis Oct 24, 2024
53c0897
fix a copy-paste error
keewis Oct 24, 2024
f7e31b4
same for `NamedArray.chunksizes`
keewis Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 196 additions & 4 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from xarray.core._aggregations import DataTreeAggregations
from xarray.core._typed_ops import DataTreeOpsMixin
from xarray.core.alignment import align
from xarray.core.common import TreeAttrAccessMixin
from xarray.core.common import TreeAttrAccessMixin, get_chunksizes
from xarray.core.coordinates import Coordinates, DataTreeCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
Expand Down Expand Up @@ -49,6 +49,8 @@
parse_dims_as_set,
)
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array

try:
from xarray.core.variable import calculate_dimensions
Expand All @@ -68,8 +70,11 @@
ErrorOptions,
ErrorOptionsWithWarn,
NetcdfWriteModes,
T_ChunkDimFreq,
T_ChunksFreq,
ZarrWriteModes,
)
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint

# """
# DEVELOPERS' NOTE
Expand Down Expand Up @@ -862,9 +867,9 @@ def _copy_node(
) -> Self:
"""Copy just one node of a tree."""
new_node = super()._copy_node(inherit=inherit, deep=deep, memo=memo)
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)
if deep:
data = data._copy(deep=True, memo=memo)
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)._copy(
deep=deep, memo=memo
)
Comment on lines +870 to +872
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

new_node._set_node_data(data)
return new_node

Expand Down Expand Up @@ -1896,3 +1901,190 @@ def apply_indexers(dataset, node_indexers):

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
return self._selective_indexing(apply_indexers, indexers)

def load(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this datatree's data
from disk or a remote source into memory and return this datatree.
Unlike compute, the original datatree is modified and returned.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.

See Also
--------
Dataset.load
dask.compute
keewis marked this conversation as resolved.
Show resolved Hide resolved
"""
# access .data to coerce everything to numpy or dask arrays
lazy_data = {
path: {
k: v._data
for k, v in node.variables.items()
if is_chunked_array(v._data)
}
for path, node in self.subtree_with_keys
}
flat_lazy_data = {
(path, var_name): array
for path, node in lazy_data.items()
for var_name, array in node.items()
}
if lazy_data:
chunkmanager = get_chunked_array_type(*flat_lazy_data.values())

# evaluate all the chunked arrays simultaneously
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
*flat_lazy_data.values(), **kwargs
)

for (path, var_name), data in zip(
flat_lazy_data, evaluated_data, strict=False
):
self[path].variables[var_name].data = data
keewis marked this conversation as resolved.
Show resolved Hide resolved

# load everything else sequentially
for node in self.subtree:
for k, v in node.variables.items():
if k not in lazy_data:
v.load()

return self

def compute(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this datatree's data
from disk or a remote source into memory and return a new dataset.
Unlike load, the original dataset is left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.

Returns
-------
object : DataTree
New object with lazy data variables and coordinates as in-memory arrays.

See Also
--------
dask.compute
"""
new = self.copy(deep=False)
return new.load(**kwargs)

@property
def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]:
"""
Mapping from group paths to a mapping of dimension names to block lengths for this dataset's data, or None if
the underlying data is not a dask array.

Cannot be modified directly, but can be modified by calling .chunk().

See Also
--------
DataTree.chunk
Dataset.chunksizes
"""
return {
node.path: get_chunksizes(node.variables.values()) for node in self.subtree
}

def chunk(
self,
chunks: T_ChunksFreq = {}, # noqa: B006 # {} even though it's technically unsafe, is being used intentionally here (#4667)
name_prefix: str = "xarray-",
token: str | None = None,
lock: bool = False,
inline_array: bool = False,
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
from_array_kwargs=None,
**chunks_kwargs: T_ChunkDimFreq,
) -> Self:
"""Coerce all arrays in all groups in this tree into dask arrays with the given
chunks.

Non-dask arrays in this tree will be converted to dask arrays. Dask
arrays will be rechunked to the given chunk sizes.

If neither chunks is not provided for one or more dimensions, chunk
sizes along that dimension will not be updated; non-dask arrays will be
converted into dask arrays with a single block.

Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted.

Parameters
----------
chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``.
name_prefix : str, default: "xarray-"
Prefix for the name of any new dask arrays.
token : str, optional
Token uniquely identifying this dataset.
lock : bool, default: False
Passed on to :py:func:`dask.array.from_array`, if the array is not
already as dask array.
inline_array: bool, default: False
Passed on to :py:func:`dask.array.from_array`, if the array is not
already as dask array.
chunked_array_type: str, optional
Which chunked array type to coerce this datasets' arrays to.
Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system.
Experimental API that should not be relied upon.
from_array_kwargs: dict, optional
Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create
chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg.
For example, with dask as the default chunked array type, this method would pass additional kwargs
to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
**chunks_kwargs : {dim: chunks, ...}, optional
The keyword arguments form of ``chunks``.
One of chunks or chunks_kwargs must be provided

Returns
-------
chunked : xarray.DataTree

See Also
--------
Dataset.chunk
Dataset.chunksizes
xarray.unify_chunks
dask.array.from_array
"""
# don't support deprecated ways of passing chunks
if not isinstance(chunks, Mapping):
raise TypeError(
f"invalid type for chunks: {type(chunks)}. Only mappings are supported."
)
combined_chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")

rechunked_groups = {
path: node.dataset.chunk(
{
dim: size
for dim, size in combined_chunks.items()
if dim in node.dataset.dims
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to use node.dataset.dims to avoid including inherited dims (which we can't chunk anyways because we only inherit indexed dims, and there is no chunked index so far)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise I just saw _node_dims in _get_all_dims, so that might have less overhead. I'll go ahead and use that instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will also be more explicit as it will avoid using "rebuilt" dims (which doesn't really matter anyway because we can't chunk indexes).

},
name_prefix=name_prefix,
token=token,
lock=lock,
inline_array=inline_array,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
)
for path, node in self.subtree_with_keys
}

return DataTree.from_dict(rechunked_groups, name=self.name)
102 changes: 100 additions & 2 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from xarray.core.datatree import DataTree
from xarray.core.treenode import NotFoundInTreeError
from xarray.testing import assert_equal, assert_identical
from xarray.tests import assert_array_equal, create_test_data, source_ndarray
from xarray.tests import (
assert_array_equal,
create_test_data,
requires_dask,
source_ndarray,
)

ON_WINDOWS = sys.platform == "win32"

Expand Down Expand Up @@ -858,7 +863,6 @@ def test_to_dict(self):
actual = DataTree.from_dict(tree.children["a"].to_dict(relative=True))
assert_identical(expected, actual)

@pytest.mark.xfail
def test_roundtrip_unnamed_root(self, simple_datatree) -> None:
# See GH81

Expand Down Expand Up @@ -2195,3 +2199,97 @@ def test_close_dataset(self, tree_and_closers):

# with tree:
# pass


@requires_dask
class TestDask:
def test_chunksizes(self):
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
ds3 = xr.Dataset({"c": ("z", np.arange(4))})
ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))})

groups = {
"/": ds1.chunk({"x": 5}),
"/group1": ds2.chunk({"y": 3}),
"/group2": ds3.chunk({"z": 2}),
"/group1/subgroup1": ds4.chunk({"x": 5}),
}

tree = xr.DataTree.from_dict(groups)

expected_chunksizes = {path: node.chunksizes for path, node in groups.items()}

assert tree.chunksizes == expected_chunksizes

def test_load(self):
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
ds3 = xr.Dataset({"c": ("z", np.arange(4))})
ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))})

expected = xr.DataTree.from_dict(
{"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4}
)
tree = xr.DataTree.from_dict(
{
"/": ds1.chunk({"x": 5}),
"/group1": ds2.chunk({"y": 3}),
"/group2": ds3.chunk({"z": 2}),
"/group1/subgroup1": ds4.chunk({"x": 5}),
}
)
expected_chunksizes = {node.path: {} for node in tree.subtree}
actual = tree.load()

assert_identical(actual, expected)
assert tree.chunksizes == expected_chunksizes

def test_compute(self):
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
ds3 = xr.Dataset({"c": ("z", np.arange(4))})
ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))})

expected = xr.DataTree.from_dict(
{"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4}
)
tree = xr.DataTree.from_dict(
{
"/": ds1.chunk({"x": 5}),
"/group1": ds2.chunk({"y": 3}),
"/group2": ds3.chunk({"z": 2}),
"/group1/subgroup1": ds4.chunk({"x": 5}),
}
)
original_chunksizes = tree.chunksizes
expected_chunksizes = {node.path: {} for node in tree.subtree}
actual = tree.compute()

assert_identical(actual, expected)

assert actual.chunksizes == expected_chunksizes, "mismatching chunksizes"
assert tree.chunksizes == original_chunksizes, "original tree was modified"

def test_chunk(self):
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
ds3 = xr.Dataset({"c": ("z", np.arange(4))})
ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))})

expected = xr.DataTree.from_dict(
{
"/": ds1.chunk({"x": 5}),
"/group1": ds2.chunk({"y": 3}),
"/group2": ds3.chunk({"z": 2}),
"/group1/subgroup1": ds4.chunk({"x": 5}),
}
)

tree = xr.DataTree.from_dict(
{"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4}
)
actual = tree.chunk({"x": 5, "y": 3, "z": 2})

assert_identical(actual, expected)
assert actual.chunksizes == expected.chunksizes
Loading