Skip to content

Commit b8277d5

Browse files
[Data] - Make Projection pushdown tests non-flaky + predicates through projects (#58688)
## Description 1. Use `rows_same` util for the tests in `test_projection_fusion` 2. Properly handle pushing predicates past projections ## Related issues > Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Goutam <[email protected]>
1 parent 004e6dc commit b8277d5

File tree

3 files changed

+316
-87
lines changed

3 files changed

+316
-87
lines changed

python/ray/data/_internal/logical/rules/predicate_pushdown.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
PredicatePassThroughBehavior,
1010
Rule,
1111
)
12-
from ray.data._internal.logical.operators.map_operator import Filter
12+
from ray.data._internal.logical.operators.map_operator import Filter, Project
1313
from ray.data._internal.planner.plan_expression.expression_visitors import (
1414
_ColumnSubstitutionVisitor,
1515
)
@@ -63,6 +63,86 @@ def _try_fuse_filters(cls, op: LogicalOperator) -> LogicalOperator:
6363
predicate_expr=combined_predicate,
6464
)
6565

66+
@classmethod
67+
def _can_push_filter_through_projection(
68+
cls, filter_op: "Filter", projection_op: Project
69+
) -> bool:
70+
"""Check if a filter can be pushed through a projection operator.
71+
72+
Returns False (blocks pushdown) if filter references:
73+
- Columns removed by select: select(['a']).filter(col('b'))
74+
- Computed columns: with_column('d', 4).filter(col('d'))
75+
- Old column names after rename: rename({'b': 'B'}).filter(col('b'))
76+
77+
Returns True (allows pushdown) for:
78+
- Columns present in output: select(['a', 'b']).filter(col('a'))
79+
- New column names after rename: rename({'b': 'B'}).filter(col('B'))
80+
- Rename chains with name reuse: rename({'a': 'b', 'b': 'c'}).filter(col('b'))
81+
(where 'b' is valid output created by a->b)
82+
"""
83+
from ray.data._internal.logical.rules.projection_pushdown import (
84+
_is_renaming_expr,
85+
)
86+
from ray.data._internal.planner.plan_expression.expression_visitors import (
87+
_ColumnReferenceCollector,
88+
)
89+
from ray.data.expressions import AliasExpr
90+
91+
collector = _ColumnReferenceCollector()
92+
collector.visit(filter_op._predicate_expr)
93+
predicate_columns = set(collector.get_column_refs() or [])
94+
95+
output_columns = set()
96+
new_names = set()
97+
original_columns_being_renamed = set()
98+
99+
for expr in projection_op.exprs:
100+
if expr.name is not None:
101+
# Collect output column names
102+
output_columns.add(expr.name)
103+
104+
# Process AliasExpr (computed columns or renames)
105+
if isinstance(expr, AliasExpr):
106+
new_names.add(expr.name)
107+
108+
# Check computed column: with_column('d', 4) creates AliasExpr(lit(4), 'd')
109+
if expr.name in predicate_columns and not _is_renaming_expr(expr):
110+
return False # Computed column
111+
112+
# Track old names being renamed for later check
113+
if _is_renaming_expr(expr):
114+
original_columns_being_renamed.add(expr.expr.name)
115+
116+
# Check if filter references columns removed by explicit select
117+
# Valid if: projection includes all columns (star) OR predicate columns exist in output
118+
has_required_columns = (
119+
projection_op.has_star_expr() or predicate_columns.issubset(output_columns)
120+
)
121+
if not has_required_columns:
122+
return False
123+
124+
# Find old names that are:
125+
# 1. Being renamed away (in original_columns_being_renamed), AND
126+
# 2. Referenced in predicate (in predicate_columns), AND
127+
# 3. NOT recreated as new names (not in new_names)
128+
#
129+
# Examples:
130+
# rename({'b': 'B'}).filter(col('b'))
131+
# → {'b'} & {'b'} - {'B'} = {'b'} → BLOCKS (old name 'b' no longer exists)
132+
#
133+
# rename({'a': 'b', 'b': 'c'}).filter(col('b'))
134+
# → {'a','b'} & {'b'} - {'b','c'} = {} → ALLOWS (new 'b' created by a->b)
135+
#
136+
# rename({'b': 'B'}).filter(col('B'))
137+
# → {'b'} & {'B'} - {'B'} = {} → ALLOWS (using new name 'B')
138+
invalid_old_names = (
139+
original_columns_being_renamed & predicate_columns
140+
) - new_names
141+
if invalid_old_names:
142+
return False # Old name after rename
143+
144+
return True
145+
66146
@classmethod
67147
def _substitute_predicate_columns(
68148
cls, predicate_expr: Expr, column_rename_map: dict[str, str]
@@ -135,6 +215,14 @@ def _try_push_down_predicate(cls, op: LogicalOperator) -> LogicalOperator:
135215
behavior
136216
== PredicatePassThroughBehavior.PASSTHROUGH_WITH_SUBSTITUTION
137217
):
218+
# Check if we can safely push the filter through this projection
219+
if isinstance(
220+
input_op, Project
221+
) and not cls._can_push_filter_through_projection(
222+
filter_op, input_op
223+
):
224+
return filter_op
225+
138226
rename_map = input_op.get_column_substitutions()
139227
if rename_map:
140228
predicate_expr = cls._substitute_predicate_columns(

python/ray/data/tests/test_predicate_pushdown.py

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Repartition,
1212
Sort,
1313
)
14-
from ray.data._internal.logical.operators.map_operator import Filter
14+
from ray.data._internal.logical.operators.map_operator import Filter, Project
1515
from ray.data._internal.logical.operators.one_to_one_operator import Limit
1616
from ray.data._internal.logical.optimizers import LogicalOptimizer
1717
from ray.data._internal.util import rows_same
@@ -543,6 +543,176 @@ def test_multiple_filters_with_renames(self, parquet_ds):
543543
), "All filters should be fused, rebound, and pushed into Read"
544544

