@@ -866,79 +866,69 @@ def __dask_postcompute__(self):
866866 import dask
867867
868868 info = [
869- (True , k , v .__dask_postcompute__ () )
869+ (k , None ) + v .__dask_postcompute__ ()
870870 if dask .is_dask_collection (v )
871- else (False , k , v )
871+ else (k , v , None , None )
872872 for k , v in self ._variables .items ()
873873 ]
874- args = (
875- info ,
874+ construct_direct_args = (
876875 self ._coord_names ,
877876 self ._dims ,
878877 self ._attrs ,
879878 self ._indexes ,
880879 self ._encoding ,
881880 self ._close ,
882881 )
883- return self ._dask_postcompute , args
882+ return self ._dask_postcompute , ( info , construct_direct_args )
884883
885884 def __dask_postpersist__ (self ):
886885 import dask
887886
888887 info = [
889- (True , k , v .__dask_postpersist__ () )
888+ (k , None , v .__dask_keys__ ()) + v . __dask_postpersist__ ( )
890889 if dask .is_dask_collection (v )
891- else (False , k , v )
890+ else (k , v , None , None , None )
892891 for k , v in self ._variables .items ()
893892 ]
894- args = (
895- info ,
893+ construct_direct_args = (
896894 self ._coord_names ,
897895 self ._dims ,
898896 self ._attrs ,
899897 self ._indexes ,
900898 self ._encoding ,
901899 self ._close ,
902900 )
903- return self ._dask_postpersist , args
901+ return self ._dask_postpersist , ( info , construct_direct_args )
904902
905903 @staticmethod
906- def _dask_postcompute (results , info , * args ):
904+ def _dask_postcompute (results , info , construct_direct_args ):
907905 variables = {}
908- results2 = list (results [::- 1 ])
909- for is_dask , k , v in info :
910- if is_dask :
911- func , args2 = v
912- r = results2 .pop ()
913- result = func (r , * args2 )
906+ results_iter = iter (results )
907+ for k , v , rebuild , rebuild_args in info :
908+ if v is None :
909+ variables [k ] = rebuild (next (results_iter ), * rebuild_args )
914910 else :
915- result = v
916- variables [k ] = result
911+ variables [k ] = v
917912
918- final = Dataset ._construct_direct (variables , * args )
913+ final = Dataset ._construct_direct (variables , * construct_direct_args )
919914 return final
920915
921916 @staticmethod
922- def _dask_postpersist (dsk , info , * args ):
917+ def _dask_postpersist (dsk , info , construct_direct_args ):
918+ from dask .optimization import cull
919+
923920 variables = {}
924921 # postpersist is called in both dask.optimize and dask.persist
925922 # When persisting, we want to filter out unrelated keys for
926923 # each Variable's task graph.
927- is_persist = len (dsk ) == len (info )
928- for is_dask , k , v in info :
929- if is_dask :
930- func , args2 = v
931- if is_persist :
932- name = args2 [1 ][0 ]
933- dsk2 = {k : v for k , v in dsk .items () if k [0 ] == name }
934- else :
935- dsk2 = dsk
936- result = func (dsk2 , * args2 )
924+ for k , v , dask_keys , rebuild , rebuild_args in info :
925+ if v is None :
926+ dsk2 , _ = cull (dsk , dask_keys )
927+ variables [k ] = rebuild (dsk2 , * rebuild_args )
937928 else :
938- result = v
939- variables [k ] = result
929+ variables [k ] = v
940930
941- return Dataset ._construct_direct (variables , * args )
931+ return Dataset ._construct_direct (variables , * construct_direct_args )
942932
943933 def compute (self , ** kwargs ) -> "Dataset" :
944934 """Manually trigger loading and/or computation of this dataset's data
0 commit comments