Skip to content

Commit 37522e9

Browse files
authored
Support for dask.graph_manipulation (#4965)
* Support dask.graph_manipulation * fix * What's New * [test-upstream]
1 parent 66acafa commit 37522e9

File tree

5 files changed

+106
-65
lines changed

5 files changed

+106
-65
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ v0.17.1 (unreleased)
2222

2323
New Features
2424
~~~~~~~~~~~~
25-
25+
- Support for `dask.graph_manipulation
26+
<https://docs.dask.org/en/latest/graph_manipulation.html>`_ (requires dask >=2021.3)
27+
By `Guido Imperiale <https://github.com/crusaderky>`_
2628

2729
Breaking changes
2830
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -839,15 +839,15 @@ def __dask_scheduler__(self):
839839

840840
def __dask_postcompute__(self):
841841
func, args = self._to_temp_dataset().__dask_postcompute__()
842-
return self._dask_finalize, (func, args, self.name)
842+
return self._dask_finalize, (self.name, func) + args
843843

844844
def __dask_postpersist__(self):
845845
func, args = self._to_temp_dataset().__dask_postpersist__()
846-
return self._dask_finalize, (func, args, self.name)
846+
return self._dask_finalize, (self.name, func) + args
847847

848848
@staticmethod
849-
def _dask_finalize(results, func, args, name):
850-
ds = func(results, *args)
849+
def _dask_finalize(results, name, func, *args, **kwargs):
850+
ds = func(results, *args, **kwargs)
851851
variable = ds._variables.pop(_THIS_ARRAY)
852852
coords = ds._variables
853853
return DataArray(variable, coords, name=name, fastpath=True)

xarray/core/dataset.py

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -863,72 +863,83 @@ def __dask_scheduler__(self):
863863
return da.Array.__dask_scheduler__
864864

865865
def __dask_postcompute__(self):
866+
return self._dask_postcompute, ()
867+
868+
def __dask_postpersist__(self):
869+
return self._dask_postpersist, ()
870+
871+
def _dask_postcompute(self, results: "Iterable[Variable]") -> "Dataset":
866872
import dask
867873

868-
info = [
869-
(k, None) + v.__dask_postcompute__()
870-
if dask.is_dask_collection(v)
871-
else (k, v, None, None)
872-
for k, v in self._variables.items()
873-
]
874-
construct_direct_args = (
874+
variables = {}
875+
results_iter = iter(results)
876+
877+
for k, v in self._variables.items():
878+
if dask.is_dask_collection(v):
879+
rebuild, args = v.__dask_postcompute__()
880+
v = rebuild(next(results_iter), *args)
881+
variables[k] = v
882+
883+
return Dataset._construct_direct(
884+
variables,
875885
self._coord_names,
876886
self._dims,
877887
self._attrs,
878888
self._indexes,
879889
self._encoding,
880890
self._close,
881891
)
882-
return self._dask_postcompute, (info, construct_direct_args)
883892

884-
def __dask_postpersist__(self):
885-
import dask
893+
def _dask_postpersist(
894+
self, dsk: Mapping, *, rename: Mapping[str, str] = None
895+
) -> "Dataset":
896+
from dask import is_dask_collection
897+
from dask.highlevelgraph import HighLevelGraph
898+
from dask.optimization import cull
886899

887-
info = [
888-
(k, None, v.__dask_keys__()) + v.__dask_postpersist__()
889-
if dask.is_dask_collection(v)
890-
else (k, v, None, None, None)
891-
for k, v in self._variables.items()
892-
]
893-
construct_direct_args = (
900+
variables = {}
901+
902+
for k, v in self._variables.items():
903+
if not is_dask_collection(v):
904+
variables[k] = v
905+
continue
906+
907+
if isinstance(dsk, HighLevelGraph):
908+
# dask >= 2021.3
909+
# __dask_postpersist__() was called by dask.highlevelgraph.
910+
# Don't use dsk.cull(), as we need to prevent partial layers:
911+
# https://github.com/dask/dask/issues/7137
912+
layers = v.__dask_layers__()
913+
if rename:
914+
layers = [rename.get(k, k) for k in layers]
915+
dsk2 = dsk.cull_layers(layers)
916+
elif rename: # pragma: nocover
917+
# At the moment of writing, this is only for forward compatibility.
918+
# replace_name_in_key requires dask >= 2021.3.
919+
from dask.base import flatten, replace_name_in_key
920+
921+
keys = [
922+
replace_name_in_key(k, rename) for k in flatten(v.__dask_keys__())
923+
]
924+
dsk2, _ = cull(dsk, keys)
925+
else:
926+
# __dask_postpersist__() was called by dask.optimize or dask.persist
927+
dsk2, _ = cull(dsk, v.__dask_keys__())
928+
929+
rebuild, args = v.__dask_postpersist__()
930+
# rename was added in dask 2021.3
931+
kwargs = {"rename": rename} if rename else {}
932+
variables[k] = rebuild(dsk2, *args, **kwargs)
933+
934+
return Dataset._construct_direct(
935+
variables,
894936
self._coord_names,
895937
self._dims,
896938
self._attrs,
897939
self._indexes,
898940
self._encoding,
899941
self._close,
900942
)
901-
return self._dask_postpersist, (info, construct_direct_args)
902-
903-
@staticmethod
904-
def _dask_postcompute(results, info, construct_direct_args):
905-
variables = {}
906-
results_iter = iter(results)
907-
for k, v, rebuild, rebuild_args in info:
908-
if v is None:
909-
variables[k] = rebuild(next(results_iter), *rebuild_args)
910-
else:
911-
variables[k] = v
912-
913-
final = Dataset._construct_direct(variables, *construct_direct_args)
914-
return final
915-
916-
@staticmethod
917-
def _dask_postpersist(dsk, info, construct_direct_args):
918-
from dask.optimization import cull
919-
920-
variables = {}
921-
# postpersist is called in both dask.optimize and dask.persist
922-
# When persisting, we want to filter out unrelated keys for
923-
# each Variable's task graph.
924-
for k, v, dask_keys, rebuild, rebuild_args in info:
925-
if v is None:
926-
dsk2, _ = cull(dsk, dask_keys)
927-
variables[k] = rebuild(dsk2, *rebuild_args)
928-
else:
929-
variables[k] = v
930-
931-
return Dataset._construct_direct(variables, *construct_direct_args)
932943

933944
def compute(self, **kwargs) -> "Dataset":
934945
"""Manually trigger loading and/or computation of this dataset's data

xarray/core/variable.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -531,22 +531,15 @@ def __dask_scheduler__(self):
531531

532532
def __dask_postcompute__(self):
533533
array_func, array_args = self._data.__dask_postcompute__()
534-
return (
535-
self._dask_finalize,
536-
(array_func, array_args, self._dims, self._attrs, self._encoding),
537-
)
534+
return self._dask_finalize, (array_func,) + array_args
538535

539536
def __dask_postpersist__(self):
540537
array_func, array_args = self._data.__dask_postpersist__()
541-
return (
542-
self._dask_finalize,
543-
(array_func, array_args, self._dims, self._attrs, self._encoding),
544-
)
538+
return self._dask_finalize, (array_func,) + array_args
545539

546-
@staticmethod
547-
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
548-
data = array_func(results, *array_args)
549-
return Variable(dims, data, attrs=attrs, encoding=encoding)
540+
def _dask_finalize(self, results, array_func, *args, **kwargs):
541+
data = array_func(results, *args, **kwargs)
542+
return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding)
550543

551544
@property
552545
def values(self):

xarray/tests/test_dask.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,3 +1599,38 @@ def test_optimize():
15991599
arr = xr.DataArray(a).chunk(5)
16001600
(arr2,) = dask.optimize(arr)
16011601
arr2.compute()
1602+
1603+
1604+
# The graph_manipulation module is in dask since 2021.2 but it became usable with
1605+
# xarray only since 2021.3
1606+
@pytest.mark.skipif(LooseVersion(dask.__version__) <= "2021.02.0", reason="new module")
1607+
def test_graph_manipulation():
1608+
"""dask.graph_manipulation passes an optional parameter, "rename", to the rebuilder
1609+
function returned by __dask_postperist__; also, the dsk passed to the rebuilder is
1610+
a HighLevelGraph whereas with dask.persist() and dask.optimize() it's a plain dict.
1611+
"""
1612+
import dask.graph_manipulation as gm
1613+
1614+
v = Variable(["x"], [1, 2]).chunk(-1).chunk(1) * 2
1615+
da = DataArray(v)
1616+
ds = Dataset({"d1": v[0], "d2": v[1], "d3": ("x", [3, 4])})
1617+
1618+
v2, da2, ds2 = gm.clone(v, da, ds)
1619+
1620+
assert_equal(v2, v)
1621+
assert_equal(da2, da)
1622+
assert_equal(ds2, ds)
1623+
1624+
for a, b in ((v, v2), (da, da2), (ds, ds2)):
1625+
assert a.__dask_layers__() != b.__dask_layers__()
1626+
assert len(a.__dask_layers__()) == len(b.__dask_layers__())
1627+
assert a.__dask_graph__().keys() != b.__dask_graph__().keys()
1628+
assert len(a.__dask_graph__()) == len(b.__dask_graph__())
1629+
assert a.__dask_graph__().layers.keys() != b.__dask_graph__().layers.keys()
1630+
assert len(a.__dask_graph__().layers) == len(b.__dask_graph__().layers)
1631+
1632+
# Above we performed a slice operation; adding the two slices back together creates
1633+
# a diamond-shaped dependency graph, which in turn will trigger a collision in layer
1634+
# names if we were to use HighLevelGraph.cull() instead of
1635+
# HighLevelGraph.cull_layers() in Dataset.__dask_postpersist__().
1636+
assert_equal(ds2.d1 + ds2.d2, ds.d1 + ds.d2)

0 commit comments

Comments
 (0)