Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 85 additions & 26 deletions python/ray/data/tests/test_execution_optimizer_limit_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,56 +34,86 @@ def _check_valid_plan_and_result(
assert op in ds.stats(), f"Operator {op} not found: {ds.stats()}"


def test_limit_pushdown_conservative(ray_start_regular_shared_2_cpus):
"""Test limit pushdown behavior - pushes through safe operations."""

def f1(x):
return x

def f2(x):
return x

# Test 1: Basic Limit -> Limit fusion (should still work)
def test_limit_pushdown_basic_limit_fusion(ray_start_regular_shared_2_cpus):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except for the tests listed below, the Ray Data pipelines in this diff aren't guaranteed to produce rows in a specific output order. So, if you don't set check_ordering=False, the tests might raise false negatives.

Could you update the appropriate tests to use check_ordering=False?

List of tests with guaranteed order (okay to use check_ordering=True):

  • test_limit_pushdown_union_with_sort
  • test_limit_pushdown_complex_interweaved_operations
  • test_limit_pushdown_stops_at_sort

"""Test basic Limit -> Limit fusion."""
ds = ray.data.range(100).limit(5).limit(100)
_check_valid_plan_and_result(
ds, "Read[ReadRange] -> Limit[limit=5]", [{"id": i} for i in range(5)]
ds,
"Read[ReadRange] -> Limit[limit=5]",
[{"id": i} for i in range(5)],
check_ordering=False,
)


def test_limit_pushdown_limit_fusion_reversed(ray_start_regular_shared_2_cpus):
"""Test Limit fusion with reversed order."""
ds = ray.data.range(100).limit(100).limit(5)
_check_valid_plan_and_result(
ds, "Read[ReadRange] -> Limit[limit=5]", [{"id": i} for i in range(5)]
ds,
"Read[ReadRange] -> Limit[limit=5]",
[{"id": i} for i in range(5)],
check_ordering=False,
)


def test_limit_pushdown_multiple_limit_fusion(ray_start_regular_shared_2_cpus):
"""Test multiple Limit operations fusion."""
ds = ray.data.range(100).limit(50).limit(80).limit(5).limit(20)
_check_valid_plan_and_result(
ds, "Read[ReadRange] -> Limit[limit=5]", [{"id": i} for i in range(5)]
ds,
"Read[ReadRange] -> Limit[limit=5]",
[{"id": i} for i in range(5)],
check_ordering=False,
)

# Test 2: Limit should push through MapRows operations (safe)

def test_limit_pushdown_through_maprows(ray_start_regular_shared_2_cpus):
"""Test that Limit pushes through MapRows operations."""

def f1(x):
return x

ds = ray.data.range(100, override_num_blocks=100).map(f1).limit(1)
_check_valid_plan_and_result(
ds, "Read[ReadRange] -> Limit[limit=1] -> MapRows[Map(f1)]", [{"id": 0}]
ds,
"Read[ReadRange] -> Limit[limit=1] -> MapRows[Map(f1)]",
[{"id": 0}],
check_ordering=False,
)

# Test 3: Limit should push through MapBatches operations

def test_limit_pushdown_through_mapbatches(ray_start_regular_shared_2_cpus):
"""Test that Limit pushes through MapBatches operations."""

def f2(x):
return x

ds = ray.data.range(100, override_num_blocks=100).map_batches(f2).limit(1)
_check_valid_plan_and_result(
ds,
"Read[ReadRange] -> Limit[limit=1] -> MapBatches[MapBatches(f2)]",
[{"id": 0}],
check_ordering=False,
)

# Test 4: Limit should NOT push through Filter operations (conservative)

def test_limit_pushdown_stops_at_filter(ray_start_regular_shared_2_cpus):
"""Test that Limit does NOT push through Filter operations (conservative)."""
ds = (
ray.data.range(100, override_num_blocks=100)
.filter(lambda x: x["id"] < 50)
.limit(1)
)
_check_valid_plan_and_result(
ds, "Read[ReadRange] -> Filter[Filter(<lambda>)] -> Limit[limit=1]", [{"id": 0}]
ds,
"Read[ReadRange] -> Filter[Filter(<lambda>)] -> Limit[limit=1]",
[{"id": 0}],
check_ordering=False,
)

# Test 5: Limit should push through Project operations (safe)

def test_limit_pushdown_through_project(ray_start_regular_shared_2_cpus):
"""Test that Limit pushes through Project operations."""
ds = ray.data.range(100, override_num_blocks=100).select_columns(["id"]).limit(5)
_check_valid_plan_and_result(
ds,
Expand All @@ -92,15 +122,26 @@ def f2(x):
check_ordering=False,
)

# Test 6: Limit should stop at Sort operations (AllToAll)

def test_limit_pushdown_stops_at_sort(ray_start_regular_shared_2_cpus):
"""Test that Limit stops at Sort operations (AllToAll)."""
ds = ray.data.range(100).sort("id").limit(5)
_check_valid_plan_and_result(
ds,
"Read[ReadRange] -> Sort[Sort] -> Limit[limit=5]",
[{"id": i} for i in range(5)],
)

# Test 7: More complex interweaved case.

def test_limit_pushdown_complex_interweaved_operations(ray_start_regular_shared_2_cpus):
"""Test Limit pushdown with complex interweaved operations."""

def f1(x):
return x

def f2(x):
return x

ds = ray.data.range(100).sort("id").map(f1).limit(20).sort("id").map(f2).limit(5)
_check_valid_plan_and_result(
ds,
Expand All @@ -109,12 +150,22 @@ def f2(x):
[{"id": i} for i in range(5)],
)

# Test 8: Test limit pushdown between two Map operators.

def test_limit_pushdown_between_two_map_operators(ray_start_regular_shared_2_cpus):
"""Test Limit pushdown between two Map operators."""

def f1(x):
return x

def f2(x):
return x

ds = ray.data.range(100, override_num_blocks=100).map(f1).limit(1).map(f2)
_check_valid_plan_and_result(
ds,
"Read[ReadRange] -> Limit[limit=1] -> MapRows[Map(f1)] -> MapRows[Map(f2)]",
[{"id": 0}],
check_ordering=False,
)


Expand Down Expand Up @@ -285,7 +336,9 @@ def test_limit_pushdown_union(ray_start_regular_shared_2_cpus):
ds = ds1.union(ds2).limit(5)

expected_plan = "Read[ReadRange] -> Limit[limit=5], Read[ReadRange] -> Limit[limit=5] -> Union[Union] -> Limit[limit=5]"
_check_valid_plan_and_result(ds, expected_plan, [{"id": i} for i in range(5)])
_check_valid_plan_and_result(
ds, expected_plan, [{"id": i} for i in range(5)], check_ordering=False
)


def test_limit_pushdown_union_with_maprows(ray_start_regular_shared_2_cpus):
Expand All @@ -300,7 +353,9 @@ def test_limit_pushdown_union_with_maprows(ray_start_regular_shared_2_cpus):
"Read[ReadRange] -> Limit[limit=5] -> Union[Union] -> "
"Limit[limit=5] -> MapRows[Map(<lambda>)]"
)
_check_valid_plan_and_result(ds, expected_plan, [{"id": i} for i in range(5)])
_check_valid_plan_and_result(
ds, expected_plan, [{"id": i} for i in range(5)], check_ordering=False
)


