Skip to content

Commit 6ca08a8

Browse files
alexeykudinkinlandscapepainter
authored andcommitted
[Data] Make streaming repartition combine small blocks (ray-project#58020)
## Description Currently, streaming repartition isn't combining blocks to the `target_num_rows_per_block` which is problematic, in a sense that it can only split blocks but not recombine them. This PR is addressing that by allowing it to recombine smaller blocks into bigger ones. However, one caveat is that the remainder of the block could still be under `target_num_rows_per_block`. ## Related issues > Link related issues: "Fixes ray-project#1234", "Closes ray-project#1234", or "Related to ray-project#1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Alexey Kudinkin <[email protected]>
1 parent 327bde8 commit 6ca08a8

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

python/ray/data/_internal/planner/plan_udf_map_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def plan_streaming_repartition_op(
169169
data_context,
170170
name=op.name,
171171
compute_strategy=compute,
172+
min_rows_per_bundle=op.target_num_rows_per_block,
172173
ray_remote_args=op._ray_remote_args,
173174
ray_remote_args_fn=op._ray_remote_args_fn,
174175
supports_fusion=False,

python/ray/data/tests/test_repartition_e2e.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,42 +127,54 @@ def test_repartition_shuffle_arrow(
127127

128128

129129
@pytest.mark.parametrize(
130-
"total_rows,target_num_rows_per_block",
130+
"total_rows,target_num_rows_per_block,expected_num_blocks",
131131
[
132-
(128, 1),
133-
(128, 2),
134-
(128, 4),
135-
(128, 8),
136-
(128, 128),
132+
(128, 1, 128),
133+
(128, 2, 64),
134+
(128, 4, 32),
135+
(128, 8, 16),
136+
(128, 128, 1),
137137
],
138138
)
139139
def test_repartition_target_num_rows_per_block(
140140
ray_start_regular_shared_2_cpus,
141141
total_rows,
142142
target_num_rows_per_block,
143+
expected_num_blocks,
143144
disable_fallback_to_object_extension,
144145
):
145-
ds = ray.data.range(total_rows).repartition(
146+
num_blocks = 16
147+
148+
# Each block is 8 ints
149+
ds = ray.data.range(total_rows, override_num_blocks=num_blocks).repartition(
146150
target_num_rows_per_block=target_num_rows_per_block,
147151
)
148-
rows_count = 0
152+
153+
num_blocks = 0
154+
num_rows = 0
149155
all_data = []
156+
150157
for ref_bundle in ds.iter_internal_ref_bundles():
151158
block, block_metadata = (
152159
ray.get(ref_bundle.blocks[0][0]),
153160
ref_bundle.blocks[0][1],
154161
)
155-
assert block_metadata.num_rows <= target_num_rows_per_block
156-
rows_count += block_metadata.num_rows
162+
163+
# NOTE: Because our block rows % target_num_rows_per_block == 0, we can
164+
# assert equality here
165+
assert block_metadata.num_rows == target_num_rows_per_block
166+
167+
num_blocks += 1
168+
num_rows += block_metadata.num_rows
169+
157170
block_data = (
158171
BlockAccessor.for_block(block).to_pandas().to_dict(orient="records")
159172
)
160173
all_data.extend(block_data)
161174

162-
assert rows_count == total_rows
163-
164175
# Verify total rows match
165-
assert rows_count == total_rows
176+
assert num_rows == total_rows
177+
assert num_blocks == expected_num_blocks
166178

167179
# Verify data consistency
168180
all_values = [row["id"] for row in all_data]

0 commit comments

Comments
 (0)