Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/ray/data/_internal/streaming_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, target_num_rows_per_block: int):
self._total_pending_rows = 0

def _try_build_ready_bundle(self, flush_remaining: bool = False):
if self._total_pending_rows >= self._target_num_rows or flush_remaining:
if self._total_pending_rows >= self._target_num_rows:
rows_needed_from_last_bundle = (
self._pending_bundles[-1].num_rows()
- self._total_pending_rows % self._target_num_rows
Expand All @@ -50,6 +50,12 @@ def _try_build_ready_bundle(self, flush_remaining: bool = False):
if remaining_bundle and remaining_bundle.num_rows() > 0:
self._pending_bundles.append(remaining_bundle)
self._total_pending_rows += remaining_bundle.num_rows()
elif flush_remaining:
self._ready_bundles.append(
RefBundle.merge_ref_bundles(self._pending_bundles)
)
self._pending_bundles.clear()
self._total_pending_rows = 0

def add_bundle(self, ref_bundle: RefBundle):
self._total_pending_rows += ref_bundle.num_rows()
Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/tests/unit/test_bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def _make_ref_bundles_for_unit_test(raw_bundles: List[List[List[Any]]]) -> tuple
[[[1]], [[]], [[2, 3]], [[]], [[4, 5]]],
[3, 2], # Expected: [1,2,3] and [4,5]
),
(
# Test with last block smaller than target num rows per block
100,
[[[1]], [[2]], [[3]], [[4]], [[5]]],
[5],
),
],
)
def test_streaming_repartition_ref_bundler(target, in_bundles, expected_row_counts):
Expand Down