From 65a3bc0beb535e809d8ee0015b2331f3c6f4d2ea Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 12 Jun 2025 14:53:39 +0000 Subject: [PATCH 1/3] test --- src/accelerate/data_loader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 3163312b568..87dabc5f105 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1137,6 +1137,11 @@ def prepare_data_loader( process_index = process_index // submesh_tp_size num_processes = submesh_fsdp_size * submesh_dp_size + # put cp size here + cp_size = + process_index = process_index // cp_size + num_processes = num_processes // cp_size + # Sanity check if split_batches: if dataloader.batch_size is not None: From 7729f4404007d37a5d92deb83b1935337f314a20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 12 Jun 2025 17:23:23 +0000 Subject: [PATCH 2/3] some modif --- src/accelerate/accelerator.py | 5 +++++ src/accelerate/data_loader.py | 8 ++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a7b9469dfb3..eeae1bbde97 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -648,6 +648,10 @@ def split_batches(self): def dispatch_batches(self): return self.dataloader_config.dispatch_batches + @property + def cp(self): + return self.dataloader_config.cp + @property def even_batches(self): return self.dataloader_config.even_batches @@ -2410,6 +2414,7 @@ def prepare_data_loader( non_blocking=self.non_blocking, use_stateful_dataloader=self.use_stateful_dataloader, torch_device_mesh=device_mesh, + cp=self.cp, ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 87dabc5f105..20d47d8483a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1007,6 +1007,7 @@ def prepare_data_loader( non_blocking: bool = False, use_stateful_dataloader: bool = False, torch_device_mesh=None, + cp=False, ) -> DataLoader: """ Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. @@ -1137,10 +1138,9 @@ def prepare_data_loader( process_index = process_index // submesh_tp_size num_processes = submesh_fsdp_size * submesh_dp_size - # put cp size here - cp_size = - process_index = process_index // cp_size - num_processes = num_processes // cp_size + if cp: + process_index = 0 + num_processes = 1 # Sanity check if split_batches: From 4b5b8381856a1130f9a9823cb71aff70abb57b48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 12 Jun 2025 17:37:50 +0000 Subject: [PATCH 3/3] other modif --- src/accelerate/utils/dataclasses.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 69933998f7f..2529505ef61 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -815,6 +815,7 @@ class DataLoaderConfiguration: " underlying dataset is an `IterableDataset`, `False` otherwise." }, ) + cp: bool = field(default=False) even_batches: bool = field( default=True, metadata={