Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/ray/data/_internal/streaming_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, target_num_rows_per_block: int):
self._total_pending_rows = 0

def _try_build_ready_bundle(self, flush_remaining: bool = False):
if self._total_pending_rows >= self._target_num_rows or flush_remaining:
if self._total_pending_rows >= self._target_num_rows:
rows_needed_from_last_bundle = (
self._pending_bundles[-1].num_rows()
- self._total_pending_rows % self._target_num_rows
Expand All @@ -50,6 +50,12 @@ def _try_build_ready_bundle(self, flush_remaining: bool = False):
if remaining_bundle and remaining_bundle.num_rows() > 0:
self._pending_bundles.append(remaining_bundle)
self._total_pending_rows += remaining_bundle.num_rows()
if flush_remaining and self._total_pending_rows > 0:
self._ready_bundles.append(
RefBundle.merge_ref_bundles(self._pending_bundles)
)
self._pending_bundles.clear()
self._total_pending_rows = 0

def add_bundle(self, ref_bundle: RefBundle):
self._total_pending_rows += ref_bundle.num_rows()
Expand Down
6 changes: 6 additions & 0 deletions python/ray/data/tests/unit/test_bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def _make_ref_bundles_for_unit_test(raw_bundles: List[List[List[Any]]]) -> tuple
[[[1]], [[]], [[2, 3]], [[]], [[4, 5]]],
[3, 2], # Expected: [1,2,3] and [4,5]
),
(
# Test with last block smaller than target num rows per block
100,
[[[1]], [[2]], [[3]], [[4]], [[5]]],
[5],
),
],
)
def test_streaming_repartition_ref_bundler(target, in_bundles, expected_row_counts):
Expand Down