Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions python/ray/air/tests/test_new_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,143 @@ def test_configure_execution_options_carryover_context(ray_start_4_cpus):
assert ingest_options.verbose_progress is True


def test_per_dataset_execution_options_single(ray_start_4_cpus):
"""Test that a single ExecutionOptions object applies to all datasets."""
ds = ray.data.range(10)

# Create execution options with specific settings
execution_options = ExecutionOptions()
execution_options.preserve_order = True
execution_options.verbose_progress = True

data_config = DataConfig(execution_options=execution_options)

# Verify that all datasets get the same execution options
train_options = data_config._get_execution_options("train")
test_options = data_config._get_execution_options("test")
val_options = data_config._get_execution_options("val")

assert train_options.preserve_order is True
assert train_options.verbose_progress is True
assert test_options.preserve_order is True
assert test_options.verbose_progress is True
assert val_options.preserve_order is True
assert val_options.verbose_progress is True

# Test that it works in practice
test = TestBasic(
1,
True,
{"train": 10, "test": 10},
datasets={"train": ds, "test": ds},
dataset_config=data_config,
)
test.fit()


def test_per_dataset_execution_options_dict(ray_start_4_cpus):
"""Test that a dict of ExecutionOptions maps to specific datasets, and datasets
not in the dict get default ingest options. Also tests resource limits."""
ds = ray.data.range(10)

# Create different execution options for different datasets
train_options = ExecutionOptions()
train_options.preserve_order = True
train_options.verbose_progress = True
train_options.resource_limits = train_options.resource_limits.copy(cpu=4, gpu=2)

test_options = ExecutionOptions()
test_options.preserve_order = False
test_options.verbose_progress = False
test_options.resource_limits = test_options.resource_limits.copy(cpu=2, gpu=1)

execution_options_dict = {
"train": train_options,
"test": test_options,
}

data_config = DataConfig(execution_options=execution_options_dict)

# Verify that each dataset in the dict gets its specific options
retrieved_train_options = data_config._get_execution_options("train")
retrieved_test_options = data_config._get_execution_options("test")

assert retrieved_train_options.preserve_order is True
assert retrieved_train_options.verbose_progress is True
assert retrieved_test_options.preserve_order is False
assert retrieved_test_options.verbose_progress is False

# Verify resource limits
assert retrieved_train_options.resource_limits.cpu == 4
assert retrieved_train_options.resource_limits.gpu == 2
assert retrieved_test_options.resource_limits.cpu == 2
assert retrieved_test_options.resource_limits.gpu == 1

# Verify that a dataset not in the dict gets default options
default_options = DataConfig.default_ingest_options()
retrieved_val_options = data_config._get_execution_options("val")
assert retrieved_val_options.preserve_order == default_options.preserve_order
assert retrieved_val_options.verbose_progress == default_options.verbose_progress
assert (
retrieved_val_options.resource_limits.cpu == default_options.resource_limits.cpu
)
assert (
retrieved_val_options.resource_limits.gpu == default_options.resource_limits.gpu
)

# Test that it works in practice
test = TestBasic(
1,
True,
{"train": 10, "test": 10, "val": 10},
datasets={"train": ds, "test": ds, "val": ds},
dataset_config=data_config,
)
test.fit()


def test_per_dataset_execution_options_default(ray_start_4_cpus):
"""Test that None or empty dict execution_options results in all datasets
using default options."""
ds = ray.data.range(10)

# Test with None
data_config_none = DataConfig(execution_options=None)
default_options = DataConfig.default_ingest_options()
retrieved_train_options = data_config_none._get_execution_options("train")
retrieved_test_options = data_config_none._get_execution_options("test")

assert retrieved_train_options.preserve_order == default_options.preserve_order
assert retrieved_test_options.preserve_order == default_options.preserve_order

# Test with empty dict
data_config_empty = DataConfig(execution_options={})
retrieved_train_options = data_config_empty._get_execution_options("train")
retrieved_test_options = data_config_empty._get_execution_options("test")

