Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
min_rows_per_bundled_input: Optional[int] = None,
compute: Optional[ComputeStrategy] = None,
preserve_row_count: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preserves_row_count

ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
Expand All @@ -188,9 +189,10 @@ def __init__(
self._batch_size = batch_size
self._batch_format = batch_format
self._zero_copy_batch = zero_copy_batch
self._preserve_row_count = preserve_row_count

def can_modify_num_rows(self) -> bool:
return False
return not self._preserve_row_count


class MapRows(AbstractUDFMap):
Expand Down
7 changes: 1 addition & 6 deletions python/ray/data/_internal/logical/rules/limit_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List

from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule
from ray.data._internal.logical.operators.map_operator import AbstractMap, MapBatches
from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data._internal.logical.operators.n_ary_operator import Union
from ray.data._internal.logical.operators.one_to_one_operator import (
AbstractOneToOne,
Expand Down Expand Up @@ -125,21 +125,16 @@ def _push_limit_down(self, limit_op: Limit) -> LogicalOperator:
# one of the stopping conditions
current_op = limit_op.input_dependency
num_rows_preserving_ops: List[LogicalOperator] = []

while (
isinstance(current_op, AbstractOneToOne)
and not current_op.can_modify_num_rows()
and not isinstance(current_op, MapBatches)
# We should push past MapBatches, but MapBatches can modify the row count
# TODO: add a flag in map_batches that allows the user to opt in ensure row preservation
):
num_rows_preserving_ops.append(current_op)
current_op = current_op.input_dependency

# If we couldn't push through any operators, return original
if not num_rows_preserving_ops:
return limit_op

# Apply per-block limit to the deepest operator if it supports it
limit_input = self._apply_per_block_limit_if_supported(
current_op, limit_op._limit
Expand Down
5 changes: 5 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def map_batches(
num_gpus: Optional[float] = None,
memory: Optional[float] = None,
concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None,
preserve_row_count: bool = False,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
**ray_remote_args,
) -> "Dataset":
Expand Down Expand Up @@ -627,6 +628,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
worker.
memory: The heap memory in bytes to reserve for each parallel map worker.
concurrency: This argument is deprecated. Use ``compute`` argument.
preserve_row_count: Set to True only if the UDF always emits the same number of records it receives (no drops or duplicates). When true, the optimizer can push downstream limits past this transform for better pruning.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the 2nd sentence: When set to True, the logical optimizer, in the presence of a limit(limit=k), will only scan k rows prior to executing the UDF, thereby saving on compute resources.

ray_remote_args_fn: A function that returns a dictionary of remote args
passed to each map worker. The purpose of this argument is to generate
dynamic arguments for each actor/task, and will be called each time prior
Expand Down Expand Up @@ -695,6 +697,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
num_gpus=num_gpus,
memory=memory,
concurrency=concurrency,
preserve_row_count=preserve_row_count,
ray_remote_args_fn=ray_remote_args_fn,
**ray_remote_args,
)
Expand All @@ -715,6 +718,7 @@ def _map_batches_without_batch_size_validation(
num_gpus: Optional[float],
memory: Optional[float],
concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]],
preserve_row_count: bool,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]],
**ray_remote_args,
):
Expand Down Expand Up @@ -768,6 +772,7 @@ def _map_batches_without_batch_size_validation(
fn_constructor_args=fn_constructor_args,
fn_constructor_kwargs=fn_constructor_kwargs,
compute=compute,
preserve_row_count=preserve_row_count,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)
Expand Down
1 change: 1 addition & 0 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def wrapped_fn(batch, *args, **kwargs):
num_gpus=num_gpus,
memory=memory,
concurrency=concurrency,
preserve_row_count=False,
ray_remote_args_fn=ray_remote_args_fn,
**ray_remote_args,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,5 +490,26 @@ def add_one(row):
assert result_with == expected


def test_limit_pushdown_preserve_row_count_with_map_batches(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a parameterized test?

ray_start_regular_shared_2_cpus,
):
"""Test that limit pushdown preserves the row count with map batches."""
ds = ray.data.range(100).map_batches(lambda x: x, preserve_row_count=True).limit(10)
_check_valid_plan_and_result(
ds,
"Read[ReadRange] -> Limit[limit=10] -> MapBatches[MapBatches(<lambda>)]",
[{"id": i} for i in range(10)],
)

ds = (
ray.data.range(100).map_batches(lambda x: x, preserve_row_count=False).limit(10)
)
_check_valid_plan_and_result(
ds,
"Read[ReadRange] -> MapBatches[MapBatches(<lambda>)] -> Limit[limit=10]",
[{"id": i} for i in range(10)],
)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))