Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import numpy as np

import ray
import ray.cloudpickle as pickle
from ray.types import ObjectRef
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.data.block import (
Expand Down Expand Up @@ -2819,6 +2820,14 @@ def fully_executed(self) -> "Dataset[T]":
ds._set_uuid(self._get_uuid())
return ds

def is_fully_executed(self) -> bool:
"""Returns whether this Dataset has been fully executed.

This will return False if this Dataset is lazy and if the output of its final
stage hasn't been computed yet.
"""
return self._plan.has_computed_output()

def stats(self) -> str:
"""Returns a string containing execution timing information."""
return self._plan.stats().summary_string()
Expand All @@ -2842,6 +2851,82 @@ def _experimental_lazy(self) -> "Dataset[T]":
self._lazy = True
return self

def is_out_of_band_serializable(self) -> bool:
"""Whether this dataset is able to be out-of-band serialized, i.e. serialized
for use across different Ray clusters. Only datasets that read from lazy
datasources (i.e. via one of the ray.data.read_*() APIs) are out-of-band
serializable.
"""
return self._plan.has_lazy_input()

@DeveloperAPI
def serialize_out_of_band(self) -> bytes:
"""
Serialize the Dataset for out-of-band use, i.e. for use across different Ray
clusters. This method serializes the lineage of the Dataset operations. Note
that this will drop all computed data, and that everything will be recomputed
from scratch after deserialization.

Use ``Dataset.deserialize_out_of_band`` to deserialize the serialized bytes
into a Dataset.

Returns:
Serialized bytes.
"""
if not self.is_out_of_band_serializable():
raise ValueError(
"Out-of-band serialization is only supported for Datasets created from "
"lazy datasources. I.e., out-of-band serialization is not "
"supported for any ray.data.from_*() APIs. To allow this Dataset to be "
"out-of-band serialized, write the data to an external store (such as "
"AWS S3, GCS, or Azure Blob Storage) using the Dataset.write_*() APIs, "
"and serialize a new dataset reading from the external store using the "
"ray.data.read_*() APIs."
)
# Copy Dataset and clear the execution plan so the Dataset is out-of-band
# serializable.
plan_copy = self._plan.deep_copy(preserve_uuid=True)
ds = Dataset(plan_copy, self._get_epoch(), self._lazy)
ds._plan.clear()
ds._set_uuid(self._get_uuid())

def _reduce(rf: ray.remote_function.RemoteFunction):
# Custom reducer for Ray remote function handles that allows for
# cross-cluster serialization.
# TODO(Clark): Fix this in core Ray.
reconstructor, args, state = rf.__reduce__()
# Manually unset last export session and job to force re-exporting of the
# function when the handle is deserialized on a new cluster.
state["_last_export_session_and_job"] = None
return reconstructor, args, state

context = ray.worker.global_worker.get_serialization_context()
try:
context._register_cloudpickle_reducer(
ray.remote_function.RemoteFunction, _reduce
)
serialized = pickle.dumps(ds)
finally:
context._unregister_cloudpickle_reducer(ray.remote_function.RemoteFunction)
return serialized

@DeveloperAPI
@staticmethod
def deserialize_out_of_band(serialized_ds: bytes) -> "Dataset":
"""
Deserialize the provided out-of-band serialized Dataset.

This assumes that the provided serialized bytes were serialized using
``Dataset.serialize_out_of_band``.

Args:
serialized_ds: The serialized Dataset that we wish to deserialize.

Returns:
A deserialized ``Dataset`` instance.
"""
return pickle.loads(serialized_ds)

def _split(
self, index: int, return_right_half: bool
) -> ("Dataset[T]", "Dataset[T]"):
Expand Down
59 changes: 58 additions & 1 deletion python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
import ray

from ray.tests.conftest import * # noqa
from ray.data.dataset import _sliding_window
from ray.data.dataset import Dataset, _sliding_window
from ray.data.datasource.csv_datasource import CSVDatasource
from ray.data.block import BlockAccessor
from ray.data.row import TableRow
from ray.data.impl.arrow_block import ArrowRow
from ray.data.impl.block_builder import BlockBuilder
from ray.data.impl.lazy_block_list import LazyBlockList
from ray.data.impl.pandas_block import PandasRow
from ray.data.aggregate import AggregateFn, Count, Sum, Min, Max, Mean, Std
from ray.data.extensions.tensor_extension import (
Expand All @@ -38,6 +39,13 @@ def maybe_pipeline(ds, enabled):
return ds


def maybe_lazy(ds, enabled):
if enabled:
return ds._experimental_lazy()
else:
return ds


class SlowCSVDatasource(CSVDatasource):
def _read_stream(self, f: "pa.NativeFile", path: str, **reader_args):
for block in CSVDatasource._read_stream(self, f, path, **reader_args):
Expand Down Expand Up @@ -184,6 +192,55 @@ def mapper(x):
ds.map(mapper)


@pytest.mark.parametrize("lazy", [False, True])
def test_dataset_out_of_band_serialization(shutdown_only, lazy):
ray.init()
ds = ray.data.range(10)
ds = maybe_lazy(ds, lazy)
ds = ds.map(lambda x: x + 1)
ds = ds.map(lambda x: x + 1)
ds = ds.random_shuffle()
epoch = ds._get_epoch()
uuid = ds._get_uuid()
plan_uuid = ds._plan._dataset_uuid
lazy = ds._lazy

serialized_ds = ds.serialize_out_of_band()
# Confirm that the original Dataset was properly copied before clearing/mutating.
in_blocks = ds._plan._in_blocks
# Should not raise.
in_blocks._check_if_cleared()
if lazy and isinstance(in_blocks, LazyBlockList):
assert in_blocks._block_partition_refs[0] is not None
if not lazy:
assert ds._plan._snapshot_blocks is not None

ray.shutdown()
ray.init()

ds = Dataset.deserialize_out_of_band(serialized_ds)
# Check Dataset state.
assert ds._get_epoch() == epoch
assert ds._get_uuid() == uuid
assert ds._plan._dataset_uuid == plan_uuid
assert ds._lazy == lazy
# Check Dataset content.
assert ds.count() == 10
assert sorted(ds.take()) == list(range(2, 12))


@pytest.mark.parametrize("lazy", [False, True])
def test_dataset_out_of_band_serialization_in_memory(shutdown_only, lazy):
ray.init()
ds = ray.data.from_items(list(range(10)))
ds = maybe_lazy(ds, lazy)
ds = ds.map(lambda x: x + 1)
ds = ds.map(lambda x: x + 1)

with pytest.raises(ValueError):
ds.serialize_out_of_band()


@pytest.mark.parametrize("pipelined", [False, True])
def test_basic(ray_start_regular_shared, pipelined):
ds0 = ray.data.range(5)
Expand Down