def test_limit_pushdown_union_with_sort(ray_start_regular_shared_2_cpus):
Expand Down Expand Up @@ -334,7 +389,9 @@ def test_limit_pushdown_multiple_unions(ray_start_regular_shared_2_cpus):
"Read[ReadRange] -> Limit[limit=5] -> Union[Union] -> Limit[limit=5], "
"Read[ReadRange] -> Limit[limit=5] -> Union[Union] -> Limit[limit=5]"
)
_check_valid_plan_and_result(ds, expected_plan, [{"id": i} for i in range(5)])
_check_valid_plan_and_result(
ds, expected_plan, [{"id": i} for i in range(5)], check_ordering=False
)


def test_limit_pushdown_union_with_groupby(ray_start_regular_shared_2_cpus):
Expand Down Expand Up @@ -427,7 +484,9 @@ def test_limit_pushdown_union_maps_projects(ray_start_regular_shared_2_cpus):

expected_result = [{"id": i} for i in range(3)] # First 3 rows from left branch.

_check_valid_plan_and_result(ds, expected_plan, expected_result)
_check_valid_plan_and_result(
ds, expected_plan, expected_result, check_ordering=False
)


def test_limit_pushdown_map_per_block_limit_applied(ray_start_regular_shared_2_cpus):
Expand Down