-
Notifications
You must be signed in to change notification settings - Fork 7k
[Data] map_batches support limit_pushdown
#57880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
f97b26e
771122f
169d4d2
f173103
8d0416c
93e8111
3a8c54d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -463,6 +463,7 @@ def map_batches( | |
| num_gpus: Optional[float] = None, | ||
| memory: Optional[float] = None, | ||
| concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None, | ||
| preserve_row_count: bool = False, | ||
| ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None, | ||
| **ray_remote_args, | ||
| ) -> "Dataset": | ||
|
|
@@ -627,6 +628,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: | |
| worker. | ||
| memory: The heap memory in bytes to reserve for each parallel map worker. | ||
| concurrency: This argument is deprecated. Use ``compute`` argument. | ||
| preserve_row_count: Set to True only if the UDF always emits the same number of records it receives (no drops or duplicates). When true, the optimizer can push downstream limits past this transform for better pruning. | ||
|
||
| ray_remote_args_fn: A function that returns a dictionary of remote args | ||
| passed to each map worker. The purpose of this argument is to generate | ||
| dynamic arguments for each actor/task, and will be called each time prior | ||
|
|
@@ -695,6 +697,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: | |
| num_gpus=num_gpus, | ||
| memory=memory, | ||
| concurrency=concurrency, | ||
| preserve_row_count=preserve_row_count, | ||
| ray_remote_args_fn=ray_remote_args_fn, | ||
| **ray_remote_args, | ||
| ) | ||
|
|
@@ -715,6 +718,7 @@ def _map_batches_without_batch_size_validation( | |
| num_gpus: Optional[float], | ||
| memory: Optional[float], | ||
| concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]], | ||
| preserve_row_count: bool, | ||
| ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]], | ||
| **ray_remote_args, | ||
| ): | ||
|
|
@@ -768,6 +772,7 @@ def _map_batches_without_batch_size_validation( | |
| fn_constructor_args=fn_constructor_args, | ||
| fn_constructor_kwargs=fn_constructor_kwargs, | ||
| compute=compute, | ||
| preserve_row_count=preserve_row_count, | ||
| ray_remote_args_fn=ray_remote_args_fn, | ||
| ray_remote_args=ray_remote_args, | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -490,5 +490,26 @@ def add_one(row): | |
| assert result_with == expected | ||
|
|
||
|
|
||
| def test_limit_pushdown_preserve_row_count_with_map_batches( | ||
|
||
| ray_start_regular_shared_2_cpus, | ||
| ): | ||
| """Test that limit pushdown preserves the row count with map batches.""" | ||
| ds = ray.data.range(100).map_batches(lambda x: x, preserve_row_count=True).limit(10) | ||
| _check_valid_plan_and_result( | ||
| ds, | ||
| "Read[ReadRange] -> Limit[limit=10] -> MapBatches[MapBatches(<lambda>)]", | ||
| [{"id": i} for i in range(10)], | ||
| ) | ||
|
|
||
| ds = ( | ||
| ray.data.range(100).map_batches(lambda x: x, preserve_row_count=False).limit(10) | ||
| ) | ||
| _check_valid_plan_and_result( | ||
| ds, | ||
| "Read[ReadRange] -> MapBatches[MapBatches(<lambda>)] -> Limit[limit=10]", | ||
| [{"id": i} for i in range(10)], | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(pytest.main(["-v", __file__])) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
preserves_row_count