5757 ParquetDatasource ,
5858 BlockWritePathProvider ,
5959 DefaultBlockWritePathProvider ,
60+ ReadTask ,
6061 WriteResult ,
6162)
6263from ray .data .datasource .file_based_datasource import (
@@ -988,26 +989,26 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
988989
989990 start_time = time .perf_counter ()
990991 context = DatasetContext .get_current ()
991- calls : List [Callable [[], ObjectRef [BlockPartition ]]] = []
992- metadata : List [BlockPartitionMetadata ] = []
992+ tasks : List [ReadTask ] = []
993993 block_partitions : List [ObjectRef [BlockPartition ]] = []
994+ block_partitions_meta : List [ObjectRef [BlockPartitionMetadata ]] = []
994995
995996 datasets = [self ] + list (other )
996997 for ds in datasets :
997998 bl = ds ._plan .execute ()
998999 if isinstance (bl , LazyBlockList ):
999- calls .extend (bl ._calls )
1000- metadata .extend (bl ._metadata )
1000+ tasks .extend (bl ._tasks )
10011001 block_partitions .extend (bl ._block_partitions )
1002+ block_partitions_meta .extend (bl ._block_partitions_meta )
10021003 else :
1003- calls .extend ([None ] * bl .initial_num_blocks ())
1004- metadata .extend (bl ._metadata )
1004+ tasks .extend ([ReadTask (lambda : None , meta ) for meta in bl ._metadata ])
10051005 if context .block_splitting_enabled :
10061006 block_partitions .extend (
10071007 [ray .put ([(b , m )]) for b , m in bl .get_blocks_with_metadata ()]
10081008 )
10091009 else :
10101010 block_partitions .extend (bl .get_blocks ())
1011+ block_partitions_meta .extend ([ray .put (meta ) for meta in bl ._metadata ])
10111012
10121013 epochs = [ds ._get_epoch () for ds in datasets ]
10131014 max_epoch = max (* epochs )
@@ -1028,7 +1029,8 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
10281029 dataset_stats .time_total_s = time .perf_counter () - start_time
10291030 return Dataset (
10301031 ExecutionPlan (
1031- LazyBlockList (calls , metadata , block_partitions ), dataset_stats
1032+ LazyBlockList (tasks , block_partitions , block_partitions_meta ),
1033+ dataset_stats ,
10321034 ),
10331035 max_epoch ,
10341036 self ._lazy ,
@@ -2548,6 +2550,7 @@ def repeat(self, times: Optional[int] = None) -> "DatasetPipeline[T]":
25482550 # to enable fusion with downstream map stages.
25492551 ctx = DatasetContext .get_current ()
25502552 if self ._plan ._is_read_stage () and ctx .optimize_fuse_read_stages :
2553+ self ._plan ._in_blocks .clear ()
25512554 blocks , read_stage = self ._plan ._rewrite_read_stage ()
25522555 outer_stats = DatasetStats (stages = {}, parent = None )
25532556 else :
@@ -2666,6 +2669,7 @@ def window(
26662669 # to enable fusion with downstream map stages.
26672670 ctx = DatasetContext .get_current ()
26682671 if self ._plan ._is_read_stage () and ctx .optimize_fuse_read_stages :
2672+ self ._plan ._in_blocks .clear ()
26692673 blocks , read_stage = self ._plan ._rewrite_read_stage ()
26702674 outer_stats = DatasetStats (stages = {}, parent = None )
26712675 else :
@@ -2749,12 +2753,13 @@ def fully_executed(self) -> "Dataset[T]":
27492753 Returns:
27502754 A Dataset with all blocks fully materialized in memory.
27512755 """
2752- blocks = self .get_internal_block_refs ()
2753- bar = ProgressBar ("Force reads" , len (blocks ))
2754- bar .block_until_complete (blocks )
2756+ blocks , metadata = [], []
2757+ for b , m in self ._plan .execute ().get_blocks_with_metadata ():
2758+ blocks .append (b )
2759+ metadata .append (m )
27552760 ds = Dataset (
27562761 ExecutionPlan (
2757- BlockList (blocks , self . _plan . execute (). get_metadata () ),
2762+ BlockList (blocks , metadata ),
27582763 self ._plan .stats (),
27592764 dataset_uuid = self ._get_uuid (),
27602765 ),
0 commit comments