|
11 | 11 | Repartition, |
12 | 12 | Sort, |
13 | 13 | ) |
14 | | -from ray.data._internal.logical.operators.map_operator import Filter |
| 14 | +from ray.data._internal.logical.operators.map_operator import Filter, Project |
15 | 15 | from ray.data._internal.logical.operators.one_to_one_operator import Limit |
16 | 16 | from ray.data._internal.logical.optimizers import LogicalOptimizer |
17 | 17 | from ray.data._internal.util import rows_same |
@@ -543,6 +543,176 @@ def test_multiple_filters_with_renames(self, parquet_ds): |
543 | 543 | ), "All filters should be fused, rebound, and pushed into Read" |
544 | 544 |
|
545 | 545 |
|
| 546 | +class TestProjectionWithFilterEdgeCases: |
| 547 | + """Tests for edge cases with select_columns and with_column followed by filters. |
| 548 | +
|
| 549 | + These tests verify that filters correctly handle: |
| 550 | + - Columns that are kept by select (should push through) |
| 551 | + - Columns that are removed by select (should NOT push through) |
| 552 | + - Computed columns from with_column (should NOT push through) |
| 553 | + """ |
| 554 | + |
| 555 | + @pytest.fixture |
| 556 | + def base_ds(self, ray_start_regular_shared): |
| 557 | + return ray.data.from_items( |
| 558 | + [ |
| 559 | + {"a": 1, "b": 2, "c": 3}, |
| 560 | + {"a": 2, "b": 5, "c": 8}, |
| 561 | + {"a": 3, "b": 6, "c": 9}, |
| 562 | + ] |
| 563 | + ) |
| 564 | + |
| 565 | + def test_select_then_filter_on_selected_column(self, base_ds): |
| 566 | + """Filter on selected column should push through select.""" |
| 567 | + ds = base_ds.select_columns(["a", "b"]).filter(expr=col("a") > 1) |
| 568 | + |
| 569 | + # Verify correctness |
| 570 | + result_df = ds.to_pandas() |
| 571 | + expected_df = pd.DataFrame( |
| 572 | + [ |
| 573 | + {"a": 2, "b": 5}, |
| 574 | + {"a": 3, "b": 6}, |
| 575 | + ] |
| 576 | + ) |
| 577 | + # Sort columns before comparison |
| 578 | + result_df = result_df[sorted(result_df.columns)] |
| 579 | + expected_df = expected_df[sorted(expected_df.columns)] |
| 580 | + assert rows_same(result_df, expected_df) |
| 581 | + |
| 582 | + # Verify plan: filter pushed through select |
| 583 | + optimized_plan = LogicalOptimizer().optimize(ds._plan._logical_plan) |
| 584 | + assert plan_operator_comes_before( |
| 585 | + optimized_plan, Filter, Project |
| 586 | + ), "Filter should be pushed before Project" |
| 587 | + |
| 588 | + def test_select_then_filter_on_removed_column(self, base_ds): |
| 589 | + """Filter on removed column should fail, not push through.""" |
| 590 | + ds = base_ds.select_columns(["a"]) |
| 591 | + |
| 592 | + with pytest.raises((KeyError, ray.exceptions.RayTaskError)): |
| 593 | + ds.filter(expr=col("b") == 2).take_all() |
| 594 | + |
| 595 | + def test_with_column_then_filter_on_computed_column(self, base_ds): |
| 596 | + """Filter on computed column should not push through.""" |
| 597 | + |
| 598 | + from ray.data.expressions import lit |
| 599 | + |
| 600 | + ds = base_ds.with_column("d", lit(4)).filter(expr=col("d") == 4) |
| 601 | + |
| 602 | + # Verify correctness - all rows should pass (d is always 4) |
| 603 | + result_df = ds.to_pandas() |
| 604 | + expected_df = pd.DataFrame( |
| 605 | + [ |
| 606 | + {"a": 1, "b": 2, "c": 3, "d": 4}, |
| 607 | + {"a": 2, "b": 5, "c": 8, "d": 4}, |
| 608 | + {"a": 3, "b": 6, "c": 9, "d": 4}, |
| 609 | + ] |
| 610 | + ) |
| 611 | + # Sort columns before comparison |
| 612 | + result_df = result_df[sorted(result_df.columns)] |
| 613 | + expected_df = expected_df[sorted(expected_df.columns)] |
| 614 | + assert rows_same(result_df, expected_df) |
| 615 | + |
| 616 | + # Verify plan: filter should NOT push through (stays after with_column) |
| 617 | + optimized_plan = LogicalOptimizer().optimize(ds._plan._logical_plan) |
| 618 | + assert plan_has_operator( |
| 619 | + optimized_plan, Filter |
| 620 | + ), "Filter should remain (not pushed through)" |
| 621 | + |
| 622 | + def test_rename_then_filter_on_old_column_name(self, base_ds): |
| 623 | + """Filter using old column name after rename should fail.""" |
| 624 | + ds = base_ds.rename_columns({"b": "B"}) |
| 625 | + |
| 626 | + with pytest.raises((KeyError, ray.exceptions.RayTaskError)): |
| 627 | + ds.filter(expr=col("b") == 2).take_all() |
| 628 | + |
| 629 | + @pytest.mark.parametrize( |
| 630 | + "ds_factory,rename_map,filter_col,filter_value,expected_rows", |
| 631 | + [ |
| 632 | + # In-memory dataset: rename a->b, b->b_old |
| 633 | + ( |
| 634 | + lambda: ray.data.from_items( |
| 635 | + [ |
| 636 | + {"a": 1, "b": 2, "c": 3}, |
| 637 | + {"a": 2, "b": 5, "c": 8}, |
| 638 | + {"a": 3, "b": 6, "c": 9}, |
| 639 | + ] |
| 640 | + ), |
| 641 | + {"a": "b", "b": "b_old"}, |
| 642 | + "b", |
| 643 | + 1, |
| 644 | + [{"b": 2, "b_old": 5, "c": 8}, {"b": 3, "b_old": 6, "c": 9}], |
| 645 | + ), |
| 646 | + # Parquet dataset: rename sepal.length->sepal.width, sepal.width->old_width |
| 647 | + ( |
| 648 | + lambda: ray.data.read_parquet("example://iris.parquet"), |
| 649 | + {"sepal.length": "sepal.width", "sepal.width": "old_width"}, |
| 650 | + "sepal.width", |
| 651 | + 5.0, |
| 652 | + None, # Will verify via alternative computation |
| 653 | + ), |
| 654 | + ], |
| 655 | + ids=["in_memory", "parquet"], |
| 656 | + ) |
| 657 | + def test_rename_chain_with_name_reuse( |
| 658 | + self, |
| 659 | + ray_start_regular_shared, |
| 660 | + ds_factory, |
| 661 | + rename_map, |
| 662 | + filter_col, |
| 663 | + filter_value, |
| 664 | + expected_rows, |
| 665 | + ): |
| 666 | + """Test rename chains where an output name matches another rename's input name. |
| 667 | +
|
| 668 | + This tests the fix for a bug where rename(a->b, b->c) followed by filter(b>5) |
| 669 | + would incorrectly block pushdown, even though 'b' is a valid output column |
| 670 | + (created by a->b). |
| 671 | +
|
| 672 | + Example: rename({'a': 'b', 'b': 'temp'}) creates 'b' from 'a' and 'temp' from 'b'. |
| 673 | + A filter on 'b' should be able to push through. |
| 674 | + """ |
| 675 | + ds = ds_factory() |
| 676 | + |
| 677 | + # Apply rename and filter |
| 678 | + ds_renamed_filtered = ds.rename_columns(rename_map).filter( |
| 679 | + expr=col(filter_col) > filter_value |
| 680 | + ) |
| 681 | + |
| 682 | + # Verify correctness |
| 683 | + if expected_rows is not None: |
| 684 | + # For in-memory, compare against expected rows |
| 685 | + result_df = ds_renamed_filtered.to_pandas() |
| 686 | + expected_df = pd.DataFrame(expected_rows) |
| 687 | + result_df = result_df[sorted(result_df.columns)] |
| 688 | + expected_df = expected_df[sorted(expected_df.columns)] |
| 689 | + assert rows_same(result_df, expected_df) |
| 690 | + else: |
| 691 | + # For parquet, compare against alternative computation |
| 692 | + # Filter on original column, then rename |
| 693 | + original_col = next(k for k, v in rename_map.items() if v == filter_col) |
| 694 | + expected = ds.filter(expr=col(original_col) > filter_value).rename_columns( |
| 695 | + rename_map |
| 696 | + ) |
| 697 | + assert rows_same(ds_renamed_filtered.to_pandas(), expected.to_pandas()) |
| 698 | + |
| 699 | + # Verify plan optimization |
| 700 | + optimized_plan = LogicalOptimizer().optimize( |
| 701 | + ds_renamed_filtered._plan._logical_plan |
| 702 | + ) |
| 703 | + |
| 704 | + # For parquet (supports predicate pushdown), filter should push into Read |
| 705 | + if "parquet" in str(ds._plan._logical_plan.dag).lower(): |
| 706 | + assert not plan_has_operator( |
| 707 | + optimized_plan, Filter |
| 708 | + ), "Filter should be pushed into Read after rebinding through rename chain" |
| 709 | + else: |
| 710 | + # For in-memory, filter should at least push through projection |
| 711 | + assert plan_operator_comes_before( |
| 712 | + optimized_plan, Filter, Project |
| 713 | + ), "Filter should be pushed before Project after rebinding through rename chain" |
| 714 | + |
| 715 | + |
546 | 716 | class TestPushIntoBranchesBehavior: |
547 | 717 | """Tests for PUSH_INTO_BRANCHES behavior operators. |
548 | 718 |
|
|
0 commit comments