Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,6 +2038,7 @@ def prepare_data_loader(
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
non_blocking=self.non_blocking,
stateful=self.dataloader_config.stateful,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down Expand Up @@ -3481,7 +3482,7 @@ def skip_first_batches(self, dataloader, num_batches: int = 0):
... ...
```
"""
return skip_first_batches(dataloader, num_batches=num_batches)
return skip_first_batches(dataloader, num_batches=num_batches, stateful=self.dataloader_config.stateful)

def __deepcopy__(self, memo):
logger.info("Deep copying the `Accelerator` object, note that this will point to the same original object.")
Expand Down
158 changes: 136 additions & 22 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_data_structure,
initialize_tensors,
is_torch_version,
is_torchdata_available,
send_to_device,
slice_tensors,
synchronize_rng_states,
Expand Down Expand Up @@ -64,6 +65,21 @@
_PYTORCH_DATALOADER_KWARGS.update(additional_kwargs)


def get_dataloader_init_class(stateful=False):
"""
Internal use only, grabs the correct class to create a dataloader from, either `torch.utils.data.DataLoader` or
`torchdata.stateful_dataloader.StatefulDataLoader`
"""
if not stateful:
return DataLoader
elif not is_torchdata_available():
raise ImportError("Using `stateful=True` requires `torchdata>=0.8.0`; Please do `pip install torchdata -U`")
else:
from torchdata.stateful_dataloader import StatefulDataLoader

return StatefulDataLoader


class SeedableRandomSampler(RandomSampler):
"""
Same as a random sampler, except that in `__iter__` a seed can be used.
Expand Down Expand Up @@ -351,6 +367,73 @@ def __iter__(self):
for i in process_slice:
yield current_batch[i]

class ResumableMixin:
"""
Mixin class that adds helper methods to resume the dataloader from a checkpoint.
"""
def __init__(self, *args, **kwargs):
stateful = kwargs.pop("stateful", False)
snapshot_every_n_steps = kwargs.pop("snapshot_every_n_steps", 1)

super().__init__(*args, **kwargs)
self.stateful = stateful
self.snapshot_every_n_steps = snapshot_every_n_steps
self.next_iter_state = None
# When a state dict is requested before __iter__ is called,
# we create the __iter__ so we can get a copy of the initial state
# from its workers. In these cases, we can avoid using a new multiprocessing
# iterator on the next __iter__ call. This flag is for this case
self._initial_iter_for_state_dict = False


def _get_iterator(self):
if not self.stateful:
return super()._get_iterator()
from torchdata.stateful_dataloader.stateful_dataloader import _StatefulSingleProcessDataLoaderIter, _StatefulMultiProcessingDataLoaderIter
if self.num_workers == 0:
it = _StatefulSingleProcessDataLoaderIter(self, self.next_iter_state)
else:
self.check_worker_number_rationality()
it = _StatefulMultiProcessingDataLoaderIter(self, self.next_iter_state)
self.next_iter_state = None
return it

def __iter__(self):
if not self.stateful:
return super().__iter__()
if self._initial_iter_for_state_dict:
self._initial_iter_for_state_dict = False
elif self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
else:
self._iterator = self._get_iterator()
if self._iterator._finished:
if self.persistent_workers:
self._iterator._reset(self)
else:
self._iterator = self._get_iterator()

return self._iterator

def state_dict(self):
if not self.stateful:
raise ValueError("To get the state_dict, please set `stateful=True` in the `DataLoaderConfiguration`")
if self._iterator is None:
self._iterator = self._get_iterator()
self._initial_iter_for_state_dict = True
return self._iterator.state_dict()

def load_state_dict(self, state_dict):
if not self.stateful:
raise ValueError("To load the state_dict, please set `stateful=True` in the `DataLoaderConfiguration`")
self._iterator = None
self._initial_iter_for_state_dict = False
if state_dict == {}:
return
self.next_iter_state = state_dict

class DataLoaderStateMixin:
"""
Expand Down Expand Up @@ -388,7 +471,7 @@ def end(self):
self.gradient_state._remove_dataloader(self)


class DataLoaderShard(DataLoader, DataLoaderStateMixin):
class DataLoaderShard(DataLoaderStateMixin, ResumableMixin, DataLoader):
"""
Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.

Expand Down Expand Up @@ -428,11 +511,13 @@ def __init__(
rng_types=None,
synchronized_generator=None,
skip_batches=0,
stateful=False,
_drop_last: bool = False,
_non_blocking: bool = False,
**kwargs,
):
super().__init__(dataset, **kwargs)
super().__init__(dataset, stateful=stateful, **kwargs)

self.device = device
self.rng_types = rng_types
self.synchronized_generator = synchronized_generator
Expand Down Expand Up @@ -559,7 +644,7 @@ def batch_sampler(self):
return self._loader.batch_sampler


class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
class DataLoaderDispatcher(DataLoaderStateMixin, ResumableMixin, DataLoader):
"""
Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
process their part of the batch.
Expand Down Expand Up @@ -589,10 +674,11 @@ def __init__(
dataset,
split_batches: bool = False,
skip_batches=0,
stateful=False,
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
**kwargs,
**kwargs
):
shuffle = False
if is_torch_version(">=", "1.11.0"):
Expand All @@ -601,7 +687,7 @@ def __init__(
# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
super().__init__(dataset, **kwargs)
super().__init__(dataset, stateful=stateful, **kwargs)
self.split_batches = split_batches
if shuffle:
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
Expand Down Expand Up @@ -807,6 +893,7 @@ def prepare_data_loader(
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
non_blocking: bool = False,
stateful: bool = False,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -868,6 +955,10 @@ def prepare_data_loader(
non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
stateful (`bool`, *optional*, defaults to `False`):
If set to `True`, will use `torchdata.stateful_dataloader.StatefulDataLoader` when creating the wrapped
`DataLoader` rather than `torch.utils.DataLoader`. This has the added benefit of giving the created
dataloader a `state_dict()` it can be loaded/unloaded from. Requires `torchdata>=0.8.0`.


Returns:
Expand All @@ -880,6 +971,7 @@ def prepare_data_loader(

</Tip>
"""
should_replace_sampler = use_seedable_sampler or stateful
if dispatch_batches is None:
if not put_on_device:
dispatch_batches = False
Expand Down Expand Up @@ -924,17 +1016,26 @@ def prepare_data_loader(
synchronized_generator = None

sampler = get_sampler(dataloader)
if isinstance(sampler, RandomSampler) and use_seedable_sampler:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
# to the `torch.utils.data.RandomSampler` class (if used).
sampler = SeedableRandomSampler(
data_source=sampler.data_source,
replacement=sampler.replacement,
num_samples=sampler._num_samples,
generator=getattr(sampler, "generator", torch.Generator()),
)
if isinstance(sampler, RandomSampler):
if use_seedable_sampler:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
# to the `torch.utils.data.RandomSampler` class (if used).
sampler = SeedableRandomSampler(
data_source=sampler.data_source,
replacement=sampler.replacement,
num_samples=sampler._num_samples,
generator=getattr(sampler, "generator", torch.Generator()),
)
elif stateful:
from torchdata.stateful_dataloader.sampler import RandomSampler as StatefulRandomSampler
sampler = StatefulRandomSampler(
data_source=sampler.data_source,
replacement=sampler.replacement,
num_samples=sampler._num_samples,
generator=getattr(sampler, "generator", torch.Generator()),
)

if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
# isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
Expand Down Expand Up @@ -1001,6 +1102,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
stateful=stateful,
**kwargs,
)
elif sampler_is_batch_sampler:
Expand All @@ -1013,6 +1115,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
synchronized_generator=synchronized_generator,
stateful=stateful,
**kwargs,
)
else:
Expand All @@ -1024,10 +1127,11 @@ def prepare_data_loader(
synchronized_generator=synchronized_generator,
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
stateful=stateful,
**kwargs,
)

if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
if isinstance(sampler, (SeedableRandomSampler, StatefulRandomSampler)) and should_replace_sampler:
dataloader.set_sampler(sampler)
if state.distributed_type == DistributedType.XLA:
return MpDeviceLoaderWrapper(dataloader, device)
Expand Down Expand Up @@ -1056,7 +1160,7 @@ def __len__(self):
return len(self.batch_sampler) - self.skip_batches


class SkipDataLoader(DataLoader):
class SkipDataLoader:
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.

Expand All @@ -1065,21 +1169,27 @@ class SkipDataLoader(DataLoader):
The dataset to use to build this datalaoder.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
stateful (`bool`, *optional*, defaults to `False`):
If set to `True`, will use `torchdata.stateful_dataloader.StatefulDataLoader` when creating the wrapped
`DataLoader` rather than `torch.utils.DataLoader`. This has the added benefit of giving the created
dataloader a `state_dict()` it can be loaded/unloaded from. Requires `torchdata>=0.8.0`.
kwargs:
All other keyword arguments to pass to the regular `DataLoader` initialization.
"""

def __init__(self, dataset, skip_batches=0, **kwargs):
def __init__(self, dataset, skip_batches=0, stateful=False, **kwargs):
super().__init__(dataset, **kwargs)

self.skip_batches = skip_batches


def __iter__(self):
for index, batch in enumerate(super().__iter__()):
if index >= self.skip_batches:
yield batch


def skip_first_batches(dataloader, num_batches=0):
def skip_first_batches(dataloader, num_batches=0, stateful=False):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
"""
Expand Down Expand Up @@ -1122,6 +1232,7 @@ def skip_first_batches(dataloader, num_batches=0):
batch_sampler=new_batch_sampler,
_drop_last=dataloader._drop_last,
**kwargs,
stateful=stateful,
)
elif isinstance(dataloader, DataLoaderShard):
if new_batch_sampler is None:
Expand All @@ -1137,13 +1248,16 @@ def skip_first_batches(dataloader, num_batches=0):
device=dataloader.device,
rng_types=dataloader.rng_types,
synchronized_generator=dataloader.synchronized_generator,
stateful=stateful,
**kwargs,
)
else:
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, stateful=stateful, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
# Choose the appropriate `DataLoader` class
dl_init = get_dataloader_init_class(stateful=stateful)
dataloader = dl_init(dataset, batch_sampler=new_batch_sampler, **kwargs)

return dataloader
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
require_single_gpu,
require_single_xpu,
require_torch_min_version,
require_torchdata,
require_torchvision,
require_tpu,
require_xpu,
Expand Down
8 changes: 8 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
is_timm_available,
is_torch_version,
is_torch_xla_available,
is_torchdata_available,
is_torchvision_available,
is_transformers_available,
is_triton_available,
Expand Down Expand Up @@ -404,6 +405,13 @@ def require_import_timer(test_case):
return unittest.skipUnless(is_import_timer_available(), "test requires tuna interpreter")(test_case)


def require_torchdata(test_case):
"""
Decorator marking a test that requires torchdata installed. These tests are skipped when tuna isn't installed
"""
return unittest.skipUnless(is_torchdata_available(), "test requires torchdata")(test_case)


_atleast_one_tracker_available = (
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
)
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
is_torchdata_available,
is_torchvision_available,
is_transformer_engine_available,
is_transformers_available,
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,10 @@ class DataLoaderConfiguration:
" prepared dataloader has `pin_memory` set to `True` to work properly."
},
)
stateful: bool = field(
default=False,
metadata={"help": "If set to `True`, the dataloader prepared by the Accelerator will be stateful."},
)


@dataclass
Expand Down
Loading