diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a7b9469dfb3..eeae1bbde97 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -648,6 +648,10 @@ def split_batches(self): def dispatch_batches(self): return self.dataloader_config.dispatch_batches + @property + def cp(self): + return self.dataloader_config.cp + @property def even_batches(self): return self.dataloader_config.even_batches @@ -2410,6 +2414,7 @@ def prepare_data_loader( non_blocking=self.non_blocking, use_stateful_dataloader=self.use_stateful_dataloader, torch_device_mesh=device_mesh, + cp=self.cp, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 3163312b568..20d47d8483a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1007,6 +1007,7 @@ def prepare_data_loader( non_blocking: bool = False, use_stateful_dataloader: bool = False, torch_device_mesh=None, + cp=False, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -1137,6 +1138,10 @@ def prepare_data_loader( process_index = process_index // submesh_tp_size num_processes = submesh_fsdp_size * submesh_dp_size + if cp: + process_index = 0 + num_processes = 1 + # Sanity check if split_batches: if dataloader.batch_size is not None: diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 69933998f7f..2529505ef61 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -815,6 +815,7 @@ class DataLoaderConfiguration: " underlying dataset is an `IterableDataset`, `False` otherwise." }, ) + cp: bool = field(default=False) even_batches: bool = field( default=True, metadata={