assert retrieved_train_options.preserve_order == default_options.preserve_order
assert retrieved_test_options.preserve_order == default_options.preserve_order

# Test that it works in practice
test = TestBasic(
1,
True,
{"train": 10, "test": 10},
datasets={"train": ds, "test": ds},
dataset_config=data_config_none,
)
test.fit()

test = TestBasic(
1,
True,
{"train": 10, "test": 10},
datasets={"train": ds, "test": ds},
dataset_config=data_config_empty,
)
test.fit()


@pytest.mark.parametrize("enable_locality", [True, False])
def test_configure_locality(enable_locality):
data_config = DataConfig(enable_shard_locality=enable_locality)
Expand Down
38 changes: 27 additions & 11 deletions python/ray/train/_internal/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ class DataConfig:
def __init__(
self,
datasets_to_split: Union[Literal["all"], List[str]] = "all",
execution_options: Optional[ExecutionOptions] = None,
execution_options: Optional[
Union[ExecutionOptions, Dict[str, ExecutionOptions]]
] = None,
enable_shard_locality: bool = True,
):
"""Construct a DataConfig.
Expand All @@ -28,12 +30,14 @@ def __init__(
datasets_to_split: Specifies which datasets should be split among workers.
Can be set to "all" or a list of dataset names. Defaults to "all",
i.e. split all datasets.
execution_options: The execution options to pass to Ray Data. By default,
the options will be optimized for data ingest. When overriding this,
base your options off of `DataConfig.default_ingest_options()`.
enable_shard_locality: If true, when sharding the datasets across Train
workers, locality will be considered to minimize cross-node data transfer.
This is on by default.
execution_options: The execution options to pass to Ray Data. Can be either:
1. A single ExecutionOptions object that is applied to all datasets.
2. A dict mapping dataset names to ExecutionOptions. If a dataset name
is not in the dict, it defaults to ``DataConfig.default_ingest_options()``.
By default, the options are optimized for data ingest. When overriding,
base your options off ``DataConfig.default_ingest_options()``.
enable_shard_locality: If true, dataset sharding across Train workers will
consider locality to minimize cross-node data transfer. Enabled by default.
"""
if isinstance(datasets_to_split, list) or datasets_to_split == "all":
self._datasets_to_split = datasets_to_split
Expand All @@ -44,9 +48,8 @@ def __init__(
f"{type(datasets_to_split).__name__} with value {datasets_to_split}."
)

self._execution_options: ExecutionOptions = (
execution_options or DataConfig.default_ingest_options()
)
# If None, all datasets will use the default ingest options.
self._execution_options = execution_options or {}
self._enable_shard_locality = enable_shard_locality

self._num_train_cpus = 0.0
Expand All @@ -62,6 +65,19 @@ def set_train_total_resources(self, num_train_cpus: float, num_train_gpus: float
self._num_train_cpus = num_train_cpus
self._num_train_gpus = num_train_gpus

def _get_execution_options(self, dataset_name: str) -> ExecutionOptions:
"""Return a copy of the configured execution options for a given dataset name."""
if isinstance(self._execution_options, dict):
res = self._execution_options.get(
dataset_name, DataConfig.default_ingest_options()
)
else:
assert isinstance(
self._execution_options, ExecutionOptions
), "execution_options must be a dictionary of ExecutionOptions objects by dataset name or a single ExecutionOptions object."
res = self._execution_options
return copy.deepcopy(res)

@DeveloperAPI
def configure(
self,
Expand Down Expand Up @@ -98,7 +114,7 @@ def configure(

locality_hints = worker_node_ids if self._enable_shard_locality else None
for name, ds in datasets.items():
execution_options = copy.deepcopy(self._execution_options)
execution_options = self._get_execution_options(name)

if execution_options.is_resource_limits_default():
# If "resource_limits" is not overriden by the user,
Expand Down
Loading