Skip to content

Commit a1167ad

Browse files
[Train] Per dataset execution_option for DataConfig (#58717)
This PR adds support for per-dataset execution options in `DataConfig`, allowing users to specify different `ExecutionOptions` for different datasets. This enables fine-grained control over how each dataset is processed by Ray Data. --------- Signed-off-by: xgui <[email protected]> Signed-off-by: Xinyuan <[email protected]> Co-authored-by: Justin Yu <[email protected]>
1 parent 0cfeb95 commit a1167ad

File tree

5 files changed

+325
-11
lines changed

5 files changed

+325
-11
lines changed

python/ray/air/tests/test_new_dataset_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def _run_data_config_resource_test(data_config):
301301
num_train_cpus = num_workers * cpus_per_worker + default_trainer_cpus
302302
num_train_gpus = num_workers * gpus_per_worker + default_trainer_gpus
303303

304-
original_execution_options = data_config._execution_options
304+
original_execution_options = data_config._get_execution_options("train")
305305

306306
ray.init(num_cpus=cluster_cpus, num_gpus=cluster_gpus)
307307

python/ray/train/_internal/data_config.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
from collections import defaultdict
23
from typing import Dict, List, Literal, Optional, Union
34

45
import ray
@@ -19,7 +20,9 @@ class DataConfig:
1920
def __init__(
2021
self,
2122
datasets_to_split: Union[Literal["all"], List[str]] = "all",
22-
execution_options: Optional[ExecutionOptions] = None,
23+
execution_options: Optional[
24+
Union[ExecutionOptions, Dict[str, ExecutionOptions]]
25+
] = None,
2326
enable_shard_locality: bool = True,
2427
):
2528
"""Construct a DataConfig.
@@ -28,12 +31,14 @@ def __init__(
2831
datasets_to_split: Specifies which datasets should be split among workers.
2932
Can be set to "all" or a list of dataset names. Defaults to "all",
3033
i.e. split all datasets.
31-
execution_options: The execution options to pass to Ray Data. By default,
32-
the options will be optimized for data ingest. When overriding this,
33-
base your options off of `DataConfig.default_ingest_options()`.
34-
enable_shard_locality: If true, when sharding the datasets across Train
35-
workers, locality will be considered to minimize cross-node data transfer.
36-
This is on by default.
34+
execution_options: The execution options to pass to Ray Data. Can be either:
35+
1. A single ExecutionOptions object that is applied to all datasets.
36+
2. A dict mapping dataset names to ExecutionOptions. If a dataset name
37+
is not in the dict, it defaults to ``DataConfig.default_ingest_options()``.
38+
By default, the options are optimized for data ingest. When overriding,
39+
base your options off ``DataConfig.default_ingest_options()``.
40+
enable_shard_locality: If true, dataset sharding across Train workers will
41+
consider locality to minimize cross-node data transfer. Enabled by default.
3742
"""
3843
if isinstance(datasets_to_split, list) or datasets_to_split == "all":
3944
self._datasets_to_split = datasets_to_split
@@ -44,9 +49,16 @@ def __init__(
4449
f"{type(datasets_to_split).__name__} with value {datasets_to_split}."
4550
)
4651

47-
self._execution_options: ExecutionOptions = (
48-
execution_options or DataConfig.default_ingest_options()
52+
default_execution_options = DataConfig.default_ingest_options()
53+
if isinstance(execution_options, ExecutionOptions):
54+
default_execution_options = execution_options
55+
# If None, all datasets will use the default ingest options.
56+
self._execution_options: Dict[str, ExecutionOptions] = defaultdict(
57+
lambda: copy.deepcopy(default_execution_options)
4958
)
59+
if isinstance(execution_options, dict):
60+
self._execution_options.update(execution_options)
61+
5062
self._enable_shard_locality = enable_shard_locality
5163

5264
self._num_train_cpus = 0.0
@@ -62,6 +74,10 @@ def set_train_total_resources(self, num_train_cpus: float, num_train_gpus: float
6274
self._num_train_cpus = num_train_cpus
6375
self._num_train_gpus = num_train_gpus
6476

77+
def _get_execution_options(self, dataset_name: str) -> ExecutionOptions:
78+
"""Return a copy of the configured execution options for a given dataset name."""
79+
return copy.deepcopy(self._execution_options[dataset_name])
80+
6581
@DeveloperAPI
6682
def configure(
6783
self,
@@ -98,7 +114,7 @@ def configure(
98114

99115
locality_hints = worker_node_ids if self._enable_shard_locality else None
100116
for name, ds in datasets.items():
101-
execution_options = copy.deepcopy(self._execution_options)
117+
execution_options = self._get_execution_options(name)
102118

103119
if execution_options.is_resource_limits_default():
104120
# If "resource_limits" is not overriden by the user,

python/ray/train/v2/BUILD.bazel

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,3 +676,19 @@ py_test(
676676
"//:ray_lib",
677677
],
678678
)
679+
680+
py_test(
681+
name = "test_data_config",
682+
size = "medium",
683+
srcs = ["tests/test_data_config.py"],
684+
env = {"RAY_TRAIN_V2_ENABLED": "1"},
685+
tags = [
686+
"exclusive",
687+
"team:ml",
688+
"train_v2",
689+
],
690+
deps = [
691+
":conftest",
692+
"//:ray_lib",
693+
],
694+
)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from ray.data._internal.execution.interfaces.execution_options import (
2+
ExecutionOptions,
3+
)
4+
from ray.train import DataConfig
5+
6+
7+
def test_per_dataset_execution_options_single(ray_start_4_cpus):
8+
"""Test that a single ExecutionOptions object applies to all datasets."""
9+
# Create execution options with specific settings
10+
execution_options = ExecutionOptions()
11+
execution_options.preserve_order = True
12+
execution_options.verbose_progress = True
13+
14+
data_config = DataConfig(execution_options=execution_options)
15+
16+
# Verify that all datasets get the same execution options
17+
train_options = data_config._get_execution_options("train")
18+
test_options = data_config._get_execution_options("test")
19+
val_options = data_config._get_execution_options("val")
20+
21+
assert train_options.preserve_order is True
22+
assert train_options.verbose_progress is True
23+
assert test_options.preserve_order is True
24+
assert test_options.verbose_progress is True
25+
assert val_options.preserve_order is True
26+
assert val_options.verbose_progress is True
27+
28+
29+
def test_per_dataset_execution_options_dict(ray_start_4_cpus):
30+
"""Test that a dict of ExecutionOptions maps to specific datasets, and datasets
31+
not in the dict get default ingest options. Also tests resource limits."""
32+
# Create different execution options for different datasets
33+
train_options = ExecutionOptions()
34+
train_options.preserve_order = True
35+
train_options.verbose_progress = True
36+
train_options.resource_limits = train_options.resource_limits.copy(cpu=4, gpu=2)
37+
38+
test_options = ExecutionOptions()
39+
test_options.preserve_order = False
40+
test_options.verbose_progress = False
41+
test_options.resource_limits = test_options.resource_limits.copy(cpu=2, gpu=1)
42+
43+
execution_options_dict = {
44+
"train": train_options,
45+
"test": test_options,
46+
}
47+
48+
data_config = DataConfig(execution_options=execution_options_dict)
49+
50+
# Verify that each dataset in the dict gets its specific options
51+
retrieved_train_options = data_config._get_execution_options("train")
52+
retrieved_test_options = data_config._get_execution_options("test")
53+
54+
assert retrieved_train_options.preserve_order is True
55+
assert retrieved_train_options.verbose_progress is True
56+
assert retrieved_test_options.preserve_order is False
57+
assert retrieved_test_options.verbose_progress is False
58+
59+
# Verify resource limits
60+
assert retrieved_train_options.resource_limits.cpu == 4
61+
assert retrieved_train_options.resource_limits.gpu == 2
62+
assert retrieved_test_options.resource_limits.cpu == 2
63+
assert retrieved_test_options.resource_limits.gpu == 1
64+
65+
# Verify that a dataset not in the dict gets default options
66+
default_options = DataConfig.default_ingest_options()
67+
retrieved_val_options = data_config._get_execution_options("val")
68+
assert retrieved_val_options.preserve_order == default_options.preserve_order
69+
assert retrieved_val_options.verbose_progress == default_options.verbose_progress
70+
assert (
71+
retrieved_val_options.resource_limits.cpu == default_options.resource_limits.cpu
72+
)
73+
assert (
74+
retrieved_val_options.resource_limits.gpu == default_options.resource_limits.gpu
75+
)
76+
77+
78+
def test_per_dataset_execution_options_default(ray_start_4_cpus):
79+
"""Test that None or empty dict execution_options results in all datasets
80+
using default options."""
81+
# Test with None
82+
data_config_none = DataConfig(execution_options=None)
83+
default_options = DataConfig.default_ingest_options()
84+
retrieved_train_options = data_config_none._get_execution_options("train")
85+
retrieved_test_options = data_config_none._get_execution_options("test")
86+
87+
assert retrieved_train_options.preserve_order == default_options.preserve_order
88+
assert retrieved_test_options.preserve_order == default_options.preserve_order
89+
90+
# Test with empty dict
91+
data_config_empty = DataConfig(execution_options={})
92+
retrieved_train_options = data_config_empty._get_execution_options("train")
93+
retrieved_test_options = data_config_empty._get_execution_options("test")
94+
95+
assert retrieved_train_options.preserve_order == default_options.preserve_order
96+
assert retrieved_test_options.preserve_order == default_options.preserve_order
97+
98+
99+
if __name__ == "__main__":
100+
import sys
101+
102+
import pytest
103+
104+
sys.exit(pytest.main(["-v", "-x", __file__]))

python/ray/train/v2/tests/test_data_integration.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,184 @@ def check_resource_limits(config):
293293
trainer.fit()
294294

295295

296+
def test_per_dataset_execution_options_single(ray_start_4_cpus):
297+
"""Test that a single ExecutionOptions object applies to all datasets."""
298+
NUM_ROWS = 100
299+
NUM_WORKERS = 2
300+
301+
train_ds = ray.data.range(NUM_ROWS)
302+
val_ds = ray.data.range(NUM_ROWS)
303+
304+
# Create execution options with specific settings
305+
execution_options = ExecutionOptions()
306+
execution_options.preserve_order = True
307+
execution_options.verbose_progress = True
308+
309+
data_config = ray.train.DataConfig(execution_options=execution_options)
310+
311+
def train_fn():
312+
train_shard = ray.train.get_dataset_shard("train")
313+
val_shard = ray.train.get_dataset_shard("val")
314+
315+
# Verify both datasets have the same execution options
316+
assert train_shard.get_context().execution_options.preserve_order is True
317+
assert train_shard.get_context().execution_options.verbose_progress is True
318+
assert val_shard.get_context().execution_options.preserve_order is True
319+
assert val_shard.get_context().execution_options.verbose_progress is True
320+
321+
trainer = DataParallelTrainer(
322+
train_fn,
323+
datasets={"train": train_ds, "val": val_ds},
324+
dataset_config=data_config,
325+
scaling_config=ray.train.ScalingConfig(num_workers=NUM_WORKERS),
326+
)
327+
trainer.fit()
328+
329+
330+
def test_per_dataset_execution_options_dict(ray_start_4_cpus):
331+
"""Test that a dict of ExecutionOptions maps to specific datasets, and datasets not in the dict get default ingest options. Also tests resource limits."""
332+
NUM_ROWS = 100
333+
NUM_WORKERS = 2
334+
335+
train_ds = ray.data.range(NUM_ROWS)
336+
val_ds = ray.data.range(NUM_ROWS)
337+
test_ds = ray.data.range(NUM_ROWS)
338+
test_ds_2 = ray.data.range(NUM_ROWS)
339+
340+
# Create different execution options for different datasets
341+
train_options = ExecutionOptions()
342+
train_options.preserve_order = True
343+
train_options.verbose_progress = True
344+
train_options.resource_limits = train_options.resource_limits.copy(cpu=4, gpu=2)
345+
346+
val_options = ExecutionOptions()
347+
val_options.preserve_order = False
348+
val_options.verbose_progress = False
349+
val_options.resource_limits = val_options.resource_limits.copy(cpu=2, gpu=1)
350+
351+
execution_options_dict = {
352+
"train": train_options,
353+
"val": val_options,
354+
}
355+
356+
data_config = ray.train.DataConfig(execution_options=execution_options_dict)
357+
358+
def train_fn():
359+
train_shard = ray.train.get_dataset_shard("train")
360+
val_shard = ray.train.get_dataset_shard("val")
361+
test_shard = ray.train.get_dataset_shard("test")
362+
test_shard_2 = ray.train.get_dataset_shard("test_2")
363+
364+
# Verify each dataset in the dict gets its specific options
365+
assert train_shard.get_context().execution_options.preserve_order is True
366+
assert train_shard.get_context().execution_options.verbose_progress is True
367+
assert val_shard.get_context().execution_options.preserve_order is False
368+
assert val_shard.get_context().execution_options.verbose_progress is False
369+
370+
# Verify resource limits
371+
assert train_shard.get_context().execution_options.resource_limits.cpu == 4
372+
assert train_shard.get_context().execution_options.resource_limits.gpu == 2
373+
assert val_shard.get_context().execution_options.resource_limits.cpu == 2
374+
assert val_shard.get_context().execution_options.resource_limits.gpu == 1
375+
376+
# Verify dataset not in the dict gets default options
377+
assert (
378+
test_shard.get_context().execution_options.preserve_order
379+
== test_shard_2.get_context().execution_options.preserve_order
380+
)
381+
assert (
382+
test_shard.get_context().execution_options.verbose_progress
383+
== test_shard_2.get_context().execution_options.verbose_progress
384+
)
385+
assert (
386+
test_shard.get_context().execution_options.resource_limits.cpu
387+
== test_shard_2.get_context().execution_options.resource_limits.cpu
388+
)
389+
assert (
390+
test_shard.get_context().execution_options.resource_limits.gpu
391+
== test_shard_2.get_context().execution_options.resource_limits.gpu
392+
)
393+
394+
trainer = DataParallelTrainer(
395+
train_fn,
396+
datasets={
397+
"train": train_ds,
398+
"val": val_ds,
399+
"test": test_ds,
400+
"test_2": test_ds_2,
401+
},
402+
dataset_config=data_config,
403+
scaling_config=ray.train.ScalingConfig(num_workers=NUM_WORKERS),
404+
)
405+
trainer.fit()
406+
407+
408+
def test_exclude_train_resources_applies_to_each_dataset(ray_start_4_cpus):
409+
"""Test that the default behavior of excluding train worker resources
410+
applies to each dataset individually when using per-dataset execution options."""
411+
NUM_ROWS = 100
412+
NUM_WORKERS = 2
413+
414+
# Create different execution options for different datasets
415+
train_options = ExecutionOptions()
416+
train_options.exclude_resources = train_options.exclude_resources.copy(cpu=2, gpu=1)
417+
418+
test_options = ExecutionOptions()
419+
test_options.exclude_resources = test_options.exclude_resources.copy(cpu=1, gpu=0)
420+
421+
# val dataset not in dict, should get default options
422+
execution_options_dict = {
423+
"train": train_options,
424+
"test": test_options,
425+
}
426+
data_config = ray.train.DataConfig(execution_options=execution_options_dict)
427+
428+
def train_fn():
429+
# Check that each dataset has the train resources excluded,
430+
# in addition to any per-dataset exclude_resources.
431+
432+
# Check train dataset
433+
train_ds = ray.train.get_dataset_shard("train")
434+
train_exec_options = train_ds.get_context().execution_options
435+
assert train_exec_options.is_resource_limits_default()
436+
# Train worker resources: NUM_WORKERS CPUs (default 1 CPU per worker)
437+
expected_train_cpu = NUM_WORKERS + 2 # 2 from user-defined
438+
expected_train_gpu = 0 + 1 # 1 from user-defined (no GPUs allocated)
439+
assert train_exec_options.exclude_resources.cpu == expected_train_cpu
440+
assert train_exec_options.exclude_resources.gpu == expected_train_gpu
441+
442+
# Check test dataset
443+
test_ds = ray.train.get_dataset_shard("test")
444+
test_exec_options = test_ds.get_context().execution_options
445+
assert test_exec_options.is_resource_limits_default()
446+
expected_test_cpu = NUM_WORKERS + 1 # 1 from user-defined
447+
expected_test_gpu = 0 + 0 # 0 from user-defined
448+
assert test_exec_options.exclude_resources.cpu == expected_test_cpu
449+
assert test_exec_options.exclude_resources.gpu == expected_test_gpu
450+
451+
# Check val dataset (should have default + train resources excluded)
452+
val_ds = ray.train.get_dataset_shard("val")
453+
val_exec_options = val_ds.get_context().execution_options
454+
assert val_exec_options.is_resource_limits_default()
455+
default_options = ray.train.DataConfig.default_ingest_options()
456+
expected_val_cpu = NUM_WORKERS + default_options.exclude_resources.cpu
457+
expected_val_gpu = 0 + default_options.exclude_resources.gpu
458+
assert val_exec_options.exclude_resources.cpu == expected_val_cpu
459+
assert val_exec_options.exclude_resources.gpu == expected_val_gpu
460+
461+
trainer = DataParallelTrainer(
462+
train_fn,
463+
datasets={
464+
"train": ray.data.range(NUM_ROWS),
465+
"test": ray.data.range(NUM_ROWS),
466+
"val": ray.data.range(NUM_ROWS),
467+
},
468+
dataset_config=data_config,
469+
scaling_config=ray.train.ScalingConfig(num_workers=NUM_WORKERS),
470+
)
471+
trainer.fit()
472+
473+
296474
if __name__ == "__main__":
297475
import sys
298476

0 commit comments

Comments
 (0)