Skip to content

Commit 02afe68

Browse files
authored
[Data] Refactor concurrency validation tests in test_map.py (#58549)
## Description The original `test_concurrency` function combined multiple test scenarios into a single test with complex control flow and expensive Ray cluster initialization. This refactoring extracts the parameter validation tests into focused, independent tests that are faster, clearer, and easier to maintain. Additionally, the original test included "validation" cases that tested valid concurrency parameters but didn't actually verify that concurrency was being limited correctly—they only checked that the output was correct, which isn't useful for validating the concurrency feature itself. **Key improvements:** - Split validation tests into `test_invalid_func_concurrency_raises` and `test_invalid_class_concurrency_raises` - Use parametrized tests for different invalid concurrency values - Switch from `shutdown_only` with explicit `ray.init()` to `ray_start_regular_shared` to eliminate cluster initialization overhead - Minimize test data from 10 blocks to 1 element since we're only validating parameter errors - Remove non-validation tests that didn't verify concurrency behavior ## Related issues N/A ## Additional information The validation tests now execute significantly faster and provide clearer failure messages. Each test has a single, well-defined purpose making maintenance and debugging easier. --------- Signed-off-by: Balaji Veeramani <[email protected]>
1 parent 676b86f commit 02afe68

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

python/ray/data/tests/test_map.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -278,34 +278,31 @@ def _get_worker_id(_):
278278
assert len(unique_worker_ids) == total_blocks
279279

280280

281-
def test_concurrency(shutdown_only, target_max_block_size_infinite_or_default):
282-
ray.init(num_cpus=6)
283-
ds = ray.data.range(10, override_num_blocks=10)
281+
@pytest.mark.parametrize(
282+
"concurrency",
283+
[
284+
"spam",
285+
# Two and three-tuples are valid for callable classes but not for functions.
286+
(1, 2),
287+
(1, 2, 3),
288+
(1, 2, 3, 4),
289+
],
290+
)
291+
def test_invalid_func_concurrency_raises(ray_start_regular_shared, concurrency):
292+
ds = ray.data.range(1)
293+
with pytest.raises(ValueError):
294+
ds.map(lambda x: x, concurrency=concurrency)
284295

285-
def udf(x):
286-
return x
287296

288-
class UDFClass:
289-
def __call__(self, x):
290-
return x
297+
@pytest.mark.parametrize("concurrency", ["spam", (1, 2, 3, 4)])
298+
def test_invalid_class_concurrency_raises(ray_start_regular_shared, concurrency):
299+
class Fn:
300+
def __call__(self, row):
301+
return row
291302

292-
# Test function and class.
293-
for fn in [udf, UDFClass]:
294-
# Test concurrency with None, single integer and a tuple of integers.
295-
for concurrency in [2, (2, 4), (2, 6, 4)]:
296-
if fn == udf and (concurrency == (2, 4) or concurrency == (2, 6, 4)):
297-
error_message = "``concurrency`` is set as a tuple of integers"
298-
with pytest.raises(ValueError, match=error_message):
299-
ds.map(fn, concurrency=concurrency).take_all()
300-
else:
301-
result = ds.map(fn, concurrency=concurrency).take_all()
302-
assert sorted(extract_values("id", result)) == list(range(10)), result
303-
304-
# Test concurrency with an illegal value.
305-
error_message = "``concurrency`` is expected to be set a"
306-
for concurrency in ["dummy", (1, 3, 5, 7)]:
307-
with pytest.raises(ValueError, match=error_message):
308-
ds.map(UDFClass, concurrency=concurrency).take_all()
303+
ds = ray.data.range(1)
304+
with pytest.raises(ValueError):
305+
ds.map(Fn, concurrency=concurrency)
309306

310307

311308
@pytest.mark.parametrize("udf_kind", ["gen", "func"])

0 commit comments

Comments
 (0)