Skip to content

Commit 61d2cfd

Browse files
jianoaixProjectsByJackHe
authored andcommitted
[Datasets] Streaming executor fixes ray-project#3 (ray-project#32836)
Signed-off-by: Jack He <[email protected]>
1 parent 48a06e4 commit 61d2cfd

File tree

6 files changed

+80
-17
lines changed

6 files changed

+80
-17
lines changed

.buildkite/pipeline.ml.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@
297297
- sudo service mongodb stop
298298
- sudo apt-get purge -y mongodb*
299299

300-
- label: "[unstable] Dataset tests (streaming executor)"
300+
- label: "Dataset tests (streaming executor)"
301301
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_DATA_AFFECTED"]
302302
instance_size: medium
303303
commands:

python/ray/data/_internal/execution/legacy_compat.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def _blocks_to_input_buffer(blocks: BlockList, owns_blocks: bool) -> PhysicalOpe
125125

126126
if hasattr(blocks, "_tasks"):
127127
read_tasks = blocks._tasks
128+
remote_args = blocks._remote_args
128129
assert all(isinstance(t, ReadTask) for t in read_tasks), read_tasks
129130
inputs = InputDataBuffer(
130131
[
@@ -157,7 +158,9 @@ def do_read(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
157158
for read_task in blocks:
158159
yield from read_task()
159160

160-
return MapOperator.create(do_read, inputs, name="DoRead")
161+
return MapOperator.create(
162+
do_read, inputs, name="DoRead", ray_remote_args=remote_args
163+
)
161164
else:
162165
output = _block_list_to_bundles(blocks, owns_blocks=owns_blocks)
163166
for i in output:

python/ray/data/_internal/plan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,10 @@ def clear_block_refs(self) -> None:
610610
611611
This will render the plan un-executable unless the root is a LazyBlockList."""
612612
self._in_blocks.clear()
613+
self._clear_snapshot()
614+
615+
def _clear_snapshot(self) -> None:
616+
"""Clear the snapshot kept in the plan to the beginning state."""
613617
self._snapshot_blocks = None
614618
self._snapshot_stats = None
615619
# We're erasing the snapshot, so put all stages into the "after snapshot"
@@ -691,7 +695,7 @@ def _get_source_blocks_and_stages(
691695
stats = self._snapshot_stats
692696
# Unlink the snapshot blocks from the plan so we can eagerly reclaim the
693697
# snapshot block memory after the first stage is done executing.
694-
self._snapshot_blocks = None
698+
self._clear_snapshot()
695699
else:
696700
# Snapshot exists but has been cleared, so we need to recompute from the
697701
# source (input blocks).

python/ray/data/tests/test_dataset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1591,7 +1591,14 @@ def test_convert_types(ray_start_regular_shared):
15911591

15921592
arrow_ds = ray.data.range_table(1)
15931593
assert arrow_ds.map(lambda x: "plain_{}".format(x["value"])).take() == ["plain_0"]
1594-
assert arrow_ds.map(lambda x: {"a": (x["value"],)}).take() == [{"a": [0]}]
1594+
# In streaming, we set batch_format to "default" (because calling
1595+
# ds.dataset_format() will still invoke bulk execution and we want
1596+
# to avoid that). As a result, it's receiving PandasRow (the defaut
1597+
# batch format), which unwraps [0] to plain 0.
1598+
if ray.data.context.DatasetContext.get_current().use_streaming_executor:
1599+
assert arrow_ds.map(lambda x: {"a": (x["value"],)}).take() == [{"a": 0}]
1600+
else:
1601+
assert arrow_ds.map(lambda x: {"a": (x["value"],)}).take() == [{"a": [0]}]
15951602

15961603

15971604
def test_from_items(ray_start_regular_shared):

python/ray/data/tests/test_dataset_tfrecords.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,29 @@ def test_readback_tfrecords(ray_start_regular_shared, tmp_path):
244244
# for type inference involving partially missing columns.
245245
parallelism=1,
246246
)
247-
248247
# Write the TFRecords.
249248
ds.write_tfrecords(tmp_path)
250-
251249
# Read the TFRecords.
252250
readback_ds = ray.data.read_tfrecords(tmp_path)
253-
assert ds.take() == readback_ds.take()
251+
if not ray.data.context.DatasetContext.get_current().use_streaming_executor:
252+
assert ds.take() == readback_ds.take()
253+
else:
254+
# In streaming, we set batch_format to "default" (because calling
255+
# ds.dataset_format() will still invoke bulk execution and we want
256+
# to avoid that). As a result, it's receiving PandasRow (the defaut
257+
# batch format), which doesn't have the same ordering of columns as
258+
# the ArrowRow.
259+
from ray.data.block import BlockAccessor
260+
261+
def get_rows(ds):
262+
rows = []
263+
for batch in ds.iter_batches(batch_size=None, batch_format="pyarrow"):
264+
batch = BlockAccessor.for_block(BlockAccessor.batch_to_block(batch))
265+
for row in batch.iter_rows():
266+
rows.append(row)
267+
return rows
268+
269+
assert get_rows(ds) == get_rows(readback_ds)
254270

255271

256272
def test_write_invalid_tfrecords(ray_start_regular_shared, tmp_path):

python/ray/data/tests/test_stats.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,18 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
3535
context.optimize_fuse_stages = True
3636

3737
if context.new_execution_backend:
38-
logger = DatasetLogger("ray.data._internal.execution.bulk_executor").get_logger(
39-
log_to_stdout=enable_auto_log_stats,
40-
)
38+
if context.use_streaming_executor:
39+
logger = DatasetLogger(
40+
"ray.data._internal.execution.streaming_executor"
41+
).get_logger(
42+
log_to_stdout=enable_auto_log_stats,
43+
)
44+
else:
45+
logger = DatasetLogger(
46+
"ray.data._internal.execution.bulk_executor"
47+
).get_logger(
48+
log_to_stdout=enable_auto_log_stats,
49+
)
4150
else:
4251
logger = DatasetLogger("ray.data._internal.plan").get_logger(
4352
log_to_stdout=enable_auto_log_stats,
@@ -111,9 +120,24 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
111120
stats = canonicalize(ds.fully_executed().stats())
112121

113122
if context.new_execution_backend:
114-
assert (
115-
stats
116-
== """Stage N read->MapBatches(dummy_map_batches): N/N blocks executed in T
123+
if context.use_streaming_executor:
124+
assert (
125+
stats
126+
== """Stage N read->MapBatches(dummy_map_batches)->map: N/N blocks executed in T
127+
* Remote wall time: T min, T max, T mean, T total
128+
* Remote cpu time: T min, T max, T mean, T total
129+
* Peak heap memory usage (MiB): N min, N max, N mean
130+
* Output num rows: N min, N max, N mean, N total
131+
* Output size bytes: N min, N max, N mean, N total
132+
* Tasks per node: N min, N max, N mean; N nodes used
133+
* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
134+
'obj_store_mem_peak': N}
135+
"""
136+
)
137+
else:
138+
assert (
139+
stats
140+
== """Stage N read->MapBatches(dummy_map_batches): N/N blocks executed in T
117141
* Remote wall time: T min, T max, T mean, T total
118142
* Remote cpu time: T min, T max, T mean, T total
119143
* Peak heap memory usage (MiB): N min, N max, N mean
@@ -141,7 +165,7 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
141165
* In user code: T
142166
* Total time: T
143167
"""
144-
)
168+
)
145169
else:
146170
assert (
147171
stats
@@ -364,9 +388,18 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_
364388
context.optimize_fuse_stages = True
365389

366390
if context.new_execution_backend:
367-
logger = DatasetLogger("ray.data._internal.execution.bulk_executor").get_logger(
368-
log_to_stdout=enable_auto_log_stats,
369-
)
391+
if context.use_streaming_executor:
392+
logger = DatasetLogger(
393+
"ray.data._internal.execution.streaming_executor"
394+
).get_logger(
395+
log_to_stdout=enable_auto_log_stats,
396+
)
397+
else:
398+
logger = DatasetLogger(
399+
"ray.data._internal.execution.bulk_executor"
400+
).get_logger(
401+
log_to_stdout=enable_auto_log_stats,
402+
)
370403
else:
371404
logger = DatasetLogger("ray.data._internal.plan").get_logger(
372405
log_to_stdout=enable_auto_log_stats,

0 commit comments

Comments
 (0)