diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 148dbf168aaa..d2591f0072c6 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -169,6 +169,7 @@ def plan_streaming_repartition_op( data_context, name=op.name, compute_strategy=compute, + min_rows_per_bundle=op.target_num_rows_per_block, ray_remote_args=op._ray_remote_args, ray_remote_args_fn=op._ray_remote_args_fn, supports_fusion=False, diff --git a/python/ray/data/tests/test_repartition_e2e.py b/python/ray/data/tests/test_repartition_e2e.py index 28155c36b2ed..fa7c7315c9a8 100644 --- a/python/ray/data/tests/test_repartition_e2e.py +++ b/python/ray/data/tests/test_repartition_e2e.py @@ -127,42 +127,54 @@ def test_repartition_shuffle_arrow( @pytest.mark.parametrize( - "total_rows,target_num_rows_per_block", + "total_rows,target_num_rows_per_block,expected_num_blocks", [ - (128, 1), - (128, 2), - (128, 4), - (128, 8), - (128, 128), + (128, 1, 128), + (128, 2, 64), + (128, 4, 32), + (128, 8, 16), + (128, 128, 1), ], ) def test_repartition_target_num_rows_per_block( ray_start_regular_shared_2_cpus, total_rows, target_num_rows_per_block, + expected_num_blocks, disable_fallback_to_object_extension, ): - ds = ray.data.range(total_rows).repartition( + num_blocks = 16 + + # Each block is 8 ints + ds = ray.data.range(total_rows, override_num_blocks=num_blocks).repartition( target_num_rows_per_block=target_num_rows_per_block, ) - rows_count = 0 + + num_blocks = 0 + num_rows = 0 all_data = [] + for ref_bundle in ds.iter_internal_ref_bundles(): block, block_metadata = ( ray.get(ref_bundle.blocks[0][0]), ref_bundle.blocks[0][1], ) - assert block_metadata.num_rows <= target_num_rows_per_block - rows_count += block_metadata.num_rows + + # NOTE: Because our block rows % target_num_rows_per_block == 0, we can + # assert equality here + assert block_metadata.num_rows == target_num_rows_per_block + + num_blocks += 1 + num_rows += block_metadata.num_rows + block_data = ( BlockAccessor.for_block(block).to_pandas().to_dict(orient="records") ) all_data.extend(block_data) - assert rows_count == total_rows - # Verify total rows match - assert rows_count == total_rows + assert num_rows == total_rows + assert num_blocks == expected_num_blocks # Verify data consistency all_values = [row["id"] for row in all_data]