From b353f4f1125a81eb4eda2b45eb27951a6b4e9944 Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Sun, 27 Oct 2024 14:52:32 -0700 Subject: [PATCH 1/7] [Data] support batch_format for Sort Signed-off-by: Xingyu Long --- .../logical/operators/all_to_all_operator.py | 2 ++ .../planner/exchange/sort_task_spec.py | 12 +++++++--- .../_internal/planner/plan_all_to_all_op.py | 4 +++- python/ray/data/_internal/planner/sort.py | 5 +++- python/ray/data/dataset.py | 6 +++-- python/ray/data/grouped_data.py | 1 + .../data/tests/test_execution_optimizer.py | 24 +++++++++++++++++++ 7 files changed, 47 insertions(+), 7 deletions(-) diff --git a/python/ray/data/_internal/logical/operators/all_to_all_operator.py b/python/ray/data/_internal/logical/operators/all_to_all_operator.py index 3179871c3685..b1978534c1cb 100644 --- a/python/ray/data/_internal/logical/operators/all_to_all_operator.py +++ b/python/ray/data/_internal/logical/operators/all_to_all_operator.py @@ -120,6 +120,7 @@ def __init__( self, input_op: LogicalOperator, sort_key: SortKey, + batch_format: Optional[str] = "default", ): super().__init__( "Sort", @@ -131,6 +132,7 @@ def __init__( ], ) self._sort_key = sort_key + self._batch_format = batch_format def aggregate_output_metadata(self) -> BlockMetadata: assert len(self._input_dependencies) == 1, len(self._input_dependencies) diff --git a/python/ray/data/_internal/planner/exchange/sort_task_spec.py b/python/ray/data/_internal/planner/exchange/sort_task_spec.py index edeea0639464..1972c6e39404 100644 --- a/python/ray/data/_internal/planner/exchange/sort_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/sort_task_spec.py @@ -6,6 +6,7 @@ from ray.data._internal.planner.exchange.interfaces import ExchangeTaskSpec from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn +from ray.data._internal.table_block import TableBlockAccessor from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata from ray.types import ObjectRef @@ -116,10 +117,11 @@ def __init__( self, boundaries: List[T], sort_key: SortKey, + batch_format: Optional[str] = "default", ): super().__init__( map_args=[boundaries, sort_key], - reduce_args=[sort_key], + reduce_args=[sort_key, batch_format], ) @staticmethod @@ -138,11 +140,15 @@ def map( @staticmethod def reduce( sort_key: SortKey, + batch_format: str, *mapper_outputs: List[Block], partial_reduce: bool = False, ) -> Tuple[Block, BlockMetadata]: - return BlockAccessor.for_block(mapper_outputs[0]).merge_sorted_blocks( - mapper_outputs, sort_key + normalized_blocks = TableBlockAccessor.normalize_block_types( + mapper_outputs, normalize_type=batch_format + ) + return BlockAccessor.for_block(normalized_blocks[0]).merge_sorted_blocks( + normalized_blocks, sort_key ) @staticmethod diff --git a/python/ray/data/_internal/planner/plan_all_to_all_op.py b/python/ray/data/_internal/planner/plan_all_to_all_op.py index fc7f7fdac954..a546a4931ee7 100644 --- a/python/ray/data/_internal/planner/plan_all_to_all_op.py +++ b/python/ray/data/_internal/planner/plan_all_to_all_op.py @@ -71,7 +71,9 @@ def plan_all_to_all_op( "debug_limit_shuffle_execution_to_num_blocks", None ) ) - fn = generate_sort_fn(op._sort_key, debug_limit_shuffle_execution_to_num_blocks) + fn = generate_sort_fn( + op._sort_key, op._batch_format, debug_limit_shuffle_execution_to_num_blocks + ) target_max_block_size = DataContext.get_current().target_shuffle_max_block_size elif isinstance(op, Aggregate): debug_limit_shuffle_execution_to_num_blocks = ( diff --git a/python/ray/data/_internal/planner/sort.py b/python/ray/data/_internal/planner/sort.py index bf46fdad7039..1a14e9f260ae 100644 --- a/python/ray/data/_internal/planner/sort.py +++ b/python/ray/data/_internal/planner/sort.py @@ -20,6 +20,7 @@ def generate_sort_fn( sort_key: SortKey, + batch_format: str, _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None, ) -> AllToAllTransformFn: """Generate function to sort blocks by the specified key column or key function.""" @@ -56,7 +57,9 @@ def fn( _, ascending = sort_key.to_pandas_sort_args() if not ascending: boundaries.reverse() - sort_spec = SortTaskSpec(boundaries=boundaries, sort_key=sort_key) + sort_spec = SortTaskSpec( + boundaries=boundaries, sort_key=sort_key, batch_format=batch_format + ) if DataContext.get_current().use_push_based_shuffle: scheduler = PushBasedShuffleTaskScheduler(sort_spec) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index d4c8d887da34..f9623ed36fc0 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -229,6 +229,9 @@ def __init__( self._current_executor: Optional["Executor"] = None self._write_ds = None + # Handle sort on empty blocks and still carry the correct batch_format + self._batch_format = "default" + self._set_uuid(StatsManager.get_dataset_id_from_stats_actor()) @staticmethod @@ -2397,8 +2400,7 @@ def sort( sort_key = SortKey(key, descending, boundaries) plan = self._plan.copy() op = Sort( - self._logical_plan.dag, - sort_key=sort_key, + self._logical_plan.dag, sort_key=sort_key, batch_format=self._batch_format ) logical_plan = LogicalPlan(op, self.context) return Dataset(plan, logical_plan) diff --git a/python/ray/data/grouped_data.py b/python/ray/data/grouped_data.py index 8f7b7dde118d..9252b0f36c5f 100644 --- a/python/ray/data/grouped_data.py +++ b/python/ray/data/grouped_data.py @@ -195,6 +195,7 @@ def map_groups( The return type is determined by the return type of ``fn``, and the return value is combined from results of all groups. """ + self._dataset._batch_format = batch_format # Globally sort records by key. # Note that sort() will ensure that records of the same key partitioned # into the same block. diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 234162871d3c..9cffb4a5784e 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -1172,6 +1172,30 @@ def test_sort_validate_keys(ray_start_regular_shared): ds_named.sort(invalid_col_name).take_all() +def test_sort_on_group_data(ray_start_regular_shared): + ds = ray.data.from_items( + [ + {"col1": 1, "col2": 2}, + {"col1": 1, "col2": 4}, + {"col1": 5, "col2": 6}, + {"col1": 7, "col2": 8}, + ] + ) + df_expected = pd.DataFrame( + { + "col1": [7, 5, 1, 1], + "col2": [8, 6, 4, 2], + } + ) + df_actual = ( + ds.groupby("col1") + .map_groups(lambda g: g, batch_format="pandas") + .sort("col2", descending=True) + .to_pandas() + ) + pd.testing.assert_frame_equal(df_actual, df_expected) + + def test_aggregate_operator(ray_start_regular_shared): ctx = DataContext.get_current() From 86af289eb02137764747e1149e18b751245c1fb7 Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Mon, 28 Oct 2024 21:28:18 -0700 Subject: [PATCH 2/7] Remove batch_format at dataset level to address Scott's comments Signed-off-by: Xingyu Long --- .../ray/data/_internal/planner/exchange/sort_task_spec.py | 2 +- python/ray/data/dataset.py | 6 ++---- python/ray/data/grouped_data.py | 1 - 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/python/ray/data/_internal/planner/exchange/sort_task_spec.py b/python/ray/data/_internal/planner/exchange/sort_task_spec.py index 1972c6e39404..5bef0a9db3f7 100644 --- a/python/ray/data/_internal/planner/exchange/sort_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/sort_task_spec.py @@ -117,7 +117,7 @@ def __init__( self, boundaries: List[T], sort_key: SortKey, - batch_format: Optional[str] = "default", + batch_format: Optional[str], ): super().__init__( map_args=[boundaries, sort_key], diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index f9623ed36fc0..d4c8d887da34 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -229,9 +229,6 @@ def __init__( self._current_executor: Optional["Executor"] = None self._write_ds = None - # Handle sort on empty blocks and still carry the correct batch_format - self._batch_format = "default" - self._set_uuid(StatsManager.get_dataset_id_from_stats_actor()) @staticmethod @@ -2400,7 +2397,8 @@ def sort( sort_key = SortKey(key, descending, boundaries) plan = self._plan.copy() op = Sort( - self._logical_plan.dag, sort_key=sort_key, batch_format=self._batch_format + self._logical_plan.dag, + sort_key=sort_key, ) logical_plan = LogicalPlan(op, self.context) return Dataset(plan, logical_plan) diff --git a/python/ray/data/grouped_data.py b/python/ray/data/grouped_data.py index 9252b0f36c5f..8f7b7dde118d 100644 --- a/python/ray/data/grouped_data.py +++ b/python/ray/data/grouped_data.py @@ -195,7 +195,6 @@ def map_groups( The return type is determined by the return type of ``fn``, and the return value is combined from results of all groups. """ - self._dataset._batch_format = batch_format # Globally sort records by key. # Note that sort() will ensure that records of the same key partitioned # into the same block. From 21eb19d1d402a70b9a24c26cbdb6b812d194cee6 Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Mon, 28 Oct 2024 21:37:40 -0700 Subject: [PATCH 3/7] Add inherit_batch_format rule Signed-off-by: Xingyu Long --- .../logical/operators/all_to_all_operator.py | 2 + .../ray/data/_internal/logical/optimizers.py | 2 + .../logical/rules/inherit_batch_format.py | 40 +++++++++++++++++++ .../ray/data/_internal/planner/aggregate.py | 2 + .../planner/exchange/aggregate_task_spec.py | 11 +++-- .../_internal/planner/plan_all_to_all_op.py | 5 ++- 6 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 python/ray/data/_internal/logical/rules/inherit_batch_format.py diff --git a/python/ray/data/_internal/logical/operators/all_to_all_operator.py b/python/ray/data/_internal/logical/operators/all_to_all_operator.py index b1978534c1cb..745103f0036f 100644 --- a/python/ray/data/_internal/logical/operators/all_to_all_operator.py +++ b/python/ray/data/_internal/logical/operators/all_to_all_operator.py @@ -147,6 +147,7 @@ def __init__( input_op: LogicalOperator, key: Optional[str], aggs: List[AggregateFn], + batch_format: Optional[str] = "default", ): super().__init__( "Aggregate", @@ -159,3 +160,4 @@ def __init__( ) self._key = key self._aggs = aggs + self._batch_format = batch_format diff --git a/python/ray/data/_internal/logical/optimizers.py b/python/ray/data/_internal/logical/optimizers.py index e50013d4a13c..a7c2b68c06fe 100644 --- a/python/ray/data/_internal/logical/optimizers.py +++ b/python/ray/data/_internal/logical/optimizers.py @@ -6,6 +6,7 @@ PhysicalPlan, Rule, ) +from ray.data._internal.logical.rules.inherit_batch_format import InheritBatchFormatRule from ray.data._internal.logical.rules.inherit_target_max_block_size import ( InheritTargetMaxBlockSizeRule, ) @@ -20,6 +21,7 @@ _LOGICAL_RULES = [ ReorderRandomizeBlocksRule, + InheritBatchFormatRule, ] _PHYSICAL_RULES = [ diff --git a/python/ray/data/_internal/logical/rules/inherit_batch_format.py b/python/ray/data/_internal/logical/rules/inherit_batch_format.py new file mode 100644 index 000000000000..264da93b6c41 --- /dev/null +++ b/python/ray/data/_internal/logical/rules/inherit_batch_format.py @@ -0,0 +1,40 @@ +from collections import deque +from typing import Iterable + +from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule +from ray.data._internal.logical.operators.all_to_all_operator import Aggregate, Sort + + +class InheritBatchFormatRule(Rule): + """For Sort and Aggregate logicla operator, apply this rule + to inherit batch_format from upstream operator by traversing + the entire DAG.""" + + def apply(self, plan: LogicalPlan) -> LogicalPlan: + optimized_dag: LogicalOperator = self._apply(plan.dag) + new_plan = LogicalPlan(dag=optimized_dag, context=plan.context) + return new_plan + + def _apply(self, op: LogicalOperator): + # Post-order traversal. + nodes: Iterable[LogicalOperator] = deque() + for node in op.post_order_iter(): + nodes.appendleft(node) + + while len(nodes) > 0: + current_op = nodes.pop() + + if isinstance(current_op, (Sort, Aggregate)): + # traversal up the DAG until we find first operator with batch_format + # or we reach to source op and do nothing + upstream_op = current_op.input_dependencies[0] + while ( + upstream_op.input_dependencies + and getattr(upstream_op, "_batch_format", None) is None + ): + upstream_op = upstream_op.input_dependencies[0] + if getattr(upstream_op, "_batch_format", None): + current_op._batch_format = upstream_op._batch_format + + # just return the default op + return op diff --git a/python/ray/data/_internal/planner/aggregate.py b/python/ray/data/_internal/planner/aggregate.py index 6a2a6c1482d1..8f177add41d9 100644 --- a/python/ray/data/_internal/planner/aggregate.py +++ b/python/ray/data/_internal/planner/aggregate.py @@ -24,6 +24,7 @@ def generate_aggregate_fn( key: Optional[str], aggs: List[AggregateFn], + batch_format: str, _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None, ) -> AllToAllTransformFn: """Generate function to aggregate blocks by the specified key column or key @@ -67,6 +68,7 @@ def fn( boundaries=boundaries, key=key, aggs=aggs, + batch_format=batch_format, ) if DataContext.get_current().use_push_based_shuffle: scheduler = PushBasedShuffleTaskScheduler(agg_spec) diff --git a/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py b/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py index 91d77863e40b..7b0aa0dc7ad8 100644 --- a/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/aggregate_task_spec.py @@ -29,10 +29,11 @@ def __init__( boundaries: List[KeyType], key: Optional[str], aggs: List[AggregateFn], + batch_format: str, ): super().__init__( map_args=[boundaries, key, aggs], - reduce_args=[key, aggs], + reduce_args=[key, aggs, batch_format], ) @staticmethod @@ -62,11 +63,15 @@ def map( def reduce( key: Optional[str], aggs: List[AggregateFn], + batch_format: str, *mapper_outputs: List[Block], partial_reduce: bool = False, ) -> Tuple[Block, BlockMetadata]: - return BlockAccessor.for_block(mapper_outputs[0]).aggregate_combined_blocks( - list(mapper_outputs), key, aggs, finalize=not partial_reduce + normalized_blocks = TableBlockAccessor.normalize_block_types( + mapper_outputs, normalize_type=batch_format + ) + return BlockAccessor.for_block(normalized_blocks[0]).aggregate_combined_blocks( + list(normalized_blocks), key, aggs, finalize=not partial_reduce ) @staticmethod diff --git a/python/ray/data/_internal/planner/plan_all_to_all_op.py b/python/ray/data/_internal/planner/plan_all_to_all_op.py index a546a4931ee7..13c13ea6a9a2 100644 --- a/python/ray/data/_internal/planner/plan_all_to_all_op.py +++ b/python/ray/data/_internal/planner/plan_all_to_all_op.py @@ -82,7 +82,10 @@ def plan_all_to_all_op( ) ) fn = generate_aggregate_fn( - op._key, op._aggs, debug_limit_shuffle_execution_to_num_blocks + op._key, + op._aggs, + op._batch_format, + debug_limit_shuffle_execution_to_num_blocks, ) target_max_block_size = DataContext.get_current().target_shuffle_max_block_size else: From 9e351887bf14d47b26fdd55ebc2acbb27138277a Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Sun, 3 Nov 2024 09:59:27 -0800 Subject: [PATCH 4/7] Update tests to verify inherit_batch_format rule Signed-off-by: Xingyu Long --- .../data/tests/test_execution_optimizer.py | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 9cffb4a5784e..16f26deaac71 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -1172,7 +1172,25 @@ def test_sort_validate_keys(ray_start_regular_shared): ds_named.sort(invalid_col_name).take_all() -def test_sort_on_group_data(ray_start_regular_shared): +def test_inherit_batch_format_rule(): + from ray.data._internal.logical.rules.inherit_batch_format import ( + InheritBatchFormatRule, + ) + + ctx = DataContext.get_current() + + operator1 = get_parquet_read_logical_op() + operator2 = MapBatches(operator1, fn=lambda g: g, batch_format="pandas") + sort_key = SortKey("number", descending=True) + operator3 = Sort(operator2, sort_key) + original_plan = LogicalPlan(dag=operator3, context=ctx) + + rule = InheritBatchFormatRule() + optimized_plan = rule.apply(original_plan) + assert optimized_plan.dag._batch_format == "pandas" + + +def test_batch_format_on_sort(ray_start_regular_shared): ds = ray.data.from_items( [ {"col1": 1, "col2": 2}, @@ -1196,6 +1214,30 @@ def test_sort_on_group_data(ray_start_regular_shared): pd.testing.assert_frame_equal(df_actual, df_expected) +def test_batch_format_on_aggregate(ray_start_regular_shared): + from ray.data.aggregate import AggregateFn + + ds = ray.data.from_items( + [ + {"col1": 1, "col2": 2}, + {"col1": 1, "col2": 4}, + {"col1": 5, "col2": 6}, + {"col1": 7, "col2": 8}, + ] + ) + aggregation = AggregateFn( + init=lambda column: 1, + accumulate_row=lambda a, row: a * row["col2"], + merge=lambda a1, a2: a1 * a2, + name="prod", + ) + assert ( + ds.groupby("col1") + .map_groups(lambda g: g, batch_format="pandas") + .aggregate(aggregation) + ) == {"prod": 384} + + def test_aggregate_operator(ray_start_regular_shared): ctx = DataContext.get_current() From e4f5009a2a7936f4506205405fc8bb6c3d1e6892 Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Thu, 7 Nov 2024 18:55:51 -0800 Subject: [PATCH 5/7] address the comments Signed-off-by: Xingyu Long --- .../logical/rules/inherit_batch_format.py | 16 +++++++++------- .../_internal/planner/exchange/sort_task_spec.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/inherit_batch_format.py b/python/ray/data/_internal/logical/rules/inherit_batch_format.py index 264da93b6c41..38009a1225c9 100644 --- a/python/ray/data/_internal/logical/rules/inherit_batch_format.py +++ b/python/ray/data/_internal/logical/rules/inherit_batch_format.py @@ -3,6 +3,7 @@ from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule from ray.data._internal.logical.operators.all_to_all_operator import Aggregate, Sort +from ray.data._internal.logical.operators.map_operator import MapBatches class InheritBatchFormatRule(Rule): @@ -25,16 +26,17 @@ def _apply(self, op: LogicalOperator): current_op = nodes.pop() if isinstance(current_op, (Sort, Aggregate)): - # traversal up the DAG until we find first operator with batch_format + # traversal up the DAG until we find MapBatches with batch_format # or we reach to source op and do nothing upstream_op = current_op.input_dependencies[0] - while ( - upstream_op.input_dependencies - and getattr(upstream_op, "_batch_format", None) is None - ): + while upstream_op.input_dependencies: + if ( + isinstance(upstream_op, MapBatches) + and upstream_op._batch_format + ): + current_op._batch_format = upstream_op._batch_format + break upstream_op = upstream_op.input_dependencies[0] - if getattr(upstream_op, "_batch_format", None): - current_op._batch_format = upstream_op._batch_format # just return the default op return op diff --git a/python/ray/data/_internal/planner/exchange/sort_task_spec.py b/python/ray/data/_internal/planner/exchange/sort_task_spec.py index 5bef0a9db3f7..299e8793774f 100644 --- a/python/ray/data/_internal/planner/exchange/sort_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/sort_task_spec.py @@ -117,7 +117,7 @@ def __init__( self, boundaries: List[T], sort_key: SortKey, - batch_format: Optional[str], + batch_format: str, ): super().__init__( map_args=[boundaries, sort_key], From afeb03ceb0a422c193b1f06c53ddafbfc95be53c Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Fri, 8 Nov 2024 16:24:02 -0800 Subject: [PATCH 6/7] Use AbstractAllToAll instead of limiting to Sort and Aggregate Signed-off-by: Xingyu Long --- .../data/_internal/logical/rules/inherit_batch_format.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/inherit_batch_format.py b/python/ray/data/_internal/logical/rules/inherit_batch_format.py index 38009a1225c9..2dd265cd08b1 100644 --- a/python/ray/data/_internal/logical/rules/inherit_batch_format.py +++ b/python/ray/data/_internal/logical/rules/inherit_batch_format.py @@ -2,12 +2,12 @@ from typing import Iterable from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule -from ray.data._internal.logical.operators.all_to_all_operator import Aggregate, Sort +from ray.data._internal.logical.operators.all_to_all_operator import AbstractAllToAll from ray.data._internal.logical.operators.map_operator import MapBatches class InheritBatchFormatRule(Rule): - """For Sort and Aggregate logicla operator, apply this rule + """For AbstractAllToAll based operator, apply this rule to inherit batch_format from upstream operator by traversing the entire DAG.""" @@ -25,7 +25,7 @@ def _apply(self, op: LogicalOperator): while len(nodes) > 0: current_op = nodes.pop() - if isinstance(current_op, (Sort, Aggregate)): + if isinstance(current_op, AbstractAllToAll): # traversal up the DAG until we find MapBatches with batch_format # or we reach to source op and do nothing upstream_op = current_op.input_dependencies[0] From b92c5f6b8171c444a5fb60260f1a5f889e213848 Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Tue, 12 Nov 2024 17:35:51 -0800 Subject: [PATCH 7/7] add comments for test cases Signed-off-by: Xingyu Long --- python/ray/data/tests/test_execution_optimizer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 16f26deaac71..d657ce1c9d98 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -1191,6 +1191,7 @@ def test_inherit_batch_format_rule(): def test_batch_format_on_sort(ray_start_regular_shared): + """Checks that the Sort op can inherit batch_format from upstream ops correctly.""" ds = ray.data.from_items( [ {"col1": 1, "col2": 2}, @@ -1215,6 +1216,8 @@ def test_batch_format_on_sort(ray_start_regular_shared): def test_batch_format_on_aggregate(ray_start_regular_shared): + """Checks that the Aggregate op can inherit batch_format + from upstream ops correctly.""" from ray.data.aggregate import AggregateFn ds = ray.data.from_items(