diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 79b38d5e824a..39917618bd12 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -34,6 +34,7 @@ import numpy as np import ray +import ray.cloudpickle as pickle from ray.types import ObjectRef from ray.util.annotations import DeveloperAPI, PublicAPI from ray.data.block import ( @@ -2819,6 +2820,14 @@ def fully_executed(self) -> "Dataset[T]": ds._set_uuid(self._get_uuid()) return ds + def is_fully_executed(self) -> bool: + """Returns whether this Dataset has been fully executed. + + This will return False if this Dataset is lazy and if the output of its final + stage hasn't been computed yet. + """ + return self._plan.has_computed_output() + def stats(self) -> str: """Returns a string containing execution timing information.""" return self._plan.stats().summary_string() @@ -2842,6 +2851,91 @@ def _experimental_lazy(self) -> "Dataset[T]": self._lazy = True return self + def has_serializable_lineage(self) -> bool: + """Whether this dataset's lineage is able to be serialized for storage and + later deserialized, possibly on a different cluster. + + Only datasets that are created from data that we know will still exist at + deserialization time, e.g. data external to this Ray cluster such as persistent + cloud object stores, support lineage-based serialization. All of the + ray.data.read_*() APIs support lineage-based serialization. + """ + return self._plan.has_lazy_input() + + @DeveloperAPI + def serialize_lineage(self) -> bytes: + """ + Serialize this dataset's lineage, not the actual data or the existing data + futures, to bytes that can be stored and later deserialized, possibly on a + different cluster. + + Note that this will drop all computed data, and that everything will be + recomputed from scratch after deserialization. + + Use :py:meth:`Dataset.deserialize_lineage` to deserialize the serialized bytes + returned from this method into a Dataset. + + Returns: + Serialized bytes containing the lineage of this dataset. + """ + if not self.has_serializable_lineage(): + raise ValueError( + "Lineage-based serialization is only supported for Datasets created " + "from data that we know will still exist at deserialization " + "time, e.g. external data in persistent cloud object stores or " + "in-memory data from long-lived clusters. Concretely, all " + "ray.data.read_*() APIs should support lineage-based serialization, " + "while all of the ray.data.from_*() APIs do not. To allow this " + "Dataset to be serialized to storage, write the data to an external " + "store (such as AWS S3, GCS, or Azure Blob Storage) using the " + "Dataset.write_*() APIs, and serialize a new dataset reading from the " + "external store using the ray.data.read_*() APIs." + ) + # Copy Dataset and clear the blocks from the execution plan so only the + # Dataset's lineage is serialized. + plan_copy = self._plan.deep_copy(preserve_uuid=True) + ds = Dataset(plan_copy, self._get_epoch(), self._lazy) + ds._plan.clear_block_refs() + ds._set_uuid(self._get_uuid()) + + def _reduce(rf: ray.remote_function.RemoteFunction): + # Custom reducer for Ray remote function handles that allows for + # cross-cluster serialization. + # This manually unsets the last export session and job to force re-exporting + # of the function when the handle is deserialized on a new cluster. + # TODO(Clark): Fix this in core Ray, see issue: + # https://github.com/ray-project/ray/issues/24152. + reconstructor, args, state = rf.__reduce__() + state["_last_export_session_and_job"] = None + return reconstructor, args, state + + context = ray.worker.global_worker.get_serialization_context() + try: + context._register_cloudpickle_reducer( + ray.remote_function.RemoteFunction, _reduce + ) + serialized = pickle.dumps(ds) + finally: + context._unregister_cloudpickle_reducer(ray.remote_function.RemoteFunction) + return serialized + + @DeveloperAPI + @staticmethod + def deserialize_lineage(serialized_ds: bytes) -> "Dataset": + """ + Deserialize the provided lineage-serialized Dataset. + + This assumes that the provided serialized bytes were serialized using + :py:meth:`Dataset.serialize_lineage`. + + Args: + serialized_ds: The serialized Dataset that we wish to deserialize. + + Returns: + A deserialized ``Dataset`` instance. + """ + return pickle.loads(serialized_ds) + def _split( self, index: int, return_right_half: bool ) -> ("Dataset[T]", "Dataset[T]"): diff --git a/python/ray/data/impl/plan.py b/python/ray/data/impl/plan.py index 6feff83268e5..2cdd3c9a6232 100644 --- a/python/ray/data/impl/plan.py +++ b/python/ray/data/impl/plan.py @@ -256,7 +256,7 @@ def execute( self._snapshot_blocks = self._snapshot_blocks.compute_to_blocklist() return self._snapshot_blocks - def clear(self) -> None: + def clear_block_refs(self) -> None: """Clear all cached block references of this plan, including input blocks. This will render the plan un-executable unless the root is a LazyBlockList.""" diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 25e679ffd9b3..8d34470faf43 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -13,12 +13,13 @@ import ray from ray.tests.conftest import * # noqa -from ray.data.dataset import _sliding_window +from ray.data.dataset import Dataset, _sliding_window from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.block import BlockAccessor from ray.data.row import TableRow from ray.data.impl.arrow_block import ArrowRow from ray.data.impl.block_builder import BlockBuilder +from ray.data.impl.lazy_block_list import LazyBlockList from ray.data.impl.pandas_block import PandasRow from ray.data.aggregate import AggregateFn, Count, Sum, Min, Max, Mean, Std from ray.data.extensions.tensor_extension import ( @@ -38,6 +39,13 @@ def maybe_pipeline(ds, enabled): return ds +def maybe_lazy(ds, enabled): + if enabled: + return ds._experimental_lazy() + else: + return ds + + class SlowCSVDatasource(CSVDatasource): def _read_stream(self, f: "pa.NativeFile", path: str, **reader_args): for block in CSVDatasource._read_stream(self, f, path, **reader_args): @@ -184,6 +192,56 @@ def mapper(x): ds.map(mapper) +@pytest.mark.parametrize("lazy", [False, True]) +def test_dataset_lineage_serialization(shutdown_only, lazy): + ray.init() + ds = ray.data.range(10) + ds = maybe_lazy(ds, lazy) + ds = ds.map(lambda x: x + 1) + ds = ds.map(lambda x: x + 1) + ds = ds.random_shuffle() + epoch = ds._get_epoch() + uuid = ds._get_uuid() + plan_uuid = ds._plan._dataset_uuid + lazy = ds._lazy + + serialized_ds = ds.serialize_lineage() + # Confirm that the original Dataset was properly copied before clearing/mutating. + in_blocks = ds._plan._in_blocks + # Should not raise. + in_blocks._check_if_cleared() + assert isinstance(in_blocks, LazyBlockList) + if lazy: + assert in_blocks._block_partition_refs[0] is not None + else: + assert ds._plan._snapshot_blocks is not None + + ray.shutdown() + ray.init() + + ds = Dataset.deserialize_lineage(serialized_ds) + # Check Dataset state. + assert ds._get_epoch() == epoch + assert ds._get_uuid() == uuid + assert ds._plan._dataset_uuid == plan_uuid + assert ds._lazy == lazy + # Check Dataset content. + assert ds.count() == 10 + assert sorted(ds.take()) == list(range(2, 12)) + + +@pytest.mark.parametrize("lazy", [False, True]) +def test_dataset_lineage_serialization_in_memory(shutdown_only, lazy): + ray.init() + ds = ray.data.from_items(list(range(10))) + ds = maybe_lazy(ds, lazy) + ds = ds.map(lambda x: x + 1) + ds = ds.map(lambda x: x + 1) + + with pytest.raises(ValueError): + ds.serialize_lineage() + + @pytest.mark.parametrize("pipelined", [False, True]) def test_basic(ray_start_regular_shared, pipelined): ds0 = ray.data.range(5)