From 3a87dcc88e9c891148a912e7ecbcc74f4266530a Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Fri, 15 Apr 2022 00:49:11 +0000 Subject: [PATCH 1/3] Add out-of-band serialization. --- python/ray/data/dataset.py | 85 +++++++++++++++++++++++++++ python/ray/data/tests/test_dataset.py | 59 ++++++++++++++++++- 2 files changed, 143 insertions(+), 1 deletion(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 79b38d5e824a..09aeb03a2a03 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -34,6 +34,7 @@ import numpy as np import ray +import ray.cloudpickle as pickle from ray.types import ObjectRef from ray.util.annotations import DeveloperAPI, PublicAPI from ray.data.block import ( @@ -2819,6 +2820,14 @@ def fully_executed(self) -> "Dataset[T]": ds._set_uuid(self._get_uuid()) return ds + def is_fully_executed(self) -> bool: + """Returns whether this Dataset has been fully executed. + + This will return False if this Dataset is lazy and if the output of its final + stage hasn't been computed yet. + """ + return self._plan.has_computed_output() + def stats(self) -> str: """Returns a string containing execution timing information.""" return self._plan.stats().summary_string() @@ -2842,6 +2851,82 @@ def _experimental_lazy(self) -> "Dataset[T]": self._lazy = True return self + def is_out_of_band_serializable(self) -> bool: + """Whether this dataset is able to be out-of-band serialized, i.e. serialized + for use across different Ray clusters. Only datasets that read from lazy + datasources (i.e. via one of the ray.data.read_*() APIs) are out-of-band + serializable. + """ + return self._plan.has_lazy_input() + + @DeveloperAPI + def serialize_out_of_band(self) -> bytes: + """ + Serialize the Dataset for out-of-band use, i.e. for use across different Ray + clusters. This method serializes the lineage of the Dataset operations. Note + that this will drop all computed data, and that everything will be recomputed + from scratch after deserialization. + + Use ``Dataset.deserialize_out_of_band`` to deserialize the serialized bytes + into a Dataset. + + Returns: + Serialized bytes. + """ + if not self.is_out_of_band_serializable(): + raise ValueError( + "Out-of-band serialization is only supported for Datasets created from " + "lazy datasources. I.e., out-of-band serialization is not " + "supported for any ray.data.from_*() APIs. To allow this Dataset to be " + "out-of-band serialized, write the data to an external store (such as " + "AWS S3, GCS, or Azure Blob Storage) using the Dataset.write_*() APIs, " + "and serialize a new dataset reading from the external store using the " + "ray.data.read_*() APIs." + ) + # Copy Dataset and clear the execution plan so the Dataset is out-of-band + # serializable. + plan_copy = self._plan.deep_copy(preserve_uuid=True) + ds = Dataset(plan_copy, self._get_epoch(), self._lazy) + ds._plan.clear() + ds._set_uuid(self._get_uuid()) + + def _reduce(rf: ray.remote_function.RemoteFunction): + # Custom reducer for Ray remote function handles that allows for + # cross-cluster serialization. + # TODO(Clark): Fix this in core Ray. + reconstructor, args, state = rf.__reduce__() + # Manually unset last export session and job to force re-exporting of the + # function when the handle is deserialized on a new cluster. + state["_last_export_session_and_job"] = None + return reconstructor, args, state + + context = ray.worker.global_worker.get_serialization_context() + try: + context._register_cloudpickle_reducer( + ray.remote_function.RemoteFunction, _reduce + ) + serialized = pickle.dumps(ds) + finally: + context._unregister_cloudpickle_reducer(ray.remote_function.RemoteFunction) + return serialized + + @DeveloperAPI + @staticmethod + def deserialize_out_of_band(serialized_ds: bytes) -> "Dataset": + """ + Deserialize the provided out-of-band serialized Dataset. + + This assumes that the provided serialized bytes were serialized using + ``Dataset.serialize_out_of_band``. + + Args: + serialized_ds: The serialized Dataset that we wish to deserialize. + + Returns: + A deserialized ``Dataset`` instance. + """ + return pickle.loads(serialized_ds) + def _split( self, index: int, return_right_half: bool ) -> ("Dataset[T]", "Dataset[T]"): diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index 25e679ffd9b3..c5070f0b1821 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -13,12 +13,13 @@ import ray from ray.tests.conftest import * # noqa -from ray.data.dataset import _sliding_window +from ray.data.dataset import Dataset, _sliding_window from ray.data.datasource.csv_datasource import CSVDatasource from ray.data.block import BlockAccessor from ray.data.row import TableRow from ray.data.impl.arrow_block import ArrowRow from ray.data.impl.block_builder import BlockBuilder +from ray.data.impl.lazy_block_list import LazyBlockList from ray.data.impl.pandas_block import PandasRow from ray.data.aggregate import AggregateFn, Count, Sum, Min, Max, Mean, Std from ray.data.extensions.tensor_extension import ( @@ -38,6 +39,13 @@ def maybe_pipeline(ds, enabled): return ds +def maybe_lazy(ds, enabled): + if enabled: + return ds._experimental_lazy() + else: + return ds + + class SlowCSVDatasource(CSVDatasource): def _read_stream(self, f: "pa.NativeFile", path: str, **reader_args): for block in CSVDatasource._read_stream(self, f, path, **reader_args): @@ -184,6 +192,55 @@ def mapper(x): ds.map(mapper) +@pytest.mark.parametrize("lazy", [False, True]) +def test_dataset_out_of_band_serialization(shutdown_only, lazy): + ray.init() + ds = ray.data.range(10) + ds = maybe_lazy(ds, lazy) + ds = ds.map(lambda x: x + 1) + ds = ds.map(lambda x: x + 1) + ds = ds.random_shuffle() + epoch = ds._get_epoch() + uuid = ds._get_uuid() + plan_uuid = ds._plan._dataset_uuid + lazy = ds._lazy + + serialized_ds = ds.serialize_out_of_band() + # Confirm that the original Dataset was properly copied before clearing/mutating. + in_blocks = ds._plan._in_blocks + # Should not raise. + in_blocks._check_if_cleared() + if lazy and isinstance(in_blocks, LazyBlockList): + assert in_blocks._block_partition_refs[0] is not None + if not lazy: + assert ds._plan._snapshot_blocks is not None + + ray.shutdown() + ray.init() + + ds = Dataset.deserialize_out_of_band(serialized_ds) + # Check Dataset state. + assert ds._get_epoch() == epoch + assert ds._get_uuid() == uuid + assert ds._plan._dataset_uuid == plan_uuid + assert ds._lazy == lazy + # Check Dataset content. + assert ds.count() == 10 + assert sorted(ds.take()) == list(range(2, 12)) + + +@pytest.mark.parametrize("lazy", [False, True]) +def test_dataset_out_of_band_serialization_in_memory(shutdown_only, lazy): + ray.init() + ds = ray.data.from_items(list(range(10))) + ds = maybe_lazy(ds, lazy) + ds = ds.map(lambda x: x + 1) + ds = ds.map(lambda x: x + 1) + + with pytest.raises(ValueError): + ds.serialize_out_of_band() + + @pytest.mark.parametrize("pipelined", [False, True]) def test_basic(ray_start_regular_shared, pipelined): ds0 = ray.data.range(5) From f124af9ffdb2de1f19cc0e936ab78d094738267e Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Sun, 24 Apr 2022 18:07:35 +0000 Subject: [PATCH 2/3] PR feedback: orient APIs around serialization of lineage; add issue link for remote function handle serialization issue. --- python/ray/data/dataset.py | 65 +++++++++++++++------------ python/ray/data/impl/plan.py | 2 +- python/ray/data/tests/test_dataset.py | 10 ++--- 3 files changed, 43 insertions(+), 34 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 09aeb03a2a03..a2e1090bf203 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2851,52 +2851,61 @@ def _experimental_lazy(self) -> "Dataset[T]": self._lazy = True return self - def is_out_of_band_serializable(self) -> bool: - """Whether this dataset is able to be out-of-band serialized, i.e. serialized - for use across different Ray clusters. Only datasets that read from lazy - datasources (i.e. via one of the ray.data.read_*() APIs) are out-of-band - serializable. + def has_serializable_lineage(self) -> bool: + """Whether this dataset's lineage is able to be serialized for storage and + later deserialized, possibly on a different cluster. + + Only datasets that are created from data that we know will still exist at + deserialization time, e.g. data external to this Ray cluster such as persistent + cloud object stores, support lineage-based serialization. All of the + ray.data.read_*() APIs support lineage-based serialization. """ return self._plan.has_lazy_input() @DeveloperAPI - def serialize_out_of_band(self) -> bytes: + def serialize_lineage(self) -> bytes: """ - Serialize the Dataset for out-of-band use, i.e. for use across different Ray - clusters. This method serializes the lineage of the Dataset operations. Note - that this will drop all computed data, and that everything will be recomputed - from scratch after deserialization. + Serialize this dataset's lineage, not the actual data or the existing data + futures, to bytes that can be stored and later deserialized, possibly on a + different cluster. - Use ``Dataset.deserialize_out_of_band`` to deserialize the serialized bytes - into a Dataset. + Note that this will drop all computed data, and that everything will be + recomputed from scratch after deserialization. + + Use :py:meth:`Dataset.deserialize_lineage` to deserialize the serialized bytes + returned from this method into a Dataset. Returns: - Serialized bytes. + Serialized bytes containing the lineage of this dataset. """ - if not self.is_out_of_band_serializable(): + if not self.has_serializable_lineage(): raise ValueError( - "Out-of-band serialization is only supported for Datasets created from " - "lazy datasources. I.e., out-of-band serialization is not " - "supported for any ray.data.from_*() APIs. To allow this Dataset to be " - "out-of-band serialized, write the data to an external store (such as " - "AWS S3, GCS, or Azure Blob Storage) using the Dataset.write_*() APIs, " - "and serialize a new dataset reading from the external store using the " - "ray.data.read_*() APIs." + "Lineage-based serialization is only supported for Datasets created " + "from data that we know will still exist at deserialization " + "time, e.g. external data in persistent cloud object stores or " + "in-memory data from long-lived clusters. Concretely, all " + "ray.data.read_*() APIs should support lineage-based serialization, " + "while all of the ray.data.from_*() APIs do not. To allow this " + "Dataset to be serialized to storage, write the data to an external " + "store (such as AWS S3, GCS, or Azure Blob Storage) using the " + "Dataset.write_*() APIs, and serialize a new dataset reading from the " + "external store using the ray.data.read_*() APIs." ) # Copy Dataset and clear the execution plan so the Dataset is out-of-band # serializable. plan_copy = self._plan.deep_copy(preserve_uuid=True) ds = Dataset(plan_copy, self._get_epoch(), self._lazy) - ds._plan.clear() + ds._plan.clear_block_refs() ds._set_uuid(self._get_uuid()) def _reduce(rf: ray.remote_function.RemoteFunction): # Custom reducer for Ray remote function handles that allows for # cross-cluster serialization. - # TODO(Clark): Fix this in core Ray. + # This manually unsets the last export session and job to force re-exporting + # of the function when the handle is deserialized on a new cluster. + # TODO(Clark): Fix this in core Ray, see issue: + # https://github.com/ray-project/ray/issues/24152. reconstructor, args, state = rf.__reduce__() - # Manually unset last export session and job to force re-exporting of the - # function when the handle is deserialized on a new cluster. state["_last_export_session_and_job"] = None return reconstructor, args, state @@ -2912,12 +2921,12 @@ def _reduce(rf: ray.remote_function.RemoteFunction): @DeveloperAPI @staticmethod - def deserialize_out_of_band(serialized_ds: bytes) -> "Dataset": + def deserialize_lineage(serialized_ds: bytes) -> "Dataset": """ - Deserialize the provided out-of-band serialized Dataset. + Deserialize the provided lineage-serialized Dataset. This assumes that the provided serialized bytes were serialized using - ``Dataset.serialize_out_of_band``. + :py:meth:`Dataset.serialize_lineage`. Args: serialized_ds: The serialized Dataset that we wish to deserialize. diff --git a/python/ray/data/impl/plan.py b/python/ray/data/impl/plan.py index 6feff83268e5..2cdd3c9a6232 100644 --- a/python/ray/data/impl/plan.py +++ b/python/ray/data/impl/plan.py @@ -256,7 +256,7 @@ def execute( self._snapshot_blocks = self._snapshot_blocks.compute_to_blocklist() return self._snapshot_blocks - def clear(self) -> None: + def clear_block_refs(self) -> None: """Clear all cached block references of this plan, including input blocks. This will render the plan un-executable unless the root is a LazyBlockList.""" diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index c5070f0b1821..ab4e783f582e 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -193,7 +193,7 @@ def mapper(x): @pytest.mark.parametrize("lazy", [False, True]) -def test_dataset_out_of_band_serialization(shutdown_only, lazy): +def test_dataset_lineage_serialization(shutdown_only, lazy): ray.init() ds = ray.data.range(10) ds = maybe_lazy(ds, lazy) @@ -205,7 +205,7 @@ def test_dataset_out_of_band_serialization(shutdown_only, lazy): plan_uuid = ds._plan._dataset_uuid lazy = ds._lazy - serialized_ds = ds.serialize_out_of_band() + serialized_ds = ds.serialize_lineage() # Confirm that the original Dataset was properly copied before clearing/mutating. in_blocks = ds._plan._in_blocks # Should not raise. @@ -218,7 +218,7 @@ def test_dataset_out_of_band_serialization(shutdown_only, lazy): ray.shutdown() ray.init() - ds = Dataset.deserialize_out_of_band(serialized_ds) + ds = Dataset.deserialize_lineage(serialized_ds) # Check Dataset state. assert ds._get_epoch() == epoch assert ds._get_uuid() == uuid @@ -230,7 +230,7 @@ def test_dataset_out_of_band_serialization(shutdown_only, lazy): @pytest.mark.parametrize("lazy", [False, True]) -def test_dataset_out_of_band_serialization_in_memory(shutdown_only, lazy): +def test_dataset_lineage_serialization_in_memory(shutdown_only, lazy): ray.init() ds = ray.data.from_items(list(range(10))) ds = maybe_lazy(ds, lazy) @@ -238,7 +238,7 @@ def test_dataset_out_of_band_serialization_in_memory(shutdown_only, lazy): ds = ds.map(lambda x: x + 1) with pytest.raises(ValueError): - ds.serialize_out_of_band() + ds.serialize_lineage() @pytest.mark.parametrize("pipelined", [False, True]) From 30462dc38a20925a7397682455baba9913f37a4b Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Sun, 24 Apr 2022 23:33:49 +0000 Subject: [PATCH 3/3] PR feedback: update comment; fix test. --- python/ray/data/dataset.py | 4 ++-- python/ray/data/tests/test_dataset.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index a2e1090bf203..39917618bd12 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2891,8 +2891,8 @@ def serialize_lineage(self) -> bytes: "Dataset.write_*() APIs, and serialize a new dataset reading from the " "external store using the ray.data.read_*() APIs." ) - # Copy Dataset and clear the execution plan so the Dataset is out-of-band - # serializable. + # Copy Dataset and clear the blocks from the execution plan so only the + # Dataset's lineage is serialized. plan_copy = self._plan.deep_copy(preserve_uuid=True) ds = Dataset(plan_copy, self._get_epoch(), self._lazy) ds._plan.clear_block_refs() diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index ab4e783f582e..8d34470faf43 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -210,9 +210,10 @@ def test_dataset_lineage_serialization(shutdown_only, lazy): in_blocks = ds._plan._in_blocks # Should not raise. in_blocks._check_if_cleared() - if lazy and isinstance(in_blocks, LazyBlockList): + assert isinstance(in_blocks, LazyBlockList) + if lazy: assert in_blocks._block_partition_refs[0] is not None - if not lazy: + else: assert ds._plan._snapshot_blocks is not None ray.shutdown()