545545

546+
class TestProjectionWithFilterEdgeCases:
547+
"""Tests for edge cases with select_columns and with_column followed by filters.
548+
549+
These tests verify that filters correctly handle:
550+
- Columns that are kept by select (should push through)
551+
- Columns that are removed by select (should NOT push through)
552+
- Computed columns from with_column (should NOT push through)
553+
"""
554+
555+
@pytest.fixture
556+
def base_ds(self, ray_start_regular_shared):
557+
return ray.data.from_items(
558+
[
559+
{"a": 1, "b": 2, "c": 3},
560+
{"a": 2, "b": 5, "c": 8},
561+
{"a": 3, "b": 6, "c": 9},
562+
]
563+
)
564+
565+
def test_select_then_filter_on_selected_column(self, base_ds):
566+
"""Filter on selected column should push through select."""
567+
ds = base_ds.select_columns(["a", "b"]).filter(expr=col("a") > 1)
568+
569+
# Verify correctness
570+
result_df = ds.to_pandas()
571+
expected_df = pd.DataFrame(
572+
[
573+
{"a": 2, "b": 5},
574+
{"a": 3, "b": 6},
575+
]
576+
)
577+
# Sort columns before comparison
578+
result_df = result_df[sorted(result_df.columns)]
579+
expected_df = expected_df[sorted(expected_df.columns)]
580+
assert rows_same(result_df, expected_df)
581+
582+
# Verify plan: filter pushed through select
583+
optimized_plan = LogicalOptimizer().optimize(ds._plan._logical_plan)
584+
assert plan_operator_comes_before(
585+
optimized_plan, Filter, Project
586+
), "Filter should be pushed before Project"
587+
588+
def test_select_then_filter_on_removed_column(self, base_ds):
589+
"""Filter on removed column should fail, not push through."""
590+
ds = base_ds.select_columns(["a"])
591+
592+
with pytest.raises((KeyError, ray.exceptions.RayTaskError)):
593+
ds.filter(expr=col("b") == 2).take_all()
594+
595+
def test_with_column_then_filter_on_computed_column(self, base_ds):
596+
"""Filter on computed column should not push through."""
597+
598+
from ray.data.expressions import lit
599+
600+
ds = base_ds.with_column("d", lit(4)).filter(expr=col("d") == 4)
601+
602+
# Verify correctness - all rows should pass (d is always 4)
603+
result_df = ds.to_pandas()
604+
expected_df = pd.DataFrame(
605+
[
606+
{"a": 1, "b": 2, "c": 3, "d": 4},
607+
{"a": 2, "b": 5, "c": 8, "d": 4},
608+
{"a": 3, "b": 6, "c": 9, "d": 4},
609+
]
610+
)
611+
# Sort columns before comparison
612+
result_df = result_df[sorted(result_df.columns)]
613+
expected_df = expected_df[sorted(expected_df.columns)]
614+
assert rows_same(result_df, expected_df)
615+
616+
# Verify plan: filter should NOT push through (stays after with_column)
617+
optimized_plan = LogicalOptimizer().optimize(ds._plan._logical_plan)
618+
assert plan_has_operator(
619+
optimized_plan, Filter
620+
), "Filter should remain (not pushed through)"
621+
622+
def test_rename_then_filter_on_old_column_name(self, base_ds):
623+
"""Filter using old column name after rename should fail."""
624+
ds = base_ds.rename_columns({"b": "B"})
625+
626+
with pytest.raises((KeyError, ray.exceptions.RayTaskError)):
627+
ds.filter(expr=col("b") == 2).take_all()
628+
629+
@pytest.mark.parametrize(
630+
"ds_factory,rename_map,filter_col,filter_value,expected_rows",
631+
[
632+
# In-memory dataset: rename a->b, b->b_old
633+
(
634+
lambda: ray.data.from_items(
635+
[
636+
{"a": 1, "b": 2, "c": 3},
637+
{"a": 2, "b": 5, "c": 8},
638+
{"a": 3, "b": 6, "c": 9},
639+
]
640+
),
641+
{"a": "b", "b": "b_old"},
642+
"b",
643+
1,
644+
[{"b": 2, "b_old": 5, "c": 8}, {"b": 3, "b_old": 6, "c": 9}],
645+
),
646+
# Parquet dataset: rename sepal.length->sepal.width, sepal.width->old_width
647+
(
648+
lambda: ray.data.read_parquet("example://iris.parquet"),
649+
{"sepal.length": "sepal.width", "sepal.width": "old_width"},
650+
"sepal.width",
651+
5.0,
652+
None, # Will verify via alternative computation
653+
),
654+
],
655+
ids=["in_memory", "parquet"],
656+
)
657+
def test_rename_chain_with_name_reuse(
658+
self,
659+
ray_start_regular_shared,
660+
ds_factory,
661+
rename_map,
662+
filter_col,
663+
filter_value,
664+
expected_rows,
665+
):
666+
"""Test rename chains where an output name matches another rename's input name.
667+
668+
This tests the fix for a bug where rename(a->b, b->c) followed by filter(b>5)
669+
would incorrectly block pushdown, even though 'b' is a valid output column
670+
(created by a->b).
671+
672+
Example: rename({'a': 'b', 'b': 'temp'}) creates 'b' from 'a' and 'temp' from 'b'.
673+
A filter on 'b' should be able to push through.
674+
"""
675+
ds = ds_factory()
676+
677+
# Apply rename and filter
678+
ds_renamed_filtered = ds.rename_columns(rename_map).filter(
679+
expr=col(filter_col) > filter_value
680+
)
681+
682+
# Verify correctness
683+
if expected_rows is not None:
684+
# For in-memory, compare against expected rows
685+
result_df = ds_renamed_filtered.to_pandas()
686+
expected_df = pd.DataFrame(expected_rows)
687+
result_df = result_df[sorted(result_df.columns)]
688+
expected_df = expected_df[sorted(expected_df.columns)]
689+
assert rows_same(result_df, expected_df)
690+
else:
691+
# For parquet, compare against alternative computation
692+
# Filter on original column, then rename
693+
original_col = next(k for k, v in rename_map.items() if v == filter_col)
694+
expected = ds.filter(expr=col(original_col) > filter_value).rename_columns(
695+
rename_map
696+
)
697+
assert rows_same(ds_renamed_filtered.to_pandas(), expected.to_pandas())
698+
699+
# Verify plan optimization
700+
optimized_plan = LogicalOptimizer().optimize(
701+
ds_renamed_filtered._plan._logical_plan
702+
)
703+
704+
# For parquet (supports predicate pushdown), filter should push into Read
705+
if "parquet" in str(ds._plan._logical_plan.dag).lower():
706+
assert not plan_has_operator(
707+
optimized_plan, Filter
708+
), "Filter should be pushed into Read after rebinding through rename chain"
709+
else:
710+
# For in-memory, filter should at least push through projection
711+
assert plan_operator_comes_before(
712+
optimized_plan, Filter, Project
713+
), "Filter should be pushed before Project after rebinding through rename chain"
714+
715+
546716
class TestPushIntoBranchesBehavior:
547717
"""Tests for PUSH_INTO_BRANCHES behavior operators.
548718

0 commit comments

Comments
 (0)