Skip to content

Commit d9b2891

Browse files
committed
Refactor LazyBlockList.
1 parent 858d607 commit d9b2891

File tree

9 files changed

+427
-170
lines changed

9 files changed

+427
-170
lines changed

python/ray/data/dataset.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
ParquetDatasource,
5858
BlockWritePathProvider,
5959
DefaultBlockWritePathProvider,
60+
ReadTask,
6061
WriteResult,
6162
)
6263
from ray.data.datasource.file_based_datasource import (
@@ -988,26 +989,26 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
988989

989990
start_time = time.perf_counter()
990991
context = DatasetContext.get_current()
991-
calls: List[Callable[[], ObjectRef[BlockPartition]]] = []
992-
metadata: List[BlockPartitionMetadata] = []
992+
tasks: List[ReadTask] = []
993993
block_partitions: List[ObjectRef[BlockPartition]] = []
994+
block_partitions_meta: List[ObjectRef[BlockPartitionMetadata]] = []
994995

995996
datasets = [self] + list(other)
996997
for ds in datasets:
997998
bl = ds._plan.execute()
998999
if isinstance(bl, LazyBlockList):
999-
calls.extend(bl._calls)
1000-
metadata.extend(bl._metadata)
1000+
tasks.extend(bl._tasks)
10011001
block_partitions.extend(bl._block_partitions)
1002+
block_partitions_meta.extend(bl._block_partitions_meta)
10021003
else:
1003-
calls.extend([None] * bl.initial_num_blocks())
1004-
metadata.extend(bl._metadata)
1004+
tasks.extend([ReadTask(lambda: None, meta) for meta in bl._metadata])
10051005
if context.block_splitting_enabled:
10061006
block_partitions.extend(
10071007
[ray.put([(b, m)]) for b, m in bl.get_blocks_with_metadata()]
10081008
)
10091009
else:
10101010
block_partitions.extend(bl.get_blocks())
1011+
block_partitions_meta.extend([ray.put(meta) for meta in bl._metadata])
10111012

10121013
epochs = [ds._get_epoch() for ds in datasets]
10131014
max_epoch = max(*epochs)
@@ -1028,7 +1029,8 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
10281029
dataset_stats.time_total_s = time.perf_counter() - start_time
10291030
return Dataset(
10301031
ExecutionPlan(
1031-
LazyBlockList(calls, metadata, block_partitions), dataset_stats
1032+
LazyBlockList(tasks, block_partitions, block_partitions_meta),
1033+
dataset_stats,
10321034
),
10331035
max_epoch,
10341036
self._lazy,
@@ -2548,6 +2550,7 @@ def repeat(self, times: Optional[int] = None) -> "DatasetPipeline[T]":
25482550
# to enable fusion with downstream map stages.
25492551
ctx = DatasetContext.get_current()
25502552
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
2553+
self._plan._in_blocks.clear()
25512554
blocks, read_stage = self._plan._rewrite_read_stage()
25522555
outer_stats = DatasetStats(stages={}, parent=None)
25532556
else:
@@ -2666,6 +2669,7 @@ def window(
26662669
# to enable fusion with downstream map stages.
26672670
ctx = DatasetContext.get_current()
26682671
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
2672+
self._plan._in_blocks.clear()
26692673
blocks, read_stage = self._plan._rewrite_read_stage()
26702674
outer_stats = DatasetStats(stages={}, parent=None)
26712675
else:
@@ -2749,12 +2753,13 @@ def fully_executed(self) -> "Dataset[T]":
27492753
Returns:
27502754
A Dataset with all blocks fully materialized in memory.
27512755
"""
2752-
blocks = self.get_internal_block_refs()
2753-
bar = ProgressBar("Force reads", len(blocks))
2754-
bar.block_until_complete(blocks)
2756+
blocks, metadata = [], []
2757+
for b, m in self._plan.execute().get_blocks_with_metadata():
2758+
blocks.append(b)
2759+
metadata.append(m)
27552760
ds = Dataset(
27562761
ExecutionPlan(
2757-
BlockList(blocks, self._plan.execute().get_metadata()),
2762+
BlockList(blocks, metadata),
27582763
self._plan.stats(),
27592764
dataset_uuid=self._get_uuid(),
27602765
),

python/ray/data/impl/block_list.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import math
2-
from typing import List, Iterator, Tuple, Any, Union, Optional, TYPE_CHECKING
3-
4-
if TYPE_CHECKING:
5-
import pyarrow
2+
from typing import List, Iterator, Tuple, Optional
63

74
import numpy as np
85

@@ -26,11 +23,7 @@ def __init__(self, blocks: List[ObjectRef[Block]], metadata: List[BlockMetadata]
2623
self._num_blocks = len(self._blocks)
2724
self._metadata: List[BlockMetadata] = metadata
2825

29-
def set_metadata(self, i: int, metadata: BlockMetadata) -> None:
30-
"""Set the metadata for a given block."""
31-
self._metadata[i] = metadata
32-
33-
def get_metadata(self) -> List[BlockMetadata]:
26+
def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]:
3427
"""Get the metadata for all blocks."""
3528
return self._metadata.copy()
3629

@@ -183,22 +176,22 @@ def executed_num_blocks(self) -> int:
183176
"""
184177
return len(self.get_blocks())
185178

186-
def ensure_schema_for_first_block(self) -> Optional[Union["pyarrow.Schema", type]]:
187-
"""Ensure that the schema is set for the first block.
179+
def ensure_metadata_for_first_block(self) -> BlockMetadata:
180+
"""Ensure that the metadata is fetched and set for the first block.
188181
189182
Returns None if the block list is empty.
190183
"""
191-
get_schema = cached_remote_fn(_get_schema)
184+
get_metadata = cached_remote_fn(_get_metadata)
192185
try:
193-
block = next(self.iter_blocks())
186+
block, metadata = next(self.iter_blocks_with_metadata())
194187
except (StopIteration, ValueError):
195188
# Dataset is empty (no blocks) or was manually cleared.
196189
return None
197-
schema = ray.get(get_schema.remote(block))
198-
# Set the schema.
199-
self._metadata[0].schema = schema
200-
return schema
190+
input_files = metadata.input_files
191+
metadata = ray.get(get_metadata.remote(block, input_files))
192+
self._metadata[0] = metadata
193+
return metadata
201194

202195

203-
def _get_schema(block: Block) -> Any:
204-
return BlockAccessor.for_block(block).schema()
196+
def _get_metadata(block: Block, input_files=Optional[List[str]]) -> BlockMetadata:
197+
return BlockAccessor.for_block(block).get_metadata(input_files=input_files)

0 commit comments

Comments
 (0)