From 25b36c4c3c973d0a2a2a6472a4728295be8951c3 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 13 Feb 2024 15:07:10 -0500 Subject: [PATCH 01/11] Deprecate and introduce dataloader_config --- src/accelerate/accelerator.py | 46 ++++++++++++++++++++++++++--- src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/dataclasses.py | 42 ++++++++++++++++++++++++++ tests/test_accelerator.py | 19 ++++++++++++ 4 files changed, 104 insertions(+), 4 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index fb60473e47c..3c1d1a0cd20 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -47,6 +47,7 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, AutocastKwargs, + DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, @@ -257,6 +258,7 @@ def __init__( mixed_precision: PrecisionType | str | None = None, gradient_accumulation_steps: int = 1, cpu: bool = False, + dataloader_config: DataLoaderConfiguration | None = None, deepspeed_plugin: DeepSpeedPlugin | None = None, fsdp_plugin: FullyShardedDataParallelPlugin | None = None, megatron_lm_plugin: MegatronLMPlugin | None = None, @@ -421,10 +423,30 @@ def __init__( ) self.device_placement = device_placement - self.split_batches = split_batches - self.dispatch_batches = dispatch_batches - self.even_batches = even_batches - self.use_seedable_sampler = use_seedable_sampler + if dataloader_config is None: + self.dataloader_config = DataLoaderConfiguration() + # Deal with deprecated args + deprecated_dl_args = {} + if dispatch_batches is not None: + deprecated_dl_args["dispatch_batches"] = dispatch_batches + self.dataloader_config.dispatch_batches = dispatch_batches + if split_batches is not True: + deprecated_dl_args["split_batches"] = split_batches + self.dataloader_config.split_batches = split_batches + if not even_batches: + deprecated_dl_args["even_batches"] = even_batches + self.dataloader_config.even_batches = even_batches + if use_seedable_sampler: + deprecated_dl_args["use_seedable_sampler"] = use_seedable_sampler + self.dataloader_config.use_seedable_sampler = use_seedable_sampler + if len(deprecated_dl_args) > 0: + values = ", ".join([f"{k}={v}" for k, v in deprecated_dl_args.items()]) + warnings.warn( + f"Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: {deprecated_dl_args.keys()}. " + "Please use `DataLoaderConfiguration` instead: \n" + f"dataloader_config = DataLoaderConfiguration({values})", + FutureWarning, + ) self.step_scheduler_with_optimizer = step_scheduler_with_optimizer # Mixed precision attributes @@ -511,6 +533,22 @@ def local_process_index(self): def device(self): return self.state.device + @property + def split_batches(self): + return self.dataloader_config.split_batches + + @property + def dispatch_batches(self): + return self.dataloader_config.dispatch_batches + + @property + def even_batches(self): + return self.dataloader_config.even_batches + + @property + def use_seedable_sampler(self): + return self.dataloader_config.use_seedable_sampler + @property def project_dir(self): return self.project_configuration.project_dir diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 89034589499..d042cb8f51f 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -18,6 +18,7 @@ BnbQuantizationConfig, ComputeEnvironment, CustomDtype, + DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 730c777fed6..f7b5a267a28 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -454,6 +454,48 @@ class TensorInformation: dtype: torch.dtype +@dataclass +class DataLoaderConfiguration: + """ + Configuration for dataloader-related items when calling `accelerator.prepare`. + """ + + split_batches: bool = field( + default=False, + metadata={ + "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If" + " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a" + " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set" + " in your script multiplied by the number of processes." + }, + ) + dispatch_batches: bool = field( + default=None, + metadata={ + "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" + " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" + " underlying dataset is an `IterableDataslet`, `False` otherwise." + }, + ) + even_batches: bool = field( + default=True, + metadata={ + "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the" + " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among" + " all workers." + }, + ) + use_seedable_sampler: bool = field( + default=False, + metadata={ + "help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])." + "Ensures training results are fully reproducable using a different sampling technique. " + "While seed-to-seed results may differ, on average the differences are neglible when using" + "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." + }, + ) + + @dataclass class ProjectConfiguration: """ diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 78d9c359147..b9f123a1019 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -55,6 +55,25 @@ def parameterized_custom_name_func(func, param_num, param): class AcceleratorTester(AccelerateTestCase): + # Should be removed after 1.0.0 release + def test_deprecated_values(self): + with self.assertWarns(FutureWarning) as cm: + accelerator = Accelerator( + dispatch_batches=True, + split_batches=False, + even_batches=False, + use_seedable_sampler=True, + ) + deprecation_warning = cm.warning.args[0] + assert "dispatch_batches" in deprecation_warning + assert accelerator.dispatch_batches is True + assert "split_batches" in deprecation_warning + assert accelerator.split_batches is False + assert "even_batches" in deprecation_warning + assert accelerator.even_batches is False + assert "use_seedable_sampler" in deprecation_warning + assert accelerator.use_seedable_sampler is True + @require_non_cpu def test_accelerator_can_be_reinstantiated(self): _ = Accelerator() From 4f25a9eb47ae9edba5eaafb5f035f036182c8862 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 13 Feb 2024 15:09:55 -0500 Subject: [PATCH 02/11] Update docs --- docs/source/quicktour.md | 2 +- src/accelerate/accelerator.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 6271dc41457..8bb1ddce4e9 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -83,7 +83,7 @@ is shuffled the same way (if you decided to use `shuffle=True` or any kind of ra your script. For instance, training on 4 GPUs with a batch size of 16 set when creating the training dataloader will train at an actual batch size of 64 (4 * 16). If you want the batch size remain the same regardless of how many GPUs the script is run on, you can use the - option `split_batches=True` when creating and initializing [`Accelerator`]. + option `split_batches=True` when creating and initializing [`Accelerator`] by passing in a [`utils.DataLoaderConfig`]. Your training dataloader may change length when going through this method: if you run on X GPUs, it will have its length divided by X (since your actual batch size will be multiplied by X), unless you set `split_batches=True`. diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 3c1d1a0cd20..e30eebc1f40 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -163,7 +163,8 @@ class Accelerator: Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set - in your script multiplied by the number of processes. + in your script multiplied by the number of processes. Will be deprecated in version 1.0 of Accelerate, + please use the `utils.DataLoaderConfig`. mixed_precision (`str`, *optional*): Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the @@ -213,16 +214,18 @@ class Accelerator: dispatch_batches (`bool`, *optional*): If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose - underlying dataset is an `IterableDataset`, `False` otherwise. + underlying dataset is an `IterableDataset`, `False` otherwise. Will be deprecated in version 1.0 of + Accelerate, please use the `utils.DataLoaderConfig`. even_batches (`bool`, *optional*, defaults to `True`): If set to `True`, in cases where the total batch size across all processes does not exactly divide the dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among - all workers. + all workers. Will be deprecated in version 1.0 of Accelerate, please use the `utils.DataLoaderConfig`. use_seedable_sampler (`bool`, *optional*, defaults to `False`): Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures training results are fully reproducable using a different sampling technique. While seed-to-seed results may differ, on average the differences are neglible when using multiple different seeds to compare. Should - also be ran with [`~utils.set_seed`] each time for the best results. + also be ran with [`~utils.set_seed`] each time for the best results. Will be deprecated in version 1.0 of + Accelerate, please use the `utils.DataLoaderConfig`. step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only done under certain circumstances (at the end of each epoch, for instance). From 6816a5eefeba6205c1984f5489deec78a42b9669 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 13 Feb 2024 15:15:58 -0500 Subject: [PATCH 03/11] Doc nits --- docs/source/package_reference/utilities.md | 2 ++ src/accelerate/accelerator.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/source/package_reference/utilities.md b/docs/source/package_reference/utilities.md index 9afc22874a0..7bfaaaf985c 100644 --- a/docs/source/package_reference/utilities.md +++ b/docs/source/package_reference/utilities.md @@ -95,6 +95,8 @@ These are classes which can be configured and passed through to the appropriate [[autodoc]] utils.BnbQuantizationConfig +[[autodoc]] utils.DataLoaderConfiguration + [[autodoc]] utils.ProjectConfiguration ## Environmental Variables diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index e30eebc1f40..4b5f1c4d9f0 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -164,7 +164,7 @@ class Accelerator: `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set in your script multiplied by the number of processes. Will be deprecated in version 1.0 of Accelerate, - please use the `utils.DataLoaderConfig`. + please use the [`utils.DataLoaderConfiguration`]. mixed_precision (`str`, *optional*): Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the @@ -215,17 +215,18 @@ class Accelerator: If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose underlying dataset is an `IterableDataset`, `False` otherwise. Will be deprecated in version 1.0 of - Accelerate, please use the `utils.DataLoaderConfig`. + Accelerate, please use the [`utils.DataLoaderConfiguration`]. even_batches (`bool`, *optional*, defaults to `True`): If set to `True`, in cases where the total batch size across all processes does not exactly divide the dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among - all workers. Will be deprecated in version 1.0 of Accelerate, please use the `utils.DataLoaderConfig`. + all workers. Will be deprecated in version 1.0 of Accelerate, please use the + [`utils.DataLoaderConfiguration`]. use_seedable_sampler (`bool`, *optional*, defaults to `False`): Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures training results are fully reproducable using a different sampling technique. While seed-to-seed results may differ, on average the differences are neglible when using multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] each time for the best results. Will be deprecated in version 1.0 of - Accelerate, please use the `utils.DataLoaderConfig`. + Accelerate, please use the [`utils.DataLoaderConfiguration`]. step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only done under certain circumstances (at the end of each epoch, for instance). From f21d81713793f73e147e1079febf7f20f440571f Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 14 Feb 2024 09:32:11 -0500 Subject: [PATCH 04/11] More tests, adjust based on PR review --- src/accelerate/__init__.py | 1 + src/accelerate/accelerator.py | 49 +++++++++++++---------------------- tests/test_accelerator.py | 42 +++++++++++++++++++++--------- 3 files changed, 49 insertions(+), 43 deletions(-) diff --git a/src/accelerate/__init__.py b/src/accelerate/__init__.py index 70b897ca916..91e84c02792 100644 --- a/src/accelerate/__init__.py +++ b/src/accelerate/__init__.py @@ -16,6 +16,7 @@ from .state import PartialState from .utils import ( AutocastKwargs, + DataLoaderConfiguration, DeepSpeedPlugin, DistributedDataParallelKwargs, DistributedType, diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 4b5f1c4d9f0..2c7ca9c48c8 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -150,6 +150,12 @@ logger = get_logger(__name__) +# Sentinel values for defaults +_split_batches = object() +_dispatch_batches = object() +_even_batches = object() +_use_seedable_sampler = object() + class Accelerator: """ @@ -159,12 +165,6 @@ class Accelerator: device_placement (`bool`, *optional*, defaults to `True`): Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model, etc...). - split_batches (`bool`, *optional*, defaults to `False`): - Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If - `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a - round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set - in your script multiplied by the number of processes. Will be deprecated in version 1.0 of Accelerate, - please use the [`utils.DataLoaderConfiguration`]. mixed_precision (`str`, *optional*): Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the @@ -177,6 +177,8 @@ class Accelerator: cpu (`bool`, *optional*): Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force the execution on one process only. + dataloader_config (`DataLoaderConfiguration`, *optional*): + A configuration for how the dataloaders should be handled in distributed scenarios. deepspeed_plugin (`DeepSpeedPlugin`, *optional*): Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured directly using *accelerate config* @@ -211,22 +213,6 @@ class Accelerator: project_dir (`str`, `os.PathLike`, *optional*): A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved checkpoints. - dispatch_batches (`bool`, *optional*): - If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process - and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose - underlying dataset is an `IterableDataset`, `False` otherwise. Will be deprecated in version 1.0 of - Accelerate, please use the [`utils.DataLoaderConfiguration`]. - even_batches (`bool`, *optional*, defaults to `True`): - If set to `True`, in cases where the total batch size across all processes does not exactly divide the - dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among - all workers. Will be deprecated in version 1.0 of Accelerate, please use the - [`utils.DataLoaderConfiguration`]. - use_seedable_sampler (`bool`, *optional*, defaults to `False`): - Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Ensures - training results are fully reproducable using a different sampling technique. While seed-to-seed results - may differ, on average the differences are neglible when using multiple different seeds to compare. Should - also be ran with [`~utils.set_seed`] each time for the best results. Will be deprecated in version 1.0 of - Accelerate, please use the [`utils.DataLoaderConfiguration`]. step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only done under certain circumstances (at the end of each epoch, for instance). @@ -258,7 +244,7 @@ class Accelerator: def __init__( self, device_placement: bool = True, - split_batches: bool = False, + split_batches: bool = _split_batches, mixed_precision: PrecisionType | str | None = None, gradient_accumulation_steps: int = 1, cpu: bool = False, @@ -271,9 +257,9 @@ def __init__( project_dir: str | os.PathLike | None = None, project_config: ProjectConfiguration | None = None, gradient_accumulation_plugin: GradientAccumulationPlugin | None = None, - dispatch_batches: bool | None = None, - even_batches: bool = True, - use_seedable_sampler: bool = False, + dispatch_batches: bool | None = _dispatch_batches, + even_batches: bool = _even_batches, + use_seedable_sampler: bool = _use_seedable_sampler, step_scheduler_with_optimizer: bool = True, kwargs_handlers: list[KwargsHandler] | None = None, dynamo_backend: DynamoBackend | str | None = None, @@ -430,24 +416,25 @@ def __init__( if dataloader_config is None: self.dataloader_config = DataLoaderConfiguration() # Deal with deprecated args + # TODO: Remove in v1.0.0 deprecated_dl_args = {} - if dispatch_batches is not None: + if dispatch_batches is not _dispatch_batches: deprecated_dl_args["dispatch_batches"] = dispatch_batches self.dataloader_config.dispatch_batches = dispatch_batches - if split_batches is not True: + if split_batches is not _split_batches: deprecated_dl_args["split_batches"] = split_batches self.dataloader_config.split_batches = split_batches - if not even_batches: + if even_batches is not _even_batches: deprecated_dl_args["even_batches"] = even_batches self.dataloader_config.even_batches = even_batches - if use_seedable_sampler: + if use_seedable_sampler is not _use_seedable_sampler: deprecated_dl_args["use_seedable_sampler"] = use_seedable_sampler self.dataloader_config.use_seedable_sampler = use_seedable_sampler if len(deprecated_dl_args) > 0: values = ", ".join([f"{k}={v}" for k, v in deprecated_dl_args.items()]) warnings.warn( f"Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: {deprecated_dl_args.keys()}. " - "Please use `DataLoaderConfiguration` instead: \n" + "Please pass an `accelerate.DataLoaderConfiguration` instead: \n" f"dataloader_config = DataLoaderConfiguration({values})", FutureWarning, ) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index b9f123a1019..0d00b0c28d0 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -57,22 +57,40 @@ def parameterized_custom_name_func(func, param_num, param): class AcceleratorTester(AccelerateTestCase): # Should be removed after 1.0.0 release def test_deprecated_values(self): - with self.assertWarns(FutureWarning) as cm: + # Test defaults + accelerator = Accelerator() + assert accelerator.split_batches is False, "split_batches should be False by default" + assert accelerator.dispatch_batches is None, "dispatch_batches should be None by default" + assert accelerator.even_batches is True, "even_batches should be True by default" + assert accelerator.use_seedable_sampler is False, "use_seedable_sampler should be False by default" + + # Pass some arguments only + with self.assertWarnsRegex(FutureWarning) as cm: accelerator = Accelerator( dispatch_batches=True, split_batches=False, - even_batches=False, - use_seedable_sampler=True, ) - deprecation_warning = cm.warning.args[0] - assert "dispatch_batches" in deprecation_warning - assert accelerator.dispatch_batches is True - assert "split_batches" in deprecation_warning - assert accelerator.split_batches is False - assert "even_batches" in deprecation_warning - assert accelerator.even_batches is False - assert "use_seedable_sampler" in deprecation_warning - assert accelerator.use_seedable_sampler is True + deprecation_warning = cm.warnings[0].message.args[0] + assert accelerator.split_batches is False, "split_batches should be True" + assert accelerator.dispatch_batches is True, "dispatch_batches should be True" + assert accelerator.even_batches is True, "even_batches should be True by default" + assert accelerator.use_seedable_sampler is False, "use_seedable_sampler should be False by default" + assert "dispatch_batches" in deprecation_warning + assert "split_batches" in deprecation_warning + assert "even_batches" not in deprecation_warning + assert "use_seedable_sampler" not in deprecation_warning + + # Pass in some arguments, but with their defaults + with self.assertWarns(FutureWarning) as cm: + accelerator = Accelerator( + even_batches=True, + use_seedable_sampler=False, + ) + deprecation_warning = cm.warnings[0].message.args[0] + assert "even_batches" in deprecation_warning + assert accelerator.even_batches is True + assert "use_seedable_sampler" in deprecation_warning + assert accelerator.use_seedable_sampler is False @require_non_cpu def test_accelerator_can_be_reinstantiated(self): From cbe95af531acd9dc250bf9b0bbd4f1b4441d77f4 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 14 Feb 2024 09:39:27 -0500 Subject: [PATCH 05/11] Fixup tests --- .../scripts/external_deps/test_metrics.py | 11 ++++++---- .../test_utils/scripts/test_script.py | 20 +++++++++++++------ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py index 50d6be063bb..88fc5b07c0a 100755 --- a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -25,7 +25,7 @@ from torch.utils.data import DataLoader, IterableDataset from transformers import AutoModelForSequenceClassification, AutoTokenizer -from accelerate import Accelerator +from accelerate import Accelerator, DataLoaderConfiguration from accelerate.data_loader import DataLoaderDispatcher from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device from accelerate.utils import set_seed @@ -81,7 +81,8 @@ def collate_fn(examples): def get_mrpc_setup(dispatch_batches, split_batches): - accelerator = Accelerator(dispatch_batches=dispatch_batches, split_batches=split_batches) + dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches) + accelerator = Accelerator(data_loader_config=dataloader_config) dataloader = get_dataloader(accelerator, not dispatch_batches) model = AutoModelForSequenceClassification.from_pretrained( "hf-internal-testing/mrpc-bert-base-cased", return_dict=True @@ -242,7 +243,8 @@ def test_gather_for_metrics_drop_last(): def main(): - accelerator = Accelerator(split_batches=False, dispatch_batches=False) + dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False) + accelerator = Accelerator(dataloader_config=dataloader_config) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() @@ -267,7 +269,8 @@ def main(): print("**Test torch metrics**") for split_batches in [True, False]: for dispatch_batches in [True, False]: - accelerator = Accelerator(split_batches=split_batches, dispatch_batches=dispatch_batches) + dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches) + accelerator = Accelerator(dataloader_config=dataloader_config) if accelerator.is_local_main_process: print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99") test_torch_metrics(accelerator, 99) diff --git a/src/accelerate/test_utils/scripts/test_script.py b/src/accelerate/test_utils/scripts/test_script.py index 6ffccc16592..e3fc7f50d43 100644 --- a/src/accelerate/test_utils/scripts/test_script.py +++ b/src/accelerate/test_utils/scripts/test_script.py @@ -30,6 +30,7 @@ from accelerate.state import AcceleratorState from accelerate.test_utils import RegressionDataset, are_the_same_tensors from accelerate.utils import ( + DataLoaderConfiguration, DistributedType, gather, is_bf16_available, @@ -355,7 +356,9 @@ def check_seedable_sampler(): set_seed(42) train_set = RegressionDataset(length=10, seed=42) train_dl = DataLoader(train_set, batch_size=2, shuffle=True) - accelerator = Accelerator(use_seedable_sampler=True) + + config = DataLoaderConfiguration(use_seedable_sampler=True) + accelerator = Accelerator(dataloader_config=config) train_dl = accelerator.prepare(train_dl) original_items = [] for _ in range(3): @@ -424,7 +427,8 @@ def training_check(use_seedable_sampler=False): accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.") - accelerator = Accelerator(split_batches=True, use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader( train_set, generator, batch_size * state.num_processes, use_seedable_sampler ) @@ -452,7 +456,8 @@ def training_check(use_seedable_sampler=False): # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16 print("FP16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="fp16", use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="fp16", dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -492,7 +497,8 @@ def training_check(use_seedable_sampler=False): # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16 print("BF16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="bf16", use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -516,7 +522,8 @@ def training_check(use_seedable_sampler=False): if is_ipex_available(): print("ipex BF16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="bf16", cpu=True, use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", cpu=True, dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) @@ -540,7 +547,8 @@ def training_check(use_seedable_sampler=False): if is_xpu_available(): print("xpu BF16 training check.") AcceleratorState._reset_state() - accelerator = Accelerator(mixed_precision="bf16", cpu=False, use_seedable_sampler=use_seedable_sampler) + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", cpu=False, dataloader_config=dataloader_config) train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) model = RegressionModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) From a51634f78a948160e3bd5b7a8e6a2f8c7a0bc173 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 14 Feb 2024 10:36:32 -0500 Subject: [PATCH 06/11] Nits --- src/accelerate/accelerator.py | 4 ++++ .../test_utils/scripts/test_distributed_data_loop.py | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 2c7ca9c48c8..a633fdd90f6 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -536,6 +536,10 @@ def dispatch_batches(self): def even_batches(self): return self.dataloader_config.even_batches + @even_batches.setter + def even_batches(self, value: bool): + self.dataloader_config.even_batches = value + @property def use_seedable_sampler(self): return self.dataloader_config.use_seedable_sampler diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 850f7310f79..606ade2d065 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -22,7 +22,7 @@ import torch from torch.utils.data import DataLoader, IterableDataset, TensorDataset -from accelerate.accelerator import Accelerator +from accelerate.accelerator import Accelerator, DataLoaderConfiguration from accelerate.utils.dataclasses import DistributedType @@ -36,7 +36,8 @@ def __iter__(self): def create_accelerator(even_batches=True): - accelerator = Accelerator(even_batches=even_batches) + dataloader_config = DataLoaderConfiguration(even_batches=even_batches) + accelerator = Accelerator(dataloader_config=dataloader_config) assert accelerator.num_processes == 2, "this script expects that two GPUs are available" return accelerator From 74d305ab5bf6b686cf814c7e10f70e3babe0fa3e Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 14 Feb 2024 10:38:38 -0500 Subject: [PATCH 07/11] Update docs/source/quicktour.md Co-authored-by: Benjamin Bossan --- docs/source/quicktour.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 8bb1ddce4e9..bf2640d3f81 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -83,7 +83,7 @@ is shuffled the same way (if you decided to use `shuffle=True` or any kind of ra your script. For instance, training on 4 GPUs with a batch size of 16 set when creating the training dataloader will train at an actual batch size of 64 (4 * 16). If you want the batch size remain the same regardless of how many GPUs the script is run on, you can use the - option `split_batches=True` when creating and initializing [`Accelerator`] by passing in a [`utils.DataLoaderConfig`]. + option `split_batches=True` when creating and initializing [`Accelerator`] by passing in a [`utils.DataLoaderConfiguration`]. Your training dataloader may change length when going through this method: if you run on X GPUs, it will have its length divided by X (since your actual batch size will be multiplied by X), unless you set `split_batches=True`. From 1e66a06361c7f6264feb1e6df943d5f3b92bc88a Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 14 Feb 2024 10:54:08 -0500 Subject: [PATCH 08/11] Clean --- .../test_utils/scripts/external_deps/test_metrics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py index d99f60972ba..d41e5cd5964 100755 --- a/src/accelerate/test_utils/scripts/external_deps/test_metrics.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -25,7 +25,7 @@ from torch.utils.data import DataLoader, IterableDataset from transformers import AutoModelForSequenceClassification, AutoTokenizer -from accelerate import Accelerator, DistributedType, DataLoaderConfiguration +from accelerate import Accelerator, DataLoaderConfiguration, DistributedType from accelerate.data_loader import DataLoaderDispatcher from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device from accelerate.utils import is_torch_xla_available, set_seed @@ -278,7 +278,9 @@ def main(): print("**Test torch metrics**") for split_batches in [True, False]: for dispatch_batches in dispatch_batches_options: - dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches) + dataloader_config = DataLoaderConfiguration( + split_batches=split_batches, dispatch_batches=dispatch_batches + ) accelerator = Accelerator(dataloader_config=dataloader_config) if accelerator.is_local_main_process: print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99") From 443745afb8cbc13be8e658c85ec43aedd2799ad8 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 14 Feb 2024 10:55:43 -0500 Subject: [PATCH 09/11] Actually create one --- src/accelerate/accelerator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 661f66b57f1..26743c089b2 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -415,7 +415,8 @@ def __init__( self.device_placement = device_placement if dataloader_config is None: - self.dataloader_config = DataLoaderConfiguration() + dataloader_config = DataLoaderConfiguration() + self.dataloader_config = dataloader_config # Deal with deprecated args # TODO: Remove in v1.0.0 deprecated_dl_args = {} From 06cde6ea23a6a33d33fce9769c4414694f3744d5 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 14 Feb 2024 11:25:08 -0500 Subject: [PATCH 10/11] Forgot to change one --- tests/test_accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 7f19d2d63e7..0d3ea877974 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -65,7 +65,7 @@ def test_deprecated_values(self): assert accelerator.use_seedable_sampler is False, "use_seedable_sampler should be False by default" # Pass some arguments only - with self.assertWarnsRegex(FutureWarning) as cm: + with self.assertWarns(FutureWarning) as cm: accelerator = Accelerator( dispatch_batches=True, split_batches=False, From 1ddad81720b60d2620c8bc552d17223211239edd Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 14 Feb 2024 12:24:16 -0500 Subject: [PATCH 11/11] Use pytest --- tests/test_accelerator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 0d3ea877974..c861f77e877 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -4,6 +4,7 @@ import tempfile from unittest.mock import patch +import pytest import torch from parameterized import parameterized from torch.utils.data import DataLoader, TensorDataset @@ -65,12 +66,12 @@ def test_deprecated_values(self): assert accelerator.use_seedable_sampler is False, "use_seedable_sampler should be False by default" # Pass some arguments only - with self.assertWarns(FutureWarning) as cm: + with pytest.warns(FutureWarning) as cm: accelerator = Accelerator( dispatch_batches=True, split_batches=False, ) - deprecation_warning = cm.warnings[0].message.args[0] + deprecation_warning = str(cm.list[0].message) assert accelerator.split_batches is False, "split_batches should be True" assert accelerator.dispatch_batches is True, "dispatch_batches should be True" assert accelerator.even_batches is True, "even_batches should be True by default" @@ -81,12 +82,12 @@ def test_deprecated_values(self): assert "use_seedable_sampler" not in deprecation_warning # Pass in some arguments, but with their defaults - with self.assertWarns(FutureWarning) as cm: + with pytest.warns(FutureWarning) as cm: accelerator = Accelerator( even_batches=True, use_seedable_sampler=False, ) - deprecation_warning = cm.warnings[0].message.args[0] + deprecation_warning = str(cm.list[0].message) assert "even_batches" in deprecation_warning assert accelerator.even_batches is True assert "use_seedable_sampler" in deprecation_warning