diff --git a/python/ray/data/_internal/streaming_repartition.py b/python/ray/data/_internal/streaming_repartition.py index 1f5de28d51ab..1e0fe3e3e442 100644 --- a/python/ray/data/_internal/streaming_repartition.py +++ b/python/ray/data/_internal/streaming_repartition.py @@ -27,7 +27,7 @@ def __init__(self, target_num_rows_per_block: int): self._total_pending_rows = 0 def _try_build_ready_bundle(self, flush_remaining: bool = False): - if self._total_pending_rows >= self._target_num_rows or flush_remaining: + if self._total_pending_rows >= self._target_num_rows: rows_needed_from_last_bundle = ( self._pending_bundles[-1].num_rows() - self._total_pending_rows % self._target_num_rows @@ -50,6 +50,12 @@ def _try_build_ready_bundle(self, flush_remaining: bool = False): if remaining_bundle and remaining_bundle.num_rows() > 0: self._pending_bundles.append(remaining_bundle) self._total_pending_rows += remaining_bundle.num_rows() + if flush_remaining and self._total_pending_rows > 0: + self._ready_bundles.append( + RefBundle.merge_ref_bundles(self._pending_bundles) + ) + self._pending_bundles.clear() + self._total_pending_rows = 0 def add_bundle(self, ref_bundle: RefBundle): self._total_pending_rows += ref_bundle.num_rows() diff --git a/python/ray/data/tests/unit/test_bundler.py b/python/ray/data/tests/unit/test_bundler.py index 6a5c7e961fba..d4d9979968c2 100644 --- a/python/ray/data/tests/unit/test_bundler.py +++ b/python/ray/data/tests/unit/test_bundler.py @@ -65,6 +65,12 @@ def _make_ref_bundles_for_unit_test(raw_bundles: List[List[List[Any]]]) -> tuple [[[1]], [[]], [[2, 3]], [[]], [[4, 5]]], [3, 2], # Expected: [1,2,3] and [4,5] ), + ( + # Test with last block smaller than target num rows per block + 100, + [[[1]], [[2]], [[3]], [[4]], [[5]]], + [5], + ), ], ) def test_streaming_repartition_ref_bundler(target, in_bundles, expected_row_counts):