@@ -863,72 +863,83 @@ def __dask_scheduler__(self):
863863 return da .Array .__dask_scheduler__
864864
865865 def __dask_postcompute__ (self ):
866+ return self ._dask_postcompute , ()
867+
868+ def __dask_postpersist__ (self ):
869+ return self ._dask_postpersist , ()
870+
871+ def _dask_postcompute (self , results : "Iterable[Variable]" ) -> "Dataset" :
866872 import dask
867873
868- info = [
869- (k , None ) + v .__dask_postcompute__ ()
870- if dask .is_dask_collection (v )
871- else (k , v , None , None )
872- for k , v in self ._variables .items ()
873- ]
874- construct_direct_args = (
874+ variables = {}
875+ results_iter = iter (results )
876+
877+ for k , v in self ._variables .items ():
878+ if dask .is_dask_collection (v ):
879+ rebuild , args = v .__dask_postcompute__ ()
880+ v = rebuild (next (results_iter ), * args )
881+ variables [k ] = v
882+
883+ return Dataset ._construct_direct (
884+ variables ,
875885 self ._coord_names ,
876886 self ._dims ,
877887 self ._attrs ,
878888 self ._indexes ,
879889 self ._encoding ,
880890 self ._close ,
881891 )
882- return self ._dask_postcompute , (info , construct_direct_args )
883892
884- def __dask_postpersist__ (self ):
885- import dask
893+ def _dask_postpersist (
894+ self , dsk : Mapping , * , rename : Mapping [str , str ] = None
895+ ) -> "Dataset" :
896+ from dask import is_dask_collection
897+ from dask .highlevelgraph import HighLevelGraph
898+ from dask .optimization import cull
886899
887- info = [
888- (k , None , v .__dask_keys__ ()) + v .__dask_postpersist__ ()
889- if dask .is_dask_collection (v )
890- else (k , v , None , None , None )
891- for k , v in self ._variables .items ()
892- ]
893- construct_direct_args = (
900+ variables = {}
901+
902+ for k , v in self ._variables .items ():
903+ if not is_dask_collection (v ):
904+ variables [k ] = v
905+ continue
906+
907+ if isinstance (dsk , HighLevelGraph ):
908+ # dask >= 2021.3
909+ # __dask_postpersist__() was called by dask.highlevelgraph.
910+ # Don't use dsk.cull(), as we need to prevent partial layers:
911+ # https://github.com/dask/dask/issues/7137
912+ layers = v .__dask_layers__ ()
913+ if rename :
914+ layers = [rename .get (k , k ) for k in layers ]
915+ dsk2 = dsk .cull_layers (layers )
916+ elif rename : # pragma: nocover
917+ # At the moment of writing, this is only for forward compatibility.
918+ # replace_name_in_key requires dask >= 2021.3.
919+ from dask .base import flatten , replace_name_in_key
920+
921+ keys = [
922+ replace_name_in_key (k , rename ) for k in flatten (v .__dask_keys__ ())
923+ ]
924+ dsk2 , _ = cull (dsk , keys )
925+ else :
926+ # __dask_postpersist__() was called by dask.optimize or dask.persist
927+ dsk2 , _ = cull (dsk , v .__dask_keys__ ())
928+
929+ rebuild , args = v .__dask_postpersist__ ()
930+ # rename was added in dask 2021.3
931+ kwargs = {"rename" : rename } if rename else {}
932+ variables [k ] = rebuild (dsk2 , * args , ** kwargs )
933+
934+ return Dataset ._construct_direct (
935+ variables ,
894936 self ._coord_names ,
895937 self ._dims ,
896938 self ._attrs ,
897939 self ._indexes ,
898940 self ._encoding ,
899941 self ._close ,
900942 )
901- return self ._dask_postpersist , (info , construct_direct_args )
902-
903- @staticmethod
904- def _dask_postcompute (results , info , construct_direct_args ):
905- variables = {}
906- results_iter = iter (results )
907- for k , v , rebuild , rebuild_args in info :
908- if v is None :
909- variables [k ] = rebuild (next (results_iter ), * rebuild_args )
910- else :
911- variables [k ] = v
912-
913- final = Dataset ._construct_direct (variables , * construct_direct_args )
914- return final
915-
916- @staticmethod
917- def _dask_postpersist (dsk , info , construct_direct_args ):
918- from dask .optimization import cull
919-
920- variables = {}
921- # postpersist is called in both dask.optimize and dask.persist
922- # When persisting, we want to filter out unrelated keys for
923- # each Variable's task graph.
924- for k , v , dask_keys , rebuild , rebuild_args in info :
925- if v is None :
926- dsk2 , _ = cull (dsk , dask_keys )
927- variables [k ] = rebuild (dsk2 , * rebuild_args )
928- else :
929- variables [k ] = v
930-
931- return Dataset ._construct_direct (variables , * construct_direct_args )
932943
933944 def compute (self , ** kwargs ) -> "Dataset" :
934945 """Manually trigger loading and/or computation of this dataset's data
0 commit comments