Skip to content

Commit dad8002

Browse files
authored
[Data] Avoid slicing block when total_pending_rows < target (#58699)
## Description Previously we will try slice the block when `self._total_pending_rows >= self._target_num_rows` or `flush_remaining` is True, but flush_remaining doesn't mean `self._total_pending_rows >= self._target_num_rows ` so it could make the slicing failed because our slicing logic is based on assumption there should be at least one full block. This PR fix the logic and added test for such case. --------- Signed-off-by: You-Cheng Lin <[email protected]>
1 parent 83a456e commit dad8002

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

python/ray/data/_internal/streaming_repartition.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self, target_num_rows_per_block: int):
2727
self._total_pending_rows = 0
2828

2929
def _try_build_ready_bundle(self, flush_remaining: bool = False):
30-
if self._total_pending_rows >= self._target_num_rows or flush_remaining:
30+
if self._total_pending_rows >= self._target_num_rows:
3131
rows_needed_from_last_bundle = (
3232
self._pending_bundles[-1].num_rows()
3333
- self._total_pending_rows % self._target_num_rows
@@ -50,6 +50,12 @@ def _try_build_ready_bundle(self, flush_remaining: bool = False):
5050
if remaining_bundle and remaining_bundle.num_rows() > 0:
5151
self._pending_bundles.append(remaining_bundle)
5252
self._total_pending_rows += remaining_bundle.num_rows()
53+
if flush_remaining and self._total_pending_rows > 0:
54+
self._ready_bundles.append(
55+
RefBundle.merge_ref_bundles(self._pending_bundles)
56+
)
57+
self._pending_bundles.clear()
58+
self._total_pending_rows = 0
5359

5460
def add_bundle(self, ref_bundle: RefBundle):
5561
self._total_pending_rows += ref_bundle.num_rows()

python/ray/data/tests/unit/test_bundler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def _make_ref_bundles_for_unit_test(raw_bundles: List[List[List[Any]]]) -> tuple
6565
[[[1]], [[]], [[2, 3]], [[]], [[4, 5]]],
6666
[3, 2], # Expected: [1,2,3] and [4,5]
6767
),
68+
(
69+
# Test with last block smaller than target num rows per block
70+
100,
71+
[[[1]], [[2]], [[3]], [[4]], [[5]]],
72+
[5],
73+
),
6874
],
6975
)
7076
def test_streaming_repartition_ref_bundler(target, in_bundles, expected_row_counts):

0 commit comments

Comments
 (0)