Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
min_rows_per_bundled_input: Optional[int] = None,
compute: Optional[ComputeStrategy] = None,
udf_modifying_row_count: bool = True,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
Expand All @@ -188,9 +189,10 @@ def __init__(
self._batch_size = batch_size
self._batch_format = batch_format
self._zero_copy_batch = zero_copy_batch
self._udf_modifying_row_count = udf_modifying_row_count

def can_modify_num_rows(self) -> bool:
return False
return self._udf_modifying_row_count


class MapRows(AbstractUDFMap):
Expand Down
7 changes: 1 addition & 6 deletions python/ray/data/_internal/logical/rules/limit_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List

from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule
from ray.data._internal.logical.operators.map_operator import AbstractMap, MapBatches
from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data._internal.logical.operators.n_ary_operator import Union
from ray.data._internal.logical.operators.one_to_one_operator import (
AbstractOneToOne,
Expand Down Expand Up @@ -125,21 +125,16 @@ def _push_limit_down(self, limit_op: Limit) -> LogicalOperator:
# one of the stopping conditions
current_op = limit_op.input_dependency
num_rows_preserving_ops: List[LogicalOperator] = []

while (
isinstance(current_op, AbstractOneToOne)
and not current_op.can_modify_num_rows()
and not isinstance(current_op, MapBatches)
# We should push past MapBatches, but MapBatches can modify the row count
# TODO: add a flag in map_batches that allows the user to opt in ensure row preservation
):
num_rows_preserving_ops.append(current_op)
current_op = current_op.input_dependency

# If we couldn't push through any operators, return original
if not num_rows_preserving_ops:
return limit_op

# Apply per-block limit to the deepest operator if it supports it
limit_input = self._apply_per_block_limit_if_supported(
current_op, limit_op._limit
Expand Down
5 changes: 5 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def map_batches(
num_gpus: Optional[float] = None,
memory: Optional[float] = None,
concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None,
udf_modifying_row_count: bool = True,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
**ray_remote_args,
) -> "Dataset":
Expand Down Expand Up @@ -627,6 +628,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
worker.
memory: The heap memory in bytes to reserve for each parallel map worker.
concurrency: This argument is deprecated. Use ``compute`` argument.
udf_modifying_row_count: Set to False only if the UDF always emits the same number of records it receives (no drops or duplicates). When set to False, the logical optimizer, in the presence of a limit(limit=k), will only scan k rows prior to executing the UDF, thereby saving on compute resources.
ray_remote_args_fn: A function that returns a dictionary of remote args
passed to each map worker. The purpose of this argument is to generate
dynamic arguments for each actor/task, and will be called each time prior
Expand Down Expand Up @@ -695,6 +697,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
num_gpus=num_gpus,
memory=memory,
concurrency=concurrency,
udf_modifying_row_count=udf_modifying_row_count,
ray_remote_args_fn=ray_remote_args_fn,
**ray_remote_args,
)
Expand All @@ -715,6 +718,7 @@ def _map_batches_without_batch_size_validation(
num_gpus: Optional[float],
memory: Optional[float],
concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]],
udf_modifying_row_count: bool,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]],
**ray_remote_args,
):
Expand Down Expand Up @@ -768,6 +772,7 @@ def _map_batches_without_batch_size_validation(
fn_constructor_args=fn_constructor_args,
fn_constructor_kwargs=fn_constructor_kwargs,
compute=compute,
udf_modifying_row_count=udf_modifying_row_count,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
)
Expand Down
1 change: 1 addition & 0 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def wrapped_fn(batch, *args, **kwargs):
num_gpus=num_gpus,
memory=memory,
concurrency=concurrency,
udf_modifying_row_count=True,
ray_remote_args_fn=ray_remote_args_fn,
**ray_remote_args,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,5 +490,36 @@ def add_one(row):
assert result_with == expected


@pytest.mark.parametrize(
"udf_modifying_row_count,expected_plan",
[
(
False,
"Read[ReadRange] -> Limit[limit=10] -> MapBatches[MapBatches(<lambda>)]",
),
(
True,
"Read[ReadRange] -> MapBatches[MapBatches(<lambda>)] -> Limit[limit=10]",
),
],
)
def test_limit_pushdown_udf_modifying_row_count_with_map_batches(
ray_start_regular_shared_2_cpus,
udf_modifying_row_count,
expected_plan,
):
"""Test that limit pushdown preserves the row count with map batches."""
ds = (
ray.data.range(100)
.map_batches(lambda x: x, udf_modifying_row_count=udf_modifying_row_count)
.limit(10)
)
_check_valid_plan_and_result(
ds,
expected_plan,
[{"id": i} for i in range(10)],
)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
7 changes: 5 additions & 2 deletions python/ray/data/tests/test_operator_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def test_read_with_map_batches_fused_successfully(
),
(
# Fusion
MapBatches(InputData([]), lambda x: x),
MapBatches(InputData([]), lambda x: x, udf_modifying_row_count=False),
True,
),
(
Expand Down Expand Up @@ -337,7 +337,9 @@ def test_map_batches_batch_size_fusion(
LogicalPlan(input_op, context),
)

mapped_ds = ds.map_batches(lambda x: x, batch_size=2,).map_batches(
mapped_ds = ds.map_batches(
lambda x: x, batch_size=2, udf_modifying_row_count=False
).map_batches(
lambda x: x,
batch_size=5,
)
Expand Down Expand Up @@ -383,6 +385,7 @@ def test_map_batches_with_batch_size_specified_fusion(
mapped_ds = ds.map_batches(
lambda x: x,
batch_size=upstream_batch_size,
udf_modifying_row_count=False,
).map_batches(
lambda x: x,
batch_size=downstream_batch_size,
Expand Down