Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
ParquetDatasource,
BlockWritePathProvider,
DefaultBlockWritePathProvider,
ReadTask,
WriteResult,
)
from ray.data.datasource.file_based_datasource import (
Expand Down Expand Up @@ -988,26 +989,28 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":

start_time = time.perf_counter()
context = DatasetContext.get_current()
calls: List[Callable[[], ObjectRef[BlockPartition]]] = []
metadata: List[BlockPartitionMetadata] = []
block_partitions: List[ObjectRef[BlockPartition]] = []
tasks: List[ReadTask] = []
block_partition_refs: List[ObjectRef[BlockPartition]] = []
block_partition_meta_refs: List[ObjectRef[BlockPartitionMetadata]] = []

datasets = [self] + list(other)
for ds in datasets:
bl = ds._plan.execute()
if isinstance(bl, LazyBlockList):
calls.extend(bl._calls)
metadata.extend(bl._metadata)
block_partitions.extend(bl._block_partitions)
tasks.extend(bl._tasks)
block_partition_refs.extend(bl._block_partition_refs)
block_partition_meta_refs.extend(bl._block_partition_meta_refs)
else:
calls.extend([None] * bl.initial_num_blocks())
metadata.extend(bl._metadata)
tasks.extend([ReadTask(lambda: None, meta) for meta in bl._metadata])
if context.block_splitting_enabled:
block_partitions.extend(
block_partition_refs.extend(
[ray.put([(b, m)]) for b, m in bl.get_blocks_with_metadata()]
)
else:
block_partitions.extend(bl.get_blocks())
block_partition_refs.extend(bl.get_blocks())
block_partition_meta_refs.extend(
[ray.put(meta) for meta in bl._metadata]
)

epochs = [ds._get_epoch() for ds in datasets]
max_epoch = max(*epochs)
Expand All @@ -1028,7 +1031,8 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
dataset_stats.time_total_s = time.perf_counter() - start_time
return Dataset(
ExecutionPlan(
LazyBlockList(calls, metadata, block_partitions), dataset_stats
LazyBlockList(tasks, block_partition_refs, block_partition_meta_refs),
dataset_stats,
),
max_epoch,
self._lazy,
Expand Down Expand Up @@ -2548,6 +2552,7 @@ def repeat(self, times: Optional[int] = None) -> "DatasetPipeline[T]":
# to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
self._plan._in_blocks.clear()
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
Expand Down Expand Up @@ -2666,6 +2671,7 @@ def window(
# to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
self._plan._in_blocks.clear()
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
Expand Down Expand Up @@ -2749,12 +2755,13 @@ def fully_executed(self) -> "Dataset[T]":
Returns:
A Dataset with all blocks fully materialized in memory.
"""
blocks = self.get_internal_block_refs()
bar = ProgressBar("Force reads", len(blocks))
bar.block_until_complete(blocks)
blocks, metadata = [], []
for b, m in self._plan.execute().get_blocks_with_metadata():
blocks.append(b)
metadata.append(m)
ds = Dataset(
ExecutionPlan(
BlockList(blocks, self._plan.execute().get_metadata()),
BlockList(blocks, metadata),
self._plan.stats(),
dataset_uuid=self._get_uuid(),
),
Expand Down
34 changes: 3 additions & 31 deletions python/ray/data/impl/block_list.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import math
from typing import List, Iterator, Tuple, Any, Union, Optional, TYPE_CHECKING

if TYPE_CHECKING:
import pyarrow
from typing import List, Iterator, Tuple

import numpy as np

import ray
from ray.types import ObjectRef
from ray.data.block import Block, BlockMetadata, BlockAccessor
from ray.data.impl.remote_fn import cached_remote_fn
from ray.data.block import Block, BlockMetadata


class BlockList:
Expand All @@ -26,11 +22,7 @@ def __init__(self, blocks: List[ObjectRef[Block]], metadata: List[BlockMetadata]
self._num_blocks = len(self._blocks)
self._metadata: List[BlockMetadata] = metadata

def set_metadata(self, i: int, metadata: BlockMetadata) -> None:
"""Set the metadata for a given block."""
self._metadata[i] = metadata

def get_metadata(self) -> List[BlockMetadata]:
def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]:
"""Get the metadata for all blocks."""
return self._metadata.copy()

Expand Down Expand Up @@ -182,23 +174,3 @@ def executed_num_blocks(self) -> int:
doesn't know how many blocks will be produced until tasks finish.
"""
return len(self.get_blocks())

def ensure_schema_for_first_block(self) -> Optional[Union["pyarrow.Schema", type]]:
"""Ensure that the schema is set for the first block.

Returns None if the block list is empty.
"""
get_schema = cached_remote_fn(_get_schema)
try:
block = next(self.iter_blocks())
except (StopIteration, ValueError):
# Dataset is empty (no blocks) or was manually cleared.
return None
schema = ray.get(get_schema.remote(block))
# Set the schema.
self._metadata[0].schema = schema
return schema


def _get_schema(block: Block) -> Any:
return BlockAccessor.for_block(block).schema()
Loading