diff --git a/python/ray/data/BUILD.bazel b/python/ray/data/BUILD.bazel index ad5fb6f29fe6..66449628ecc7 100644 --- a/python/ray/data/BUILD.bazel +++ b/python/ray/data/BUILD.bazel @@ -913,6 +913,20 @@ py_test( ], ) +py_test( + name = "test_predicate_pushdown", + size = "small", + srcs = ["tests/test_predicate_pushdown.py"], + tags = [ + "exclusive", + "team:data", + ], + deps = [ + ":conftest", + "//:ray_lib", + ], +) + py_test( name = "test_path_util", size = "small", diff --git a/python/ray/data/_internal/datasource/csv_datasource.py b/python/ray/data/_internal/datasource/csv_datasource.py index f8ddc4bda6bd..2d796fab6a71 100644 --- a/python/ray/data/_internal/datasource/csv_datasource.py +++ b/python/ray/data/_internal/datasource/csv_datasource.py @@ -37,6 +37,9 @@ def __init__( self.parse_options = arrow_csv_args.pop("parse_options", csv.ParseOptions()) self.arrow_csv_args = arrow_csv_args + def supports_predicate_pushdown(self) -> bool: + return True + def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: import pyarrow as pa from pyarrow import csv @@ -47,6 +50,12 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: self.parse_options.invalid_row_handler ) + filter_expr = ( + self._predicate_expr.to_pyarrow() + if self._predicate_expr is not None + else None + ) + try: reader = csv.open_csv( f, @@ -61,6 +70,8 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: table = pa.Table.from_batches([batch], schema=schema) if schema is None: schema = table.schema + if filter_expr is not None: + table = table.filter(filter_expr) yield table except StopIteration: return diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index 855e2d122ca1..92d387d9334c 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -190,6 +190,7 @@ def __init__( include_paths: bool = False, file_extensions: Optional[List[str]] = None, ): + super().__init__() _check_pyarrow_version() self._supports_distributed_reads = not _is_local_scheme(paths) @@ -284,7 +285,6 @@ def __init__( self._file_metadata_shuffler = None self._include_paths = include_paths self._partitioning = partitioning - if shuffle == "files": self._file_metadata_shuffler = np.random.default_rng() elif isinstance(shuffle, FileShuffleConfig): @@ -352,6 +352,12 @@ def get_read_tasks( ) read_tasks = [] + filter_expr = ( + self._predicate_expr.to_pyarrow() + if self._predicate_expr is not None + else None + ) + for fragments, paths in zip( np.array_split(pq_fragments, parallelism), np.array_split(pq_paths, parallelism), @@ -401,6 +407,7 @@ def get_read_tasks( f, include_paths, partitioning, + filter_expr, ), meta, schema=target_schema, @@ -424,6 +431,9 @@ def supports_distributed_reads(self) -> bool: def supports_projection_pushdown(self) -> bool: return True + def supports_predicate_pushdown(self) -> bool: + return True + def get_current_projection(self) -> Optional[List[str]]: # NOTE: In case there's no projection both file and partition columns # will be none @@ -432,6 +442,9 @@ def get_current_projection(self) -> Optional[List[str]]: return (self._data_columns or []) + (self._partition_columns or []) + def get_column_renames(self) -> Optional[Dict[str, str]]: + return self._data_columns_rename_map if self._data_columns_rename_map else None + def apply_projection( self, columns: Optional[List[str]], @@ -463,6 +476,7 @@ def read_fragments( fragments: List[_ParquetFragment], include_paths: bool, partitioning: Partitioning, + filter_expr: Optional["pyarrow.dataset.Expression"] = None, ) -> Iterator["pyarrow.Table"]: # This import is necessary to load the tensor extension type. from ray.data.extensions.tensor_extension import ArrowTensorType # noqa @@ -484,6 +498,7 @@ def read_fragments( partition_columns=partition_columns, partitioning=partitioning, include_path=include_paths, + filter_expr=filter_expr, batch_size=default_read_batch_size_rows, to_batches_kwargs=to_batches_kwargs, ), @@ -522,7 +537,14 @@ def _read_batches_from( # NOTE: Passed in kwargs overrides always take precedence # TODO deprecate to_batches_kwargs use_threads = to_batches_kwargs.pop("use_threads", use_threads) - filter_expr = to_batches_kwargs.pop("filter", filter_expr) + # TODO: We should deprecate filter through the read_parquet API and only allow through dataset.filter() + filter_from_kwargs = to_batches_kwargs.pop("filter", None) + if filter_from_kwargs is not None: + filter_expr = ( + filter_from_kwargs + if filter_expr is None + else filter_expr & filter_from_kwargs + ) # NOTE: Arrow's ``to_batches`` expects ``batch_size`` as an int if batch_size is not None: to_batches_kwargs.setdefault("batch_size", batch_size) diff --git a/python/ray/data/_internal/logical/interfaces/__init__.py b/python/ray/data/_internal/logical/interfaces/__init__.py index 8ec39d6454ea..d45578ec093a 100644 --- a/python/ray/data/_internal/logical/interfaces/__init__.py +++ b/python/ray/data/_internal/logical/interfaces/__init__.py @@ -1,4 +1,8 @@ -from .logical_operator import LogicalOperator, LogicalOperatorSupportsProjectionPushdown +from .logical_operator import ( + LogicalOperator, + LogicalOperatorSupportsPredicatePushdown, + LogicalOperatorSupportsProjectionPushdown, +) from .logical_plan import LogicalPlan from .operator import Operator from .optimizer import Optimizer, Rule @@ -16,4 +20,5 @@ "Rule", "SourceOperator", "LogicalOperatorSupportsProjectionPushdown", + "LogicalOperatorSupportsPredicatePushdown", ] diff --git a/python/ray/data/_internal/logical/interfaces/logical_operator.py b/python/ray/data/_internal/logical/interfaces/logical_operator.py index d7141af78987..7237d487f4dc 100644 --- a/python/ray/data/_internal/logical/interfaces/logical_operator.py +++ b/python/ray/data/_internal/logical/interfaces/logical_operator.py @@ -2,6 +2,7 @@ from .operator import Operator from ray.data.block import BlockMetadata +from ray.data.expressions import Expr if TYPE_CHECKING: from ray.data.block import Schema @@ -104,3 +105,28 @@ def apply_projection( column_rename_map: Optional[Dict[str, str]], ) -> LogicalOperator: return self + + +class LogicalOperatorSupportsPredicatePushdown(LogicalOperator): + """Mixin for reading operators supporting predicate pushdown""" + + def supports_predicate_pushdown(self) -> bool: + return False + + def get_current_predicate(self) -> Optional[Expr]: + return None + + def apply_predicate( + self, + predicate_expr: Expr, + ) -> LogicalOperator: + return self + + def get_column_renames(self) -> Optional[Dict[str, str]]: + """Return the column renames applied by projection pushdown, if any. + + Returns: + A dictionary mapping old column names to new column names, + or None if no renaming has been applied. + """ + return None diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index a575681545ba..6253bfd5de52 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -268,6 +268,15 @@ def __init__( def can_modify_num_rows(self) -> bool: return True + def is_expression_based(self) -> bool: + return self._predicate_expr is not None + + def _get_operator_name(self, op_name: str, fn: UserDefinedFunction): + if self.is_expression_based(): + # TODO: Use a truncated expression prefix here instead of . + return f"{op_name}()" + return super()._get_operator_name(op_name, fn) + class Project(AbstractMap): """Logical operator for all Projection Operations.""" diff --git a/python/ray/data/_internal/logical/operators/read_operator.py b/python/ray/data/_internal/logical/operators/read_operator.py index 8561e7aff75c..ba8fa16811a5 100644 --- a/python/ray/data/_internal/logical/operators/read_operator.py +++ b/python/ray/data/_internal/logical/operators/read_operator.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Union from ray.data._internal.logical.interfaces import ( + LogicalOperatorSupportsPredicatePushdown, LogicalOperatorSupportsProjectionPushdown, SourceOperator, ) @@ -14,9 +15,15 @@ ) from ray.data.context import DataContext from ray.data.datasource.datasource import Datasource, Reader +from ray.data.expressions import Expr -class Read(AbstractMap, SourceOperator, LogicalOperatorSupportsProjectionPushdown): +class Read( + AbstractMap, + SourceOperator, + LogicalOperatorSupportsProjectionPushdown, + LogicalOperatorSupportsPredicatePushdown, +): """Logical operator for read.""" # TODO: make this a frozen dataclass. https://github.com/ray-project/ray/issues/55747 @@ -158,6 +165,9 @@ def supports_projection_pushdown(self) -> bool: def get_current_projection(self) -> Optional[List[str]]: return self._datasource.get_current_projection() + def get_column_renames(self) -> Optional[Dict[str, str]]: + return self._datasource.get_column_renames() + def apply_projection( self, columns: Optional[List[str]], @@ -173,6 +183,21 @@ def apply_projection( return clone + def supports_predicate_pushdown(self) -> bool: + return self._datasource.supports_predicate_pushdown() + + def get_current_predicate(self) -> Optional[Expr]: + return self._datasource.get_current_predicate() + + def apply_predicate(self, predicate_expr: Expr) -> "Read": + clone = copy.copy(self) + + predicated_datasource = self._datasource.apply_predicate(predicate_expr) + clone._datasource = predicated_datasource + clone._datasource_or_legacy_reader = predicated_datasource + + return clone + def can_modify_num_rows(self) -> bool: # NOTE: Returns true, since most of the readers expands its input # and produce many rows for every single row of the input diff --git a/python/ray/data/_internal/logical/optimizers.py b/python/ray/data/_internal/logical/optimizers.py index 889e5491f3b6..994a13ab347b 100644 --- a/python/ray/data/_internal/logical/optimizers.py +++ b/python/ray/data/_internal/logical/optimizers.py @@ -16,6 +16,7 @@ ) from ray.data._internal.logical.rules.limit_pushdown import LimitPushdownRule from ray.data._internal.logical.rules.operator_fusion import FuseOperators +from ray.data._internal.logical.rules.predicate_pushdown import PredicatePushdown from ray.data._internal.logical.rules.projection_pushdown import ProjectionPushdown from ray.data._internal.logical.rules.set_read_parallelism import SetReadParallelismRule from ray.util.annotations import DeveloperAPI @@ -25,6 +26,7 @@ InheritBatchFormatRule, LimitPushdownRule, ProjectionPushdown, + PredicatePushdown, ] ) diff --git a/python/ray/data/_internal/logical/rules/predicate_pushdown.py b/python/ray/data/_internal/logical/rules/predicate_pushdown.py new file mode 100644 index 000000000000..3a230ac97472 --- /dev/null +++ b/python/ray/data/_internal/logical/rules/predicate_pushdown.py @@ -0,0 +1,142 @@ +from ray.data._internal.logical.interfaces import ( + LogicalOperator, + LogicalOperatorSupportsPredicatePushdown, + LogicalPlan, + Rule, +) +from ray.data._internal.logical.operators.map_operator import Filter +from ray.data._internal.logical.operators.n_ary_operator import Union +from ray.data._internal.planner.plan_expression.expression_visitors import ( + _ColumnSubstitutionVisitor, +) +from ray.data.expressions import Expr, col + + +class PredicatePushdown(Rule): + """Pushes down predicates across the graph. + + This rule performs the following optimizations: + 1. Combines chained Filter operators with compatible expressions + 2. Pushes filter expressions down to operators that support predicate pushdown, + rebinding column references when necessary (e.g., after projections with renames) + 3. Pushes filters through Union operators into each branch + """ + + def apply(self, plan: LogicalPlan) -> LogicalPlan: + """Apply predicate pushdown optimization to the logical plan.""" + dag = plan.dag + new_dag = dag._apply_transform(self._try_fuse_filters) + new_dag = new_dag._apply_transform(self._try_push_down_predicate) + return LogicalPlan(new_dag, plan.context) if dag is not new_dag else plan + + @classmethod + def _is_valid_filter_operator(cls, op: LogicalOperator) -> bool: + return isinstance(op, Filter) and op.is_expression_based() + + @classmethod + def _try_fuse_filters(cls, op: LogicalOperator) -> LogicalOperator: + """Fuse consecutive Filter operators with compatible expressions.""" + if not cls._is_valid_filter_operator(op): + return op + + input_op = op.input_dependencies[0] + if not cls._is_valid_filter_operator(input_op): + return op + + # Combine predicates + combined_predicate = op._predicate_expr & input_op._predicate_expr + + # Create new filter on the input of the lower filter + return Filter( + input_op.input_dependencies[0], + predicate_expr=combined_predicate, + ) + + @classmethod + def _rebind_predicate_columns( + cls, predicate_expr: Expr, column_rename_map: dict[str, str] + ) -> Expr: + """Rebind column references in a predicate expression. + + When pushing a predicate through a projection with column renames, + we need to rewrite column references from new names to old names. + + Args: + predicate_expr: The predicate with new column names + column_rename_map: Mapping from old_name -> new_name + + Returns: + The predicate rewritten to use old column names + """ + # Invert the mapping: new_name -> old_name (as col expression) + # This is because the predicate uses new names and we need to map + # them back to old names + column_mapping = { + new_col: col(old_col) for old_col, new_col in column_rename_map.items() + } + + visitor = _ColumnSubstitutionVisitor(column_mapping) + return visitor.visit(predicate_expr) + + @classmethod + def _try_push_down_predicate(cls, op: LogicalOperator) -> LogicalOperator: + """Push Filter down through the operator tree.""" + if not cls._is_valid_filter_operator(op): + return op + + input_op = op.input_dependencies[0] + + # Special case: Push filter through Union into each branch + # TODO: Push filter through other operators like Projection, Zip, Join, Sort, Aggregate (after expression support lands) + if isinstance(input_op, Union): + return cls._push_filter_through_union(op, input_op) + + # Check if the input operator supports predicate pushdown + if ( + isinstance(input_op, LogicalOperatorSupportsPredicatePushdown) + and input_op.supports_predicate_pushdown() + ): + predicate_expr = op._predicate_expr + + # Check if the operator has column renames that need rebinding + # This happens when projection pushdown has been applied + rename_map = input_op.get_column_renames() + if rename_map: + # Rebind the predicate to use original column names + # This is needed to ensure that the predicate expression can be pushed into the input operator. + predicate_expr = cls._rebind_predicate_columns( + predicate_expr, rename_map + ) + + # Push the predicate down and return the result without the filter + return input_op.apply_predicate(predicate_expr) + + return op + + @classmethod + def _push_filter_through_union(cls, filter_op: Filter, union_op: Union) -> Union: + """Push a Filter through a Union into each branch. + + Transforms: + branch₁ ─┐ + branch₂ ─┤ Union ─> Filter(predicate) + branch₃ ─┘ + + Into: + branch₁ ─> Filter(predicate) ─┐ + branch₂ ─> Filter(predicate) ─┤ Union + branch₃ ─> Filter(predicate) ─┘ + """ + predicate_expr = filter_op._predicate_expr + + # Apply filter to each branch of the union + new_inputs = [] + for input_op in union_op.input_dependencies: + # Create a filter for this branch and recursively try to push it down + branch_filter = Filter(input_op, predicate_expr=predicate_expr) + # Recursively apply pushdown to each branch's filter + pushed_branch = cls._try_push_down_predicate(branch_filter) + new_inputs.append(pushed_branch) + + # Return a new Union with filtered branches + return Union(*new_inputs) diff --git a/python/ray/data/_internal/logical/rules/projection_pushdown.py b/python/ray/data/_internal/logical/rules/projection_pushdown.py index 4e40696bb412..31469027770c 100644 --- a/python/ray/data/_internal/logical/rules/projection_pushdown.py +++ b/python/ray/data/_internal/logical/rules/projection_pushdown.py @@ -9,7 +9,7 @@ from ray.data._internal.logical.operators.map_operator import Project from ray.data._internal.planner.plan_expression.expression_visitors import ( _ColumnReferenceCollector, - _ColumnRefRebindingVisitor, + _ColumnSubstitutionVisitor, _is_col_expr, ) from ray.data.expressions import ( @@ -42,27 +42,6 @@ def _collect_referenced_columns(exprs: List[Expr]) -> Optional[List[str]]: return collector.get_column_refs() -def _extract_simple_rename(expr: Expr) -> Optional[Tuple[str, str]]: - """ - Check if an expression is a simple column rename. - - Returns (source_name, target_name) if the expression is of form: - col("source").alias("dest") - - Returns None for other expression types. - """ - if ( - isinstance(expr, AliasExpr) - and isinstance(expr.expr, ColumnExpr) - and expr._is_rename - ): - target_name = expr.name - source_name = expr.expr.name - return source_name, target_name - - return None - - def _analyze_upstream_project( upstream_project: Project, ) -> Tuple[Set[str], dict[str, Expr], Set[str]]: @@ -203,7 +182,7 @@ def _try_fuse(upstream_project: Project, downstream_project: Project) -> Project # Upstream output column refs inside downstream expressions need to be bound # to upstream output column definitions to satisfy invariant #1 (common for both # composition/projection cases) - v = _ColumnRefRebindingVisitor(upstream_column_defs) + v = _ColumnSubstitutionVisitor(upstream_column_defs) rebound_downstream_exprs = [ v.visit(e) for e in _filter_out_star(downstream_project.exprs) diff --git a/python/ray/data/_internal/planner/plan_expression/expression_visitors.py b/python/ray/data/_internal/planner/plan_expression/expression_visitors.py index dfeda08715d4..b01ed2e87710 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_visitors.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_visitors.py @@ -96,7 +96,7 @@ def visit_alias(self, expr: AliasExpr) -> None: self.visit(expr.expr) -class _ColumnRefRebindingVisitor(_ExprVisitor[Expr]): +class _ColumnSubstitutionVisitor(_ExprVisitor[Expr]): """Visitor rebinding column references in ``Expression``s. This visitor traverses given ``Expression`` trees and substitutes column references diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index c217e3a2d69d..708f94430523 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -5,6 +5,7 @@ from ray.data._internal.util import _check_pyarrow_version from ray.data.block import Block, BlockMetadata, Schema from ray.data.datasource.util import _iter_sliced_blocks +from ray.data.expressions import Expr from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI @@ -20,6 +21,15 @@ def get_current_projection(self) -> Optional[List[str]]: """Retrurns current projection""" return None + def get_column_renames(self) -> Optional[Dict[str, str]]: + """Return the column renames applied to this datasource. + + Returns: + A dictionary mapping old column names to new column names, + or None if no renaming has been applied. + """ + return None + def apply_projection( self, columns: Optional[List[str]], @@ -28,13 +38,57 @@ def apply_projection( return self +class _DatasourcePredicatePushdownMixin: + """Mixin for reading operators supporting predicate pushdown""" + + def __init__(self): + self._predicate_expr: Optional[Expr] = None + + def supports_predicate_pushdown(self) -> bool: + return False + + def get_current_predicate(self) -> Optional[Expr]: + return self._predicate_expr + + def apply_predicate( + self, + predicate_expr: Expr, + ) -> "Datasource": + """Apply a predicate to this datasource. + + Default implementation that combines predicates using AND. + Subclasses that support predicate pushdown should have a _predicate_expr + attribute to store the predicate. + + Note: Column rebinding is handled by the PredicatePushdown rule + before this method is called, so the predicate_expr should already + reference the correct column names. + """ + import copy + + clone = copy.copy(self) + + # Combine with existing predicate using AND + clone._predicate_expr = ( + predicate_expr + if clone._predicate_expr is None + else clone._predicate_expr & predicate_expr + ) + + return clone + + @PublicAPI -class Datasource(_DatasourceProjectionPushdownMixin): +class Datasource(_DatasourceProjectionPushdownMixin, _DatasourcePredicatePushdownMixin): """Interface for defining a custom :class:`~ray.data.Dataset` datasource. To read a datasource into a dataset, use :meth:`~ray.data.read_datasource`. """ # noqa: E501 + def __init__(self): + """Initialize the datasource and its mixins.""" + _DatasourcePredicatePushdownMixin.__init__(self) + @Deprecated def create_reader(self, **read_args) -> "Reader": """ diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 256927931087..eb52d564d222 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -122,6 +122,7 @@ def __init__( include_paths: bool = False, file_extensions: Optional[List[str]] = None, ): + super().__init__() _check_pyarrow_version() self._supports_distributed_reads = not _is_local_scheme(paths) diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index d1518253cd08..c58408480c1f 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -1009,6 +1009,15 @@ def read_parquet( _emit_meta_provider_deprecation_warning(meta_provider) _validate_shuffle_arg(shuffle) + # Check for deprecated filter parameter + if "filter" in arrow_parquet_args: + warnings.warn( + "The `filter` argument is deprecated and will not supported in a future release. " + "Use `dataset.filter(expr=expr)` instead to filter rows.", + DeprecationWarning, + stacklevel=2, + ) + arrow_parquet_args = _resolve_parquet_args( tensor_column_schema, **arrow_parquet_args, diff --git a/python/ray/data/tests/test_filter.py b/python/ray/data/tests/test_filter.py index 2c40bf6fd669..fcab00f0ef3f 100644 --- a/python/ray/data/tests/test_filter.py +++ b/python/ray/data/tests/test_filter.py @@ -115,7 +115,11 @@ def test_filter_with_invalid_expression(ray_start_regular_shared, tmp_path): parquet_ds.filter(expr="fake_news super fake") fake_column_ds = parquet_ds.filter(expr="sepal_length_123 > 1") - with pytest.raises(KeyError): + # With predicate pushdown, the error is raised during file reading + # and wrapped in RayTaskError + with pytest.raises( + (ray.exceptions.RayTaskError, RuntimeError), match="sepal_length_123" + ): fake_column_ds.to_pandas() diff --git a/python/ray/data/tests/test_predicate_pushdown.py b/python/ray/data/tests/test_predicate_pushdown.py new file mode 100644 index 000000000000..9e21fea2e99c --- /dev/null +++ b/python/ray/data/tests/test_predicate_pushdown.py @@ -0,0 +1,374 @@ +import re +from typing import Any, List + +import pandas as pd +import pyarrow.compute as pc +import pytest + +import ray +from ray.data import Dataset +from ray.data._internal.logical.optimizers import LogicalOptimizer +from ray.data.expressions import col +from ray.data.tests.conftest import * # noqa +from ray.data.tests.test_execution_optimizer_limit_pushdown import ( + _check_valid_plan_and_result, +) +from ray.tests.conftest import * # noqa + +# Pattern to match read operators in logical plans. +# Matches Read[Read] where format is Parquet, CSV, Range, etc. +READ_OPERATOR_PATTERN = ( + r"^(Read\[Read\w+\]|ListFiles\[ListFiles\] -> ReadFiles\[ReadFiles\])" +) + + +def _check_plan_with_flexible_read( + ds: Dataset, expected_plan_suffix: str, expected_result: List[Any] +): + """Check the logical plan with flexible read operator matching. + + This function allows flexibility in the read operator part of the plan + by using a configurable pattern (READ_OPERATOR_PATTERN). + + Args: + ds: The dataset to check. + expected_plan_suffix: The expected plan after the read operator(s). + If empty string, only the read operator is expected. + expected_result: The expected result data. + """ + # Optimize the logical plan before checking + logical_plan = ds._plan._logical_plan + optimized_plan = LogicalOptimizer().optimize(logical_plan) + actual_plan = optimized_plan.dag.dag_str + + match = re.match(READ_OPERATOR_PATTERN, actual_plan) + assert match, f"Expected plan to start with read operator, got: {actual_plan}" + + # Check if there's a suffix expected + if expected_plan_suffix: + # The suffix should appear after the read operator + expected_full_pattern = ( + f"{READ_OPERATOR_PATTERN} -> {re.escape(expected_plan_suffix)}" + ) + assert re.match(expected_full_pattern, actual_plan), ( + f"Expected plan to match pattern with suffix '{expected_plan_suffix}', " + f"got: {actual_plan}" + ) + # If no suffix, the plan should be just the read operator + else: + assert actual_plan == match.group( + 1 + ), f"Expected plan to be just the read operator, got: {actual_plan}" + + # Check the result + assert ds.take_all() == expected_result + + +@pytest.fixture +def parquet_ds(ray_start_regular_shared): + """Fixture to load the Parquet dataset for testing.""" + ds = ray.data.read_parquet("example://iris.parquet") + assert ds.count() == 150 + return ds + + +@pytest.fixture +def csv_ds(ray_start_regular_shared): + """Fixture to load the CSV dataset for testing.""" + ds = ray.data.read_csv("example://iris.csv") + assert ds.count() == 150 + return ds + + +def test_filter_with_udfs(parquet_ds): + """Test filtering with UDFs where predicate pushdown does not occur.""" + filtered_udf_ds = parquet_ds.filter(lambda r: r["sepal.length"] > 5.0) + filtered_udf_data = filtered_udf_ds.take_all() + assert filtered_udf_ds.count() == 118 + assert all(record["sepal.length"] > 5.0 for record in filtered_udf_data) + _check_plan_with_flexible_read( + filtered_udf_ds, + "Filter[Filter()]", # UDF filter doesn't push down + filtered_udf_data, + ) + + +def test_filter_with_expressions(parquet_ds): + """Test filtering with expressions where predicate pushdown occurs.""" + filtered_udf_data = parquet_ds.filter(lambda r: r["sepal.length"] > 5.0).take_all() + filtered_expr_ds = parquet_ds.filter(expr="sepal.length > 5.0") + _check_plan_with_flexible_read( + filtered_expr_ds, + "", # Pushed down to read, no additional operators + filtered_udf_data, + ) + + +def test_filter_pushdown_source_and_op(ray_start_regular_shared): + """Test filtering when expressions are provided both in source and operator.""" + # Test with PyArrow compute expressions + source_expr = pc.greater(pc.field("sepal.length"), pc.scalar(5.0)) + filter_expr = "sepal.width > 3.0" + + ds = ray.data.read_parquet("example://iris.parquet", filter=source_expr).filter( + expr=filter_expr + ) + result = ds.take_all() + assert all(r["sepal.length"] > 5.0 and r["sepal.width"] > 3.0 for r in result) + _check_plan_with_flexible_read( + ds, + "", # Both filters pushed down to read + result, + ) + + +def test_chained_filter_with_expressions(parquet_ds): + """Test chained filtering with expressions where combined pushdown occurs.""" + filtered_expr_chained_ds = ( + parquet_ds.filter(expr=col("sepal.length") > 1.0) + .filter(expr=col("sepal.length") > 2.0) + .filter(expr=col("sepal.length") > 3.0) + .filter(expr=col("sepal.length") > 3.0) + .filter(expr=col("sepal.length") > 5.0) + ) + filtered_udf_data = parquet_ds.filter(lambda r: r["sepal.length"] > 5.0).take_all() + _check_plan_with_flexible_read( + filtered_expr_chained_ds, + "", # All filters combined and pushed down to read + filtered_udf_data, + ) + + +@pytest.mark.parametrize( + "filter_fn,expected_suffix", + [ + ( + lambda ds: ds.filter(lambda r: r["sepal.length"] > 5.0), + "Filter[Filter()]", # UDF filter doesn't push down + ), + ( + lambda ds: ds.filter(expr=col("sepal.length") > 5.0), + "", # Expression filter pushes down to read + ), + ], +) +def test_filter_pushdown_csv(csv_ds, filter_fn, expected_suffix): + """Test filtering on CSV files with predicate pushdown.""" + filtered_ds = filter_fn(csv_ds) + filtered_data = filtered_ds.take_all() + assert filtered_ds.count() == 118 + assert all(record["sepal.length"] > 5.0 for record in filtered_data) + _check_plan_with_flexible_read( + filtered_ds, + expected_suffix, + filtered_data, + ) + + +def test_filter_mixed(csv_ds): + """Test that mixed function and expressions work (CSV supports predicate pushdown).""" + csv_ds = csv_ds.filter(lambda r: r["sepal.length"] < 5.0) + csv_ds = csv_ds.filter(expr="sepal.length > 3.0") + csv_ds = csv_ds.filter(expr="sepal.length > 4.0") + csv_ds = csv_ds.map(lambda x: x) + csv_ds = csv_ds.filter(expr="sepal.length > 2.0") + csv_ds = csv_ds.filter(expr="sepal.length > 1.0") + filtered_expr_data = csv_ds.take_all() + assert csv_ds.count() == 22 + assert all(record["sepal.length"] < 5.0 for record in filtered_expr_data) + assert all(record["sepal.length"] > 4.0 for record in filtered_expr_data) + # After optimization: expression filters before map get fused, expression filters after map get fused + _check_plan_with_flexible_read( + csv_ds, + "Filter[Filter()] -> Filter[Filter()] -> " + "MapRows[Map()] -> Filter[Filter()]", + filtered_expr_data, + ) + + +def test_filter_mixed_expression_first_parquet(ray_start_regular_shared): + """Test that mixed functional and expressions work with Parquet (supports predicate pushdown).""" + ds = ray.data.read_parquet("example://iris.parquet") + ds = ds.filter(expr="sepal.length > 3.0") + ds = ds.filter(expr="sepal.length > 4.0") + ds = ds.filter(lambda r: r["sepal.length"] < 5.0) + filtered_expr_data = ds.take_all() + assert ds.count() == 22 + assert all(record["sepal.length"] < 5.0 for record in filtered_expr_data) + assert all(record["sepal.length"] > 4.0 for record in filtered_expr_data) + _check_plan_with_flexible_read( + ds, + "Filter[Filter()]", # Expressions pushed down, UDF remains + filtered_expr_data, + ) + + +def test_filter_mixed_expression_first_csv(ray_start_regular_shared): + """Test that mixed functional and expressions work with CSV (supports predicate pushdown).""" + ds = ray.data.read_csv("example://iris.csv") + ds = ds.filter(expr="sepal.length > 3.0") + ds = ds.filter(expr="sepal.length > 4.0") + ds = ds.filter(lambda r: r["sepal.length"] < 5.0) + filtered_expr_data = ds.take_all() + assert ds.count() == 22 + assert all(record["sepal.length"] < 5.0 for record in filtered_expr_data) + assert all(record["sepal.length"] > 4.0 for record in filtered_expr_data) + # Expression filters pushed down to read, UDF filter remains + _check_plan_with_flexible_read( + ds, + "Filter[Filter()]", + filtered_expr_data, + ) + + +def test_filter_mixed_expression_not_readfiles(ray_start_regular_shared): + """Test that mixed functional and expressions work.""" + ds = ray.data.range(100).filter(expr="id > 1.0") + ds = ds.filter(expr="id > 2.0") + ds = ds.filter(lambda r: r["id"] < 5.0) + filtered_expr_data = ds.take_all() + assert ds.count() == 2 + assert all(record["id"] < 5.0 for record in filtered_expr_data) + assert all(record["id"] > 2.0 for record in filtered_expr_data) + _check_valid_plan_and_result( + ds, + "Read[ReadRange] -> Filter[Filter()] -> " + "Filter[Filter()]", + filtered_expr_data, + ) + + +def test_read_range_union_with_filter_pushdown(ray_start_regular_shared): + ds1 = ray.data.range(100, parallelism=2) + ds2 = ray.data.range(100, parallelism=2) + ds = ds1.union(ds2).filter(expr="id >= 50") + result = ds.take_all() + assert ds.count() == 100 + _check_valid_plan_and_result( + ds, + "Read[ReadRange] -> Filter[Filter()], " + "Read[ReadRange] -> Filter[Filter()] -> Union[Union]", + result, + ) + + +def test_multiple_union_with_filter_pushdown(ray_start_regular_shared): + ds1 = ray.data.read_parquet("example://iris.parquet") + ds2 = ray.data.read_parquet("example://iris.parquet") + ds3 = ray.data.read_parquet("example://iris.parquet") + ds = ds1.union(ds2).union(ds3).filter(expr="sepal.length > 5.0") + result = ds.take_all() + assert ds.count() == 354 + assert all(record["sepal.length"] > 5.0 for record in result) + + # For union operations, verify the pattern separately for each branch + actual_plan = ds._plan._logical_plan.dag.dag_str + # Check that filter was pushed down into all three reads (no Filter operator in plan) + assert ( + "Filter[Filter" not in actual_plan + ), f"Filter should be pushed down, got: {actual_plan}" + # Check that union operations are present + assert ( + actual_plan.count("Union[Union]") == 2 + ), f"Expected 2 unions, got: {actual_plan}" + # Check result + assert ds.take_all() == result + + +def test_multiple_filter_with_union_pushdown_parquet(ray_start_regular_shared): + ds1 = ray.data.read_parquet("example://iris.parquet") + ds1 = ds1.filter(expr="sepal.width > 2.0") + ds2 = ray.data.read_parquet("example://iris.parquet") + ds2 = ds2.filter(expr="sepal.width > 2.0") + ds = ds1.union(ds2).filter(expr="sepal.length < 5.0") + result = ds.take_all() + assert all(record["sepal.width"] > 2.0 for record in result) + assert all(record["sepal.length"] < 5.0 for record in result) + + assert ds.count() == 44 + + # For union operations, verify the pattern separately for each branch + actual_plan = ds._plan._logical_plan.dag.dag_str + # Check that all filters were pushed down (no Filter operator in plan) + assert ( + "Filter[Filter" not in actual_plan + ), f"Filters should be pushed down, got: {actual_plan}" + # Check that union operation is present + assert "Union[Union]" in actual_plan, f"Expected union, got: {actual_plan}" + # Check result + assert ds.take_all() == result + + +@pytest.mark.parametrize( + "operations,output_rename_map,expected_filter_expr,test_id", + [ + ( + # rename("sepal.length" -> a).filter(a) + lambda ds: ds.rename_columns({"sepal.length": "a"}).filter( + expr=col("a") > 2.0 + ), + {"a": "sepal.length"}, + col("sepal.length") > 2.0, + "rename_filter", + ), + ( + # rename("sepal.length" -> a).filter(a).rename(a -> b) + lambda ds: ds.rename_columns({"sepal.length": "a"}) + .filter(expr=col("a") > 2.0) + .rename_columns({"a": "b"}), + {"b": "sepal.length"}, + col("sepal.length") > 2.0, + "rename_filter_rename", + ), + ( + # rename("sepal.length" -> a).filter(a).rename(a -> b).filter(b) + lambda ds: ds.rename_columns({"sepal.length": "a"}) + .filter(expr=col("a") > 2.0) + .rename_columns({"a": "b"}) + .filter(expr=col("b") < 5.0), + {"b": "sepal.length"}, + (col("sepal.length") > 2.0) & (col("sepal.length") < 5.0), + "rename_filter_rename_filter", + ), + ( + # rename("sepal.length" -> a).filter(a).rename(a -> b).filter(b).rename("sepal.width" -> a) + # Here column a is referred multiple times in rename + lambda ds: ds.rename_columns({"sepal.length": "a"}) + .filter(expr=col("a") > 2.0) + .rename_columns({"a": "b"}) + .filter(expr=col("b") < 5.0) + .rename_columns({"sepal.width": "a"}), + {"b": "sepal.length", "a": "sepal.width"}, + (col("sepal.length") > 2.0) & (col("sepal.length") < 5.0), + "rename_filter_rename_filter_rename", + ), + ], + ids=lambda x: x if isinstance(x, str) else "", +) +def test_pushdown_with_rename_and_filter( + ray_start_regular_shared, + operations, + output_rename_map, + expected_filter_expr, + test_id, +): + """Test predicate pushdown with various combinations of rename and filter operations.""" + path = "example://iris.parquet" + ds = operations(ray.data.read_parquet(path)) + result = ds.take_all() + + # Check that plan is just the read (filters and renames pushed down/fused) + _check_plan_with_flexible_read(ds, "", result) + + ds1 = ray.data.read_parquet(path).filter(expr=expected_filter_expr) + # Convert to pandas to ensure both datasets are fully executed + df = ds.to_pandas().rename(columns=output_rename_map) + df1 = ds1.to_pandas() + assert len(df) == len(df1), f"Expected {len(df)} rows, got {len(df1)} rows" + pd.testing.assert_frame_equal(df, df1) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/data/tests/test_union.py b/python/ray/data/tests/test_union.py index 1797325fd36f..cfc041223454 100644 --- a/python/ray/data/tests/test_union.py +++ b/python/ray/data/tests/test_union.py @@ -38,6 +38,32 @@ def test_union_with_preserve_order(ray_start_10_cpus_shared, restore_data_contex assert [row["id"] for row in ds.take_all()] == [0, 1, 2] +def test_union_with_filter(ray_start_10_cpus_shared): + """Test that filters are pushed through union to both branches.""" + from ray.data._internal.logical.optimizers import LogicalOptimizer + from ray.data.expressions import col + + ds1 = ray.data.from_items([{"id": 0}, {"id": 1}, {"id": 2}]) + ds2 = ray.data.from_items([{"id": 3}, {"id": 4}, {"id": 5}]) + ds = ds1.union(ds2).filter(expr=col("id") > 2) + + # Verify the filter was pushed through the union + optimized_plan = LogicalOptimizer().optimize(ds._plan._logical_plan) + actual_plan_str = optimized_plan.dag.dag_str + + # After optimization, filter should be pushed to both union branches + # So we should see: Filter(Read), Filter(Read) -> Union + # Not: Read, Read -> Union -> Filter + assert "Union" in actual_plan_str + assert "Filter" in actual_plan_str + # Ensure Filter is before Union (pushed down), not after + assert actual_plan_str.index("Filter") < actual_plan_str.index("Union") + + # Verify correctness + result = sorted(row["id"] for row in ds.take_all()) + assert result == [3, 4, 5] + + if __name__ == "__main__": import sys