Skip to content

Commit

Permalink
add fault-tolerance for global random state in map-style datasets (#8950
Browse files Browse the repository at this point in the history
)

Co-authored-by: tchaton <[email protected]>
Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
5 people authored Aug 26, 2021
1 parent 0752bcd commit b13749b
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 72 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))
* Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950))

- Checkpoint saving & loading extensibility:
* Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))
Expand Down
18 changes: 17 additions & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from contextlib import suppress
from typing import Optional
from typing import Any, Dict, Optional

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
Expand All @@ -40,6 +40,8 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] =
self.min_epochs = min_epochs
self.epoch_loop: Optional[TrainingEpochLoop] = None
self.epoch_progress = Progress()
# caches the loaded dataloader state until dataloader objects are available
self._dataloader_state_dict: Dict[str, Any] = {}

@property
def current_epoch(self) -> int:
Expand Down Expand Up @@ -175,6 +177,10 @@ def on_advance_start(self) -> None:
if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch:
self.trainer.reset_train_dataloader(model)

if self._dataloader_state_dict:
self.trainer.train_dataloader.load_state_dict(self._dataloader_state_dict)
self._dataloader_state_dict = {}

# TODO: specify the possible exception
with suppress(Exception):
# set seed for distributed sampler (enables shuffling for each epoch)
Expand Down Expand Up @@ -234,3 +240,13 @@ def should_accumulate(self) -> bool:

def teardown(self) -> None:
self.epoch_loop.teardown()

def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
# FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
return state_dict

def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
151 changes: 93 additions & 58 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@
# limitations under the License.

from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset

from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_cycle_to_next_worker_and_reset,
_find_current_worker,
_find_fast_forward_samplers,
CaptureIterableDataset,
CaptureMapDataset,
IteratorState,
MergedIteratorState,
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -167,6 +170,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle
self.loader = loader
self._loader_iter = None
self.counter = 0
self.state = state

def __iter__(self) -> Any:
"""
Expand All @@ -176,6 +180,7 @@ def __iter__(self) -> Any:
CycleIterator: self
"""
self.counter = 0
self.state.reset()
self._loader_iter = iter(self.loader)
return self

Expand Down Expand Up @@ -205,6 +210,12 @@ def __next__(self) -> Any:
raise StopIteration

self._loader_iter = iter(self.loader)
# if fault tolerant is enabled, we need to patch the iterator to collect the states
# before the batch gets returned.
fetcher = getattr(self.loader, "_lightning_fetcher", None)
if fetcher:
patch_dataloader_iterator(self.loader, self._loader_iter, fetcher)

return next(self._loader_iter)

finally:
Expand Down Expand Up @@ -302,11 +313,6 @@ def __len__(self) -> int:
return self._calc_num_data(self.datasets, self.mode)


class DataLoaderDict(Dict):
# behaves exactly like a dict, this is used to simplify apply_to_collection.
pass


