Skip to content

Commit dade522

Browse files
committed
PR feedback.
1 parent acadb9b commit dade522

File tree

4 files changed

+117
-107
lines changed

4 files changed

+117
-107
lines changed

python/ray/data/dataset.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -990,25 +990,27 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
990990
start_time = time.perf_counter()
991991
context = DatasetContext.get_current()
992992
tasks: List[ReadTask] = []
993-
block_partitions: List[ObjectRef[BlockPartition]] = []
994-
block_partitions_meta: List[ObjectRef[BlockPartitionMetadata]] = []
993+
block_partition_refs: List[ObjectRef[BlockPartition]] = []
994+
block_partition_meta_refs: List[ObjectRef[BlockPartitionMetadata]] = []
995995

996996
datasets = [self] + list(other)
997997
for ds in datasets:
998998
bl = ds._plan.execute()
999999
if isinstance(bl, LazyBlockList):
10001000
tasks.extend(bl._tasks)
1001-
block_partitions.extend(bl._block_partitions)
1002-
block_partitions_meta.extend(bl._block_partitions_meta)
1001+
block_partition_refs.extend(bl._block_partition_refs)
1002+
block_partition_meta_refs.extend(bl._block_partition_meta_refs)
10031003
else:
10041004
tasks.extend([ReadTask(lambda: None, meta) for meta in bl._metadata])
10051005
if context.block_splitting_enabled:
1006-
block_partitions.extend(
1006+
block_partition_refs.extend(
10071007
[ray.put([(b, m)]) for b, m in bl.get_blocks_with_metadata()]
10081008
)
10091009
else:
1010-
block_partitions.extend(bl.get_blocks())
1011-
block_partitions_meta.extend([ray.put(meta) for meta in bl._metadata])
1010+
block_partition_refs.extend(bl.get_blocks())
1011+
block_partition_meta_refs.extend(
1012+
[ray.put(meta) for meta in bl._metadata]
1013+
)
10121014

10131015
epochs = [ds._get_epoch() for ds in datasets]
10141016
max_epoch = max(*epochs)
@@ -1029,7 +1031,7 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
10291031
dataset_stats.time_total_s = time.perf_counter() - start_time
10301032
return Dataset(
10311033
ExecutionPlan(
1032-
LazyBlockList(tasks, block_partitions, block_partitions_meta),
1034+
LazyBlockList(tasks, block_partition_refs, block_partition_meta_refs),
10331035
dataset_stats,
10341036
),
10351037
max_epoch,

python/ray/data/impl/lazy_block_list.py

Lines changed: 100 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,24 @@ class LazyBlockList(BlockList):
3333
def __init__(
3434
self,
3535
tasks: List[ReadTask],
36-
block_partitions: Optional[List[ObjectRef[MaybeBlockPartition]]] = None,
37-
block_partitions_meta: Optional[List[ObjectRef[BlockPartitionMetadata]]] = None,
38-
fetched_metadata: Optional[List[BlockPartitionMetadata]] = None,
36+
block_partition_refs: Optional[List[ObjectRef[MaybeBlockPartition]]] = None,
37+
block_partition_meta_refs: Optional[
38+
List[ObjectRef[BlockPartitionMetadata]]
39+
] = None,
40+
cached_metadata: Optional[List[BlockPartitionMetadata]] = None,
3941
ray_remote_args: Optional[Dict[str, Any]] = None,
4042
stats_uuid: str = None,
4143
):
4244
"""Create a LazyBlockList on the provided read tasks.
4345
4446
Args:
4547
tasks: The read tasks that will produce the blocks of this lazy block list.
46-
block_partitions: An optional list of already submitted read task futures
47-
(i.e. block partition refs). This should be the same length as the tasks
48-
argument.
49-
block_partitions_meta: An optional list of block partition metadata refs.
50-
This should be the same length as the tasks argument.
51-
fetched_metadata: An optional list of already computed AND fetched metadata.
48+
block_partition_refs: An optional list of already submitted read task
49+
futures (i.e. block partition refs). This should be the same length as
50+
the tasks argument.
51+
block_partition_meta_refs: An optional list of block partition metadata
52+
refs. This should be the same length as the tasks argument.
53+
cached_metadata: An optional list of already computed AND fetched metadata.
5254
This serves as a cache of fetched block metadata.
5355
ray_remote_args: Ray remote arguments for the read tasks.
5456
stats_uuid: UUID for the dataset stats, used to group and fetch read task
@@ -62,42 +64,42 @@ def __init__(
6264
self._execution_started = False
6365
self._remote_args = ray_remote_args or {}
6466
# Block partition metadata that have already been computed and fetched.
65-
if fetched_metadata is not None:
66-
self._fetched_metadata = fetched_metadata
67+
if cached_metadata is not None:
68+
self._cached_metadata = cached_metadata
6769
else:
68-
self._fetched_metadata = [None] * len(tasks)
70+
self._cached_metadata = [None] * len(tasks)
6971
# Block partition metadata that have already been computed.
70-
if block_partitions_meta is not None:
71-
self._block_partitions_meta = block_partitions_meta
72+
if block_partition_meta_refs is not None:
73+
self._block_partition_meta_refs = block_partition_meta_refs
7274
else:
73-
self._block_partitions_meta = [None] * len(tasks)
75+
self._block_partition_meta_refs = [None] * len(tasks)
7476
# Block partitions that have already been computed.
75-
if block_partitions is not None:
76-
self._block_partitions = block_partitions
77+
if block_partition_refs is not None:
78+
self._block_partition_refs = block_partition_refs
7779
else:
78-
self._block_partitions = [None] * len(tasks)
79-
assert len(tasks) == len(self._block_partitions), (
80+
self._block_partition_refs = [None] * len(tasks)
81+
assert len(tasks) == len(self._block_partition_refs), (
8082
tasks,
81-
self._block_partitions,
83+
self._block_partition_refs,
8284
)
83-
assert len(tasks) == len(self._block_partitions_meta), (
85+
assert len(tasks) == len(self._block_partition_meta_refs), (
8486
tasks,
85-
self._block_partitions_meta,
87+
self._block_partition_meta_refs,
8688
)
87-
assert len(tasks) == len(self._fetched_metadata), (
89+
assert len(tasks) == len(self._cached_metadata), (
8890
tasks,
89-
self._fetched_metadata,
91+
self._cached_metadata,
9092
)
9193

9294
def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]:
9395
"""Get the metadata for all blocks."""
94-
if all(meta is not None for meta in self._fetched_metadata):
96+
if all(meta is not None for meta in self._cached_metadata):
9597
# Always return fetched metadata if we already have it.
96-
metadata = self._fetched_metadata
98+
metadata = self._cached_metadata
9799
elif not fetch_if_missing:
98100
metadata = [
99101
m if m is not None else t.get_metadata()
100-
for m, t in zip(self._fetched_metadata, self._tasks)
102+
for m, t in zip(self._cached_metadata, self._tasks)
101103
]
102104
else:
103105
_, metadata = self._get_blocks_with_metadata()
@@ -136,9 +138,9 @@ def _submit_task(
136138
def copy(self) -> "LazyBlockList":
137139
return LazyBlockList(
138140
self._tasks.copy(),
139-
block_partitions=self._block_partitions.copy(),
140-
block_partitions_meta=self._block_partitions_meta.copy(),
141-
fetched_metadata=self._fetched_metadata,
141+
block_partition_refs=self._block_partition_refs.copy(),
142+
block_partition_meta_refs=self._block_partition_meta_refs.copy(),
143+
cached_metadata=self._cached_metadata,
142144
ray_remote_args=self._remote_args.copy(),
143145
stats_uuid=self._stats_uuid,
144146
)
@@ -147,9 +149,11 @@ def clear(self):
147149
"""Clears all object references (block partitions and base block partitions)
148150
from this lazy block list.
149151
"""
150-
self._block_partitions = [None for _ in self._block_partitions]
151-
self._block_partitions_meta = [None for _ in self._block_partitions_meta]
152-
self._fetched_metadata = [None for _ in self._fetched_metadata]
152+
self._block_partition_refs = [None for _ in self._block_partition_refs]
153+
self._block_partition_meta_refs = [
154+
None for _ in self._block_partition_meta_refs
155+
]
156+
self._cached_metadata = [None for _ in self._cached_metadata]
153157

154158
def _check_if_cleared(self):
155159
pass # LazyBlockList can always be re-computed.
@@ -158,10 +162,12 @@ def _check_if_cleared(self):
158162
def split(self, split_size: int) -> List["LazyBlockList"]:
159163
num_splits = math.ceil(len(self._tasks) / split_size)
160164
tasks = np.array_split(self._tasks, num_splits)
161-
block_partitions = np.array_split(self._block_partitions, num_splits)
162-
block_partitions_meta = np.array_split(self._block_partitions_meta, num_splits)
165+
block_partition_refs = np.array_split(self._block_partition_refs, num_splits)
166+
block_partition_meta_refs = np.array_split(
167+
self._block_partition_meta_refs, num_splits
168+
)
163169
output = []
164-
for t, b, m in zip(tasks, block_partitions, block_partitions_meta):
170+
for t, b, m in zip(tasks, block_partition_refs, block_partition_meta_refs):
165171
output.append(
166172
LazyBlockList(
167173
t.tolist(),
@@ -179,8 +185,8 @@ def split_by_bytes(self, bytes_per_split: int) -> List["BlockList"]:
179185
cur_size = 0
180186
for t, b, bm in zip(
181187
self._tasks,
182-
self._block_partitions,
183-
self._block_partitions_meta,
188+
self._block_partition_refs,
189+
self._block_partition_meta_refs,
184190
):
185191
m = t.get_metadata()
186192
if m.size_bytes is None:
@@ -206,13 +212,13 @@ def split_by_bytes(self, bytes_per_split: int) -> List["BlockList"]:
206212
def divide(self, part_idx: int) -> ("LazyBlockList", "LazyBlockList"):
207213
left = LazyBlockList(
208214
self._tasks[:part_idx],
209-
self._block_partitions[:part_idx],
210-
self._block_partitions_meta[:part_idx],
215+
self._block_partition_refs[:part_idx],
216+
self._block_partition_meta_refs[:part_idx],
211217
)
212218
right = LazyBlockList(
213219
self._tasks[part_idx:],
214-
self._block_partitions[part_idx:],
215-
self._block_partitions_meta[part_idx:],
220+
self._block_partition_refs[part_idx:],
221+
self._block_partition_meta_refs[part_idx:],
216222
)
217223
return left, right
218224

@@ -243,26 +249,26 @@ def _get_blocks_with_metadata(
243249
all block metadata outputted by those tasks.
244250
"""
245251
context = DatasetContext.get_current()
246-
blocks, meta_refs = [], []
247-
for block, meta_ref in self._iter_block_partition_refs():
248-
blocks.append(block)
252+
block_refs, meta_refs = [], []
253+
for block_ref, meta_ref in self._iter_block_partition_refs():
254+
block_refs.append(block_ref)
249255
meta_refs.append(meta_ref)
250256
if context.block_splitting_enabled:
251257
# If block splitting is enabled, fetch the partitions.
252-
parts = ray.get(blocks)
253-
blocks, metadata = [], []
258+
parts = ray.get(block_refs)
259+
block_refs, metadata = [], []
254260
for part in parts:
255-
for block, meta in part:
256-
blocks.append(block)
261+
for block_ref, meta in part:
262+
block_refs.append(block_ref)
257263
metadata.append(meta)
258-
self._fetched_metadata = metadata
259-
return blocks, metadata
260-
if all(meta is not None for meta in self._fetched_metadata):
264+
self._cached_metadata = metadata
265+
return block_refs, metadata
266+
if all(meta is not None for meta in self._cached_metadata):
261267
# Short-circuit on cached metadata.
262-
return blocks, self._fetched_metadata
268+
return block_refs, self._cached_metadata
263269
if not meta_refs:
264270
# Short-circuit on empty set of block partitions.
265-
assert not blocks, blocks
271+
assert not block_refs, block_refs
266272
return [], []
267273
read_progress_bar = ProgressBar("Read progress", total=len(meta_refs))
268274
# Fetch the metadata in bulk.
@@ -273,8 +279,8 @@ def _get_blocks_with_metadata(
273279
meta_ref: data for meta_ref, data in zip(unique_meta_refs, metadata)
274280
}
275281
metadata = [ref_to_data[meta_ref] for meta_ref in meta_refs]
276-
self._fetched_metadata = metadata
277-
return blocks, metadata
282+
self._cached_metadata = metadata
283+
return block_refs, metadata
278284

279285
def compute_first_block(self):
280286
"""Kick off computation for the first block in the list.
@@ -298,7 +304,11 @@ def ensure_metadata_for_first_block(self) -> Optional[BlockMetadata]:
298304
return None
299305
metadata = self._tasks[0].get_metadata()
300306
if metadata.schema is not None:
307+
# If pre-read schema is not null, we consider it to be "good enough" and use
308+
# it.
301309
return metadata
310+
# Otherwise, we trigger computation (if needed), wait until the task completes,
311+
# and fetch the block partition metadata.
302312
try:
303313
_, metadata_ref = next(self._iter_block_partition_refs())
304314
except (StopIteration, ValueError):
@@ -307,16 +317,27 @@ def ensure_metadata_for_first_block(self) -> Optional[BlockMetadata]:
307317
else:
308318
# This blocks until the underlying read task is finished.
309319
metadata = ray.get(metadata_ref)
310-
self._fetched_metadata[0] = metadata
320+
self._cached_metadata[0] = metadata
311321
return metadata
312322

313323
def iter_blocks_with_metadata(
314324
self,
325+
block_for_metadata: bool = False,
315326
) -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]:
316327
"""Iterate over the blocks along with their metadata.
317328
329+
Note that, if block_for_metadata is False (default), this iterator returns
330+
pre-read metadata from the ReadTasks given to this LazyBlockList so it doesn't
331+
have to block on the execution of the read tasks. Therefore, the metadata may be
332+
under-specified, e.g. missing schema or the number of rows. If fully-specified
333+
block metadata is required, pass block_for_metadata=True.
334+
318335
The length of this iterator is not known until execution.
319336
337+
Args:
338+
block_for_metadata: Whether we should block on the execution of read tasks
339+
in order to obtain fully-specified block metadata.
340+
320341
Returns:
321342
An iterator of block references and the corresponding block metadata.
322343
"""
@@ -339,11 +360,18 @@ def __next__(self):
339360
part_ref, _ = next(self._base_iter)
340361
partition = ray.get(part_ref)
341362
else:
342-
block, _ = next(self._base_iter)
343-
metadata = outer._tasks[self._pos].get_metadata()
344-
partition = [(block, metadata)]
345-
for ref, metadata in partition:
346-
self._buffer.append((ref, metadata))
363+
block_ref, metadata_ref = next(self._base_iter)
364+
if block_for_metadata:
365+
# This blocks until the read task completes, returning
366+
# fully-specified block metadata.
367+
metadata = ray.get(metadata_ref)
368+
else:
369+
# This does not block, returning (possibly under-specified)
370+
# pre-read block metadata.
371+
metadata = outer._tasks[self._pos].get_metadata()
372+
partition = [(block_ref, metadata)]
373+
for block_ref, metadata in partition:
374+
self._buffer.append((block_ref, metadata))
347375
return self._buffer.pop(0)
348376

349377
return Iter()
@@ -379,24 +407,24 @@ def _get_or_compute(
379407
i: int,
380408
) -> Tuple[ObjectRef[MaybeBlockPartition], ObjectRef[BlockPartitionMetadata]]:
381409
assert i < len(self._tasks), i
382-
# Check if we need to compute more block_partitions.
383-
if not self._block_partitions[i]:
410+
# Check if we need to compute more block_partition_refs.
411+
if not self._block_partition_refs[i]:
384412
# Exponentially increase the number computed per batch.
385413
for j in range(max(i + 1, i * 2)):
386-
if j >= len(self._block_partitions):
414+
if j >= len(self._block_partition_refs):
387415
break
388-
if not self._block_partitions[j]:
416+
if not self._block_partition_refs[j]:
389417
(
390-
self._block_partitions[j],
391-
self._block_partitions_meta[j],
418+
self._block_partition_refs[j],
419+
self._block_partition_meta_refs[j],
392420
) = self._submit_task(j)
393-
assert self._block_partitions[i], self._block_partitions
394-
assert self._block_partitions_meta[i], self._block_partitions_meta
395-
return self._block_partitions[i], self._block_partitions_meta[i]
421+
assert self._block_partition_refs[i], self._block_partition_refs
422+
assert self._block_partition_meta_refs[i], self._block_partition_meta_refs
423+
return self._block_partition_refs[i], self._block_partition_meta_refs[i]
396424

397425
def _num_computed(self) -> int:
398426
i = 0
399-
for b in self._block_partitions:
427+
for b in self._block_partition_refs:
400428
if b is not None:
401429
i += 1
402430
return i

0 commit comments

Comments
 (0)