diff --git a/python/ray/data/_internal/execution/legacy_compat.py b/python/ray/data/_internal/execution/legacy_compat.py index 0579f38c2cea..1476dce38ad6 100644 --- a/python/ray/data/_internal/execution/legacy_compat.py +++ b/python/ray/data/_internal/execution/legacy_compat.py @@ -56,6 +56,13 @@ def execute_to_legacy_block_iterator( dag, stats = get_execution_plan(plan._logical_plan).dag, None else: dag, stats = _to_operator_dag(plan, allow_clear_input_blocks) + + # Enforce to preserve ordering if the plan has stages required to do so, such as + # Zip and Sort. + # TODO(chengsu): implement this for operator as well. + if plan.require_preserve_order(): + executor._options.preserve_order = True + bundle_iter = executor.execute(dag, initial_stats=stats) for bundle in bundle_iter: @@ -84,6 +91,13 @@ def execute_to_legacy_block_list( dag, stats = get_execution_plan(plan._logical_plan).dag, None else: dag, stats = _to_operator_dag(plan, allow_clear_input_blocks) + + # Enforce to preserve ordering if the plan has stages required to do so, such as + # Zip and Sort. + # TODO(chengsu): implement this for operator as well. + if plan.require_preserve_order(): + executor._options.preserve_order = True + bundles = executor.execute(dag, initial_stats=stats) block_list = _bundles_to_block_list(bundles) # Set the stats UUID after execution finishes. diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index bf53ec0668be..caeebec71a9e 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -764,6 +764,17 @@ def _run_with_new_execution_backend(self) -> bool: and self._stages_after_snapshot ) + def require_preserve_order(self) -> bool: + """Whether this plan requires to preserve order when running with new + backend. + """ + from ray.data._internal.stage_impl import SortStage, ZipStage + + for stage in self._stages_after_snapshot: + if isinstance(stage, ZipStage) or isinstance(stage, SortStage): + return True + return False + def _pack_args( self_fn_args: Iterable[Any], diff --git a/python/ray/data/tests/test_optimize.py b/python/ray/data/tests/test_optimize.py index 9267d0b783d6..b387fa7a9c65 100644 --- a/python/ray/data/tests/test_optimize.py +++ b/python/ray/data/tests/test_optimize.py @@ -710,6 +710,15 @@ def test_optimize_lazy_reuse_base_data( assert num_reads == num_blocks, num_reads +def test_require_preserve_order(ray_start_regular_shared): + ds = ray.data.range(100).map_batches(lambda x: x).sort() + assert ds._plan.require_preserve_order() + ds2 = ray.data.range(100).map_batches(lambda x: x).zip(ds) + assert ds2._plan.require_preserve_order() + ds3 = ray.data.range(100).map_batches(lambda x: x).repartition(10) + assert not ds3._plan.require_preserve_order() + + if __name__ == "__main__": import sys