@@ -293,6 +293,184 @@ def check_resource_limits(config):
293293 trainer .fit ()
294294
295295
296+ def test_per_dataset_execution_options_single (ray_start_4_cpus ):
297+ """Test that a single ExecutionOptions object applies to all datasets."""
298+ NUM_ROWS = 100
299+ NUM_WORKERS = 2
300+
301+ train_ds = ray .data .range (NUM_ROWS )
302+ val_ds = ray .data .range (NUM_ROWS )
303+
304+ # Create execution options with specific settings
305+ execution_options = ExecutionOptions ()
306+ execution_options .preserve_order = True
307+ execution_options .verbose_progress = True
308+
309+ data_config = ray .train .DataConfig (execution_options = execution_options )
310+
311+ def train_fn ():
312+ train_shard = ray .train .get_dataset_shard ("train" )
313+ val_shard = ray .train .get_dataset_shard ("val" )
314+
315+ # Verify both datasets have the same execution options
316+ assert train_shard .get_context ().execution_options .preserve_order is True
317+ assert train_shard .get_context ().execution_options .verbose_progress is True
318+ assert val_shard .get_context ().execution_options .preserve_order is True
319+ assert val_shard .get_context ().execution_options .verbose_progress is True
320+
321+ trainer = DataParallelTrainer (
322+ train_fn ,
323+ datasets = {"train" : train_ds , "val" : val_ds },
324+ dataset_config = data_config ,
325+ scaling_config = ray .train .ScalingConfig (num_workers = NUM_WORKERS ),
326+ )
327+ trainer .fit ()
328+
329+
330+ def test_per_dataset_execution_options_dict (ray_start_4_cpus ):
331+ """Test that a dict of ExecutionOptions maps to specific datasets, and datasets not in the dict get default ingest options. Also tests resource limits."""
332+ NUM_ROWS = 100
333+ NUM_WORKERS = 2
334+
335+ train_ds = ray .data .range (NUM_ROWS )
336+ val_ds = ray .data .range (NUM_ROWS )
337+ test_ds = ray .data .range (NUM_ROWS )
338+ test_ds_2 = ray .data .range (NUM_ROWS )
339+
340+ # Create different execution options for different datasets
341+ train_options = ExecutionOptions ()
342+ train_options .preserve_order = True
343+ train_options .verbose_progress = True
344+ train_options .resource_limits = train_options .resource_limits .copy (cpu = 4 , gpu = 2 )
345+
346+ val_options = ExecutionOptions ()
347+ val_options .preserve_order = False
348+ val_options .verbose_progress = False
349+ val_options .resource_limits = val_options .resource_limits .copy (cpu = 2 , gpu = 1 )
350+
351+ execution_options_dict = {
352+ "train" : train_options ,
353+ "val" : val_options ,
354+ }
355+
356+ data_config = ray .train .DataConfig (execution_options = execution_options_dict )
357+
358+ def train_fn ():
359+ train_shard = ray .train .get_dataset_shard ("train" )
360+ val_shard = ray .train .get_dataset_shard ("val" )
361+ test_shard = ray .train .get_dataset_shard ("test" )
362+ test_shard_2 = ray .train .get_dataset_shard ("test_2" )
363+
364+ # Verify each dataset in the dict gets its specific options
365+ assert train_shard .get_context ().execution_options .preserve_order is True
366+ assert train_shard .get_context ().execution_options .verbose_progress is True
367+ assert val_shard .get_context ().execution_options .preserve_order is False
368+ assert val_shard .get_context ().execution_options .verbose_progress is False
369+
370+ # Verify resource limits
371+ assert train_shard .get_context ().execution_options .resource_limits .cpu == 4
372+ assert train_shard .get_context ().execution_options .resource_limits .gpu == 2
373+ assert val_shard .get_context ().execution_options .resource_limits .cpu == 2
374+ assert val_shard .get_context ().execution_options .resource_limits .gpu == 1
375+
376+ # Verify dataset not in the dict gets default options
377+ assert (
378+ test_shard .get_context ().execution_options .preserve_order
379+ == test_shard_2 .get_context ().execution_options .preserve_order
380+ )
381+ assert (
382+ test_shard .get_context ().execution_options .verbose_progress
383+ == test_shard_2 .get_context ().execution_options .verbose_progress
384+ )
385+ assert (
386+ test_shard .get_context ().execution_options .resource_limits .cpu
387+ == test_shard_2 .get_context ().execution_options .resource_limits .cpu
388+ )
389+ assert (
390+ test_shard .get_context ().execution_options .resource_limits .gpu
391+ == test_shard_2 .get_context ().execution_options .resource_limits .gpu
392+ )
393+
394+ trainer = DataParallelTrainer (
395+ train_fn ,
396+ datasets = {
397+ "train" : train_ds ,
398+ "val" : val_ds ,
399+ "test" : test_ds ,
400+ "test_2" : test_ds_2 ,
401+ },
402+ dataset_config = data_config ,
403+ scaling_config = ray .train .ScalingConfig (num_workers = NUM_WORKERS ),
404+ )
405+ trainer .fit ()
406+
407+
408+ def test_exclude_train_resources_applies_to_each_dataset (ray_start_4_cpus ):
409+ """Test that the default behavior of excluding train worker resources
410+ applies to each dataset individually when using per-dataset execution options."""
411+ NUM_ROWS = 100
412+ NUM_WORKERS = 2
413+
414+ # Create different execution options for different datasets
415+ train_options = ExecutionOptions ()
416+ train_options .exclude_resources = train_options .exclude_resources .copy (cpu = 2 , gpu = 1 )
417+
418+ test_options = ExecutionOptions ()
419+ test_options .exclude_resources = test_options .exclude_resources .copy (cpu = 1 , gpu = 0 )
420+
421+ # val dataset not in dict, should get default options
422+ execution_options_dict = {
423+ "train" : train_options ,
424+ "test" : test_options ,
425+ }
426+ data_config = ray .train .DataConfig (execution_options = execution_options_dict )
427+
428+ def train_fn ():
429+ # Check that each dataset has the train resources excluded,
430+ # in addition to any per-dataset exclude_resources.
431+
432+ # Check train dataset
433+ train_ds = ray .train .get_dataset_shard ("train" )
434+ train_exec_options = train_ds .get_context ().execution_options
435+ assert train_exec_options .is_resource_limits_default ()
436+ # Train worker resources: NUM_WORKERS CPUs (default 1 CPU per worker)
437+ expected_train_cpu = NUM_WORKERS + 2 # 2 from user-defined
438+ expected_train_gpu = 0 + 1 # 1 from user-defined (no GPUs allocated)
439+ assert train_exec_options .exclude_resources .cpu == expected_train_cpu
440+ assert train_exec_options .exclude_resources .gpu == expected_train_gpu
441+
442+ # Check test dataset
443+ test_ds = ray .train .get_dataset_shard ("test" )
444+ test_exec_options = test_ds .get_context ().execution_options
445+ assert test_exec_options .is_resource_limits_default ()
446+ expected_test_cpu = NUM_WORKERS + 1 # 1 from user-defined
447+ expected_test_gpu = 0 + 0 # 0 from user-defined
448+ assert test_exec_options .exclude_resources .cpu == expected_test_cpu
449+ assert test_exec_options .exclude_resources .gpu == expected_test_gpu
450+
451+ # Check val dataset (should have default + train resources excluded)
452+ val_ds = ray .train .get_dataset_shard ("val" )
453+ val_exec_options = val_ds .get_context ().execution_options
454+ assert val_exec_options .is_resource_limits_default ()
455+ default_options = ray .train .DataConfig .default_ingest_options ()
456+ expected_val_cpu = NUM_WORKERS + default_options .exclude_resources .cpu
457+ expected_val_gpu = 0 + default_options .exclude_resources .gpu
458+ assert val_exec_options .exclude_resources .cpu == expected_val_cpu
459+ assert val_exec_options .exclude_resources .gpu == expected_val_gpu
460+
461+ trainer = DataParallelTrainer (
462+ train_fn ,
463+ datasets = {
464+ "train" : ray .data .range (NUM_ROWS ),
465+ "test" : ray .data .range (NUM_ROWS ),
466+ "val" : ray .data .range (NUM_ROWS ),
467+ },
468+ dataset_config = data_config ,
469+ scaling_config = ray .train .ScalingConfig (num_workers = NUM_WORKERS ),
470+ )
471+ trainer .fit ()
472+
473+
296474if __name__ == "__main__" :
297475 import sys
298476
0 commit comments