diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index d690596fd1cf..c06829cfd653 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -278,34 +278,31 @@ def _get_worker_id(_): assert len(unique_worker_ids) == total_blocks -def test_concurrency(shutdown_only, target_max_block_size_infinite_or_default): - ray.init(num_cpus=6) - ds = ray.data.range(10, override_num_blocks=10) +@pytest.mark.parametrize( + "concurrency", + [ + "spam", + # Two and three-tuples are valid for callable classes but not for functions. + (1, 2), + (1, 2, 3), + (1, 2, 3, 4), + ], +) +def test_invalid_func_concurrency_raises(ray_start_regular_shared, concurrency): + ds = ray.data.range(1) + with pytest.raises(ValueError): + ds.map(lambda x: x, concurrency=concurrency) - def udf(x): - return x - class UDFClass: - def __call__(self, x): - return x +@pytest.mark.parametrize("concurrency", ["spam", (1, 2, 3, 4)]) +def test_invalid_class_concurrency_raises(ray_start_regular_shared, concurrency): + class Fn: + def __call__(self, row): + return row - # Test function and class. - for fn in [udf, UDFClass]: - # Test concurrency with None, single integer and a tuple of integers. - for concurrency in [2, (2, 4), (2, 6, 4)]: - if fn == udf and (concurrency == (2, 4) or concurrency == (2, 6, 4)): - error_message = "``concurrency`` is set as a tuple of integers" - with pytest.raises(ValueError, match=error_message): - ds.map(fn, concurrency=concurrency).take_all() - else: - result = ds.map(fn, concurrency=concurrency).take_all() - assert sorted(extract_values("id", result)) == list(range(10)), result - - # Test concurrency with an illegal value. - error_message = "``concurrency`` is expected to be set a" - for concurrency in ["dummy", (1, 3, 5, 7)]: - with pytest.raises(ValueError, match=error_message): - ds.map(UDFClass, concurrency=concurrency).take_all() + ds = ray.data.range(1) + with pytest.raises(ValueError): + ds.map(Fn, concurrency=concurrency) @pytest.mark.parametrize("udf_kind", ["gen", "func"])