Skip to content
14 changes: 14 additions & 0 deletions python/ray/data/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,20 @@ py_test(
],
)

py_test(
name = "test_predicate_pushdown",
size = "small",
srcs = ["tests/test_predicate_pushdown.py"],
tags = [
"exclusive",
"team:data",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_path_util",
size = "small",
Expand Down
11 changes: 11 additions & 0 deletions python/ray/data/_internal/datasource/csv_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(
self.parse_options = arrow_csv_args.pop("parse_options", csv.ParseOptions())
self.arrow_csv_args = arrow_csv_args

def supports_predicate_pushdown(self) -> bool:
return True

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import pyarrow as pa
from pyarrow import csv
Expand All @@ -47,6 +50,12 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
self.parse_options.invalid_row_handler
)

filter_expr = (
self._predicate_expr.to_pyarrow()
if self._predicate_expr is not None
else None
)

try:
reader = csv.open_csv(
f,
Expand All @@ -61,6 +70,8 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
table = pa.Table.from_batches([batch], schema=schema)
if schema is None:
schema = table.schema
if filter_expr is not None:
table = table.filter(filter_expr)
yield table
except StopIteration:
return
Expand Down
26 changes: 24 additions & 2 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def __init__(
include_paths: bool = False,
file_extensions: Optional[List[str]] = None,
):
super().__init__()
_check_pyarrow_version()

self._supports_distributed_reads = not _is_local_scheme(paths)
Expand Down Expand Up @@ -284,7 +285,6 @@ def __init__(
self._file_metadata_shuffler = None
self._include_paths = include_paths
self._partitioning = partitioning

if shuffle == "files":
self._file_metadata_shuffler = np.random.default_rng()
elif isinstance(shuffle, FileShuffleConfig):
Expand Down Expand Up @@ -352,6 +352,12 @@ def get_read_tasks(
)

read_tasks = []
filter_expr = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much work is actually push our expr all the way into the reader itself?

If not a lot let's do the right thing right away (otherwise do it in stacked PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm already pushing this into the reader. apply_predicate should change the _predicate_expr which then calls to_pyarrow() to convert to a pyarrow.dataset.expression which then gets sent to fragment.to_batches()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant, threading of our expressions instead of PA ones

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. How will pyarrow accept Ray Data's Expressions? At some point we have to convert before calling to_batches() right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. As part of the next PR, I'll refactor the remaining 2 functions that are not managed by Pyarrow to only pass in Ray Data's Expr

self._predicate_expr.to_pyarrow()
if self._predicate_expr is not None
else None
)

for fragments, paths in zip(
np.array_split(pq_fragments, parallelism),
np.array_split(pq_paths, parallelism),
Expand Down Expand Up @@ -401,6 +407,7 @@ def get_read_tasks(
f,
include_paths,
partitioning,
filter_expr,
),
meta,
schema=target_schema,
Expand All @@ -424,6 +431,9 @@ def supports_distributed_reads(self) -> bool:
def supports_projection_pushdown(self) -> bool:
return True

def supports_predicate_pushdown(self) -> bool:
return True

def get_current_projection(self) -> Optional[List[str]]:
# NOTE: In case there's no projection both file and partition columns
# will be none
Expand All @@ -432,6 +442,9 @@ def get_current_projection(self) -> Optional[List[str]]:

return (self._data_columns or []) + (self._partition_columns or [])

def get_column_renames(self) -> Optional[Dict[str, str]]:
return self._data_columns_rename_map if self._data_columns_rename_map else None

def apply_projection(
self,
columns: Optional[List[str]],
Expand Down Expand Up @@ -463,6 +476,7 @@ def read_fragments(
fragments: List[_ParquetFragment],
include_paths: bool,
partitioning: Partitioning,
filter_expr: Optional["pyarrow.dataset.Expression"] = None,
) -> Iterator["pyarrow.Table"]:
# This import is necessary to load the tensor extension type.
from ray.data.extensions.tensor_extension import ArrowTensorType # noqa
Expand All @@ -484,6 +498,7 @@ def read_fragments(
partition_columns=partition_columns,
partitioning=partitioning,
include_path=include_paths,
filter_expr=filter_expr,
batch_size=default_read_batch_size_rows,
to_batches_kwargs=to_batches_kwargs,
),
Expand Down Expand Up @@ -522,7 +537,14 @@ def _read_batches_from(
# NOTE: Passed in kwargs overrides always take precedence
# TODO deprecate to_batches_kwargs
use_threads = to_batches_kwargs.pop("use_threads", use_threads)
filter_expr = to_batches_kwargs.pop("filter", filter_expr)
# TODO: We should deprecate filter through the read_parquet API and only allow through dataset.filter()
filter_from_kwargs = to_batches_kwargs.pop("filter", None)
if filter_from_kwargs is not None:
filter_expr = (
filter_from_kwargs
if filter_expr is None
else filter_expr & filter_from_kwargs
)
# NOTE: Arrow's ``to_batches`` expects ``batch_size`` as an int
if batch_size is not None:
to_batches_kwargs.setdefault("batch_size", batch_size)
Expand Down
7 changes: 6 additions & 1 deletion python/ray/data/_internal/logical/interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .logical_operator import LogicalOperator, LogicalOperatorSupportsProjectionPushdown
from .logical_operator import (
LogicalOperator,
LogicalOperatorSupportsPredicatePushdown,
LogicalOperatorSupportsProjectionPushdown,
)
from .logical_plan import LogicalPlan
from .operator import Operator
from .optimizer import Optimizer, Rule
Expand All @@ -16,4 +20,5 @@
"Rule",
"SourceOperator",
"LogicalOperatorSupportsProjectionPushdown",
"LogicalOperatorSupportsPredicatePushdown",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .operator import Operator
from ray.data.block import BlockMetadata
from ray.data.expressions import Expr

if TYPE_CHECKING:
from ray.data.block import Schema
Expand Down Expand Up @@ -104,3 +105,28 @@ def apply_projection(
column_rename_map: Optional[Dict[str, str]],
) -> LogicalOperator:
return self


class LogicalOperatorSupportsPredicatePushdown(LogicalOperator):
"""Mixin for reading operators supporting predicate pushdown"""

def supports_predicate_pushdown(self) -> bool:
return False

def get_current_predicate(self) -> Optional[Expr]:
return None

def apply_predicate(
self,
predicate_expr: Expr,
) -> LogicalOperator:
return self

def get_column_renames(self) -> Optional[Dict[str, str]]:
"""Return the column renames applied by projection pushdown, if any.

Returns:
A dictionary mapping old column names to new column names,
or None if no renaming has been applied.
"""
return None
9 changes: 9 additions & 0 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def __init__(
def can_modify_num_rows(self) -> bool:
return True

def is_expression_based(self) -> bool:
return self._predicate_expr is not None

def _get_operator_name(self, op_name: str, fn: UserDefinedFunction):
if self.is_expression_based():
# TODO: Use a truncated expression prefix here instead of <expression>.
return f"{op_name}(<expression>)"
return super()._get_operator_name(op_name, fn)


class Project(AbstractMap):
"""Logical operator for all Projection Operations."""
Expand Down
27 changes: 26 additions & 1 deletion python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List, Optional, Union

from ray.data._internal.logical.interfaces import (
LogicalOperatorSupportsPredicatePushdown,
LogicalOperatorSupportsProjectionPushdown,
SourceOperator,
)
Expand All @@ -14,9 +15,15 @@
)
from ray.data.context import DataContext
from ray.data.datasource.datasource import Datasource, Reader
from ray.data.expressions import Expr


class Read(AbstractMap, SourceOperator, LogicalOperatorSupportsProjectionPushdown):
class Read(
AbstractMap,
SourceOperator,
LogicalOperatorSupportsProjectionPushdown,
LogicalOperatorSupportsPredicatePushdown,
):
"""Logical operator for read."""

# TODO: make this a frozen dataclass. https://github.com/ray-project/ray/issues/55747
Expand Down Expand Up @@ -158,6 +165,9 @@ def supports_projection_pushdown(self) -> bool:
def get_current_projection(self) -> Optional[List[str]]:
return self._datasource.get_current_projection()

def get_column_renames(self) -> Optional[Dict[str, str]]:
return self._datasource.get_column_renames()

def apply_projection(
self,
columns: Optional[List[str]],
Expand All @@ -173,6 +183,21 @@ def apply_projection(

return clone

def supports_predicate_pushdown(self) -> bool:
return self._datasource.supports_predicate_pushdown()

def get_current_predicate(self) -> Optional[Expr]:
return self._datasource.get_current_predicate()

def apply_predicate(self, predicate_expr: Expr) -> "Read":
clone = copy.copy(self)

predicated_datasource = self._datasource.apply_predicate(predicate_expr)
clone._datasource = predicated_datasource
clone._datasource_or_legacy_reader = predicated_datasource

return clone

def can_modify_num_rows(self) -> bool:
# NOTE: Returns true, since most of the readers expands its input
# and produce many rows for every single row of the input
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/_internal/logical/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from ray.data._internal.logical.rules.limit_pushdown import LimitPushdownRule
from ray.data._internal.logical.rules.operator_fusion import FuseOperators
from ray.data._internal.logical.rules.predicate_pushdown import PredicatePushdown
from ray.data._internal.logical.rules.projection_pushdown import ProjectionPushdown
from ray.data._internal.logical.rules.set_read_parallelism import SetReadParallelismRule
from ray.util.annotations import DeveloperAPI
Expand All @@ -25,6 +26,7 @@
InheritBatchFormatRule,
LimitPushdownRule,
ProjectionPushdown,
PredicatePushdown,
]
)

Expand Down
Loading