class CombinedLoader:
"""
Combines different dataloaders and allows sampling in parallel.
Expand Down Expand Up @@ -360,80 +366,110 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
self._iterator = None # assigned in __iter__

@staticmethod
def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict:
# find next worker if multiple workers were used
state = _find_current_worker(iterator)
if isinstance(dataloader.dataset, CaptureIterableDataset):
# the sampler state dict are extracted in `CombinedLoaderIterator`
if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None:
state.update(iterator._sampler_state_dict[0])
else:
# fetch directly from fast forward sampler
state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed))
return DataLoaderDict(state)

def state_dict(self, num_batches_processed: int) -> Dict:
def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_completed: int) -> Dict:
if isinstance(dataloader, CycleIterator):
iterator = dataloader._loader_iter
state = getattr(iterator, "state", None) if has_completed else getattr(iterator, "previous_state", None)
if state:
return asdict(state)
return {}

def state_dict(self, has_completed: bool = False) -> Dict:
"""
The state dict includes all states from wrapped dataloaders and their samplers through the
``CaptureIterableDataset`` and fast-forward samplers.
Args:
num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
may have already prefetched more batches by the time a state dict is requested.
has_completed: whether the current state of data fetching is considered completed or not. If it is, the
current state gets returned, otherwise the previously cached state.
"""
if not _fault_tolerant_training():
return DataLoaderDict()

state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
if not _fault_tolerant_training() or self._iterator is None:
return {}

return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn)
return apply_to_collections(
self.loaders,
self._iterator.loader_iters,
(Iterator, DataLoader),
partial(self._state_dict_fn, has_completed=has_completed),
)

def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict) -> None:
# store the samplers state.
# They would be reloaded once the `CombinedIterator` as been created
# and the workers are created.
self._loaders_iter_state_dict = state_dict

def mock_reset_fn(self, *_, **__):
pass

# mock reset call, so we can rotate the `_worker_queue_idx_cycle` to failed worker
# and get the first batch from it
_MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset
_MultiProcessingDataLoaderIter._reset = mock_reset_fn

def on_restart(self, iterator: Iterator):
def on_restart(self, iterator: Iterator) -> None:
if not self._loaders_iter_state_dict:
return

# this happen inside the workers if any were specificied.
def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator:
"""Function used to reload the iterator state before once the workers are created."""

dataloader_to_iter_on = dataloader
if isinstance(dataloader, CycleIterator):
dataloader = dataloader_to_iter_on.loader

dataset = dataloader.dataset

# We reload the states before creating the workers
# The specific type of dataset will then decide if the state should be applied before or after
# spawning the workers
if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]

if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)

# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)
# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)

elif isinstance(dataset, CaptureIterableDataset):
dataset_dict = {
sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()
}
dataset.load_state_dict(dataset_dict)

def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict):
if isinstance(dataloader.dataset, CaptureIterableDataset):
# provide the `state_dict` to the `CaptureIterableDataset`
# as it is responsible for passing down the state to associated `FastForwardSampler`
dataloader.dataset.load_state_dict(state_dict)
else:
# for `Mapping-based` dataset, the `fast_forward_sampler` was attached
# on the dataloader for simplicity
dataloader.fast_forward_sampler.load_state_dict(state_dict)
raise MisconfigurationException(
"This shouldn't happen. Please, open an issue on PyTorch Lightning Github."
)

# We finally spawned the workers if any.
it = iter(dataloader_to_iter_on)

# cycle back the iterator to the failed worker if multiple workers were provided
iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict)
# restore caching state
state = MergedIteratorState.from_state_dict(state_dict)

if isinstance(dataloader.dataset, CaptureIterableDataset):
# remove keys related to iterator
state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")}
# need to re-attach the state dict into the iterator for future collection.
iterator._sampler_state_dict = [state_dict]
return iterator
if isinstance(dataloader_to_iter_on, CycleIterator):
it._loader_iter.state = state
else:
it.state = state
return it

# create an un-existing token, so it doesn't activate for something else than an iterator.
class DataLoaderDict(dict):
pass

# apply the `create_loader_iters` on the collection of `DataLoader / Iterator`.
# each `Iterator` was created from the `DataLoader`.
iterator._loader_iters = apply_to_collections(
self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters
self.loaders,
self._loaders_iter_state_dict,
(Iterable, DataLoaderDict),
create_loader_iters,
wrong_dtype=(Sequence, Mapping),
)

self._loaders_iter_state_dict = None

@property
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of samplers extracting from loaders."""
Expand All @@ -457,7 +493,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
self.loaders = apply_to_collection(
self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)
)

state.reset()

def __iter__(self) -> Any:
Expand Down
38 changes: 34 additions & 4 deletions pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial, wraps
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset

Expand Down Expand Up @@ -168,6 +172,16 @@ def update(self, generator_name: Optional[str], new_state: IteratorState) -> Non
state[latest_worker_id] = new_state
self.latest_worker_id = latest_worker_id

@property
def sampler_states(self) -> Dict[int, Any]:
"""Returns the merged sampler states for all worker processes."""
return {0: self.state[k].sampler_state[0] for k in self.state.keys()}

@property
def dataset_states(self) -> Dict[int, Any]:
"""Returns the merged dataset states for all worker processes."""
return {k: self.state[k].dataset_state[k] for k in self.state.keys()}

@classmethod
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
if state_dict["represent_map_dataset"]:
Expand All @@ -188,7 +202,12 @@ def __len__(self) -> int:


class CaptureMapDataset(Dataset):
"""This class is used to capture the state from the map-based state dataset."""
"""This class is used to capture the state from the map-based state dataset.
Note:
We currently don't support restoring if we fail during the first `N = num_workers` batches, where
`num_workers` is the number of workers spawned by the dataloader.
"""

def __init__(self, dataset: Dataset) -> None:
self.dataset = dataset
Expand All @@ -202,8 +221,7 @@ def worker_id(self) -> int:
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
if self._cached_state_dict is not None:
if self.worker_id in self._cached_state_dict:
# TODO: reset random states
pass
set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
self._cached_state_dict = None

data = self.dataset[item]
Expand All @@ -227,7 +245,19 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num
self._cached_state_dict = state_dict

def _state_dict(self) -> Dict[int, Dict[str, Any]]:
return {self.worker_id: {"rng_states": {}}}
return {self.worker_id: {"rng_states": collect_rng_states()}}


def collect_rng_states() -> Dict[str, Any]:
"""Collect the global random state of :mod:`torch`, :mod:`numpy` and Python."""
return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()}


def set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
"""Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""
torch.set_rng_state(rng_state_dict.get("torch"))
np.random.set_state(rng_state_dict.get("numpy"))
python_set_rng_state(rng_state_dict.get("python"))


class CaptureIterableDataset(IterableDataset):
Expand Down
Loading

0 comments on commit b13749b

Please sign in to comment.