Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant Manual: Add stateful dataloader iter #10674

Merged
merged 41 commits into from
Nov 23, 2021
Merged
Changes from 31 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ca8dbbb
update
tchaton Nov 19, 2021
3f0d28a
update
tchaton Nov 19, 2021
0c47670
update
tchaton Nov 19, 2021
3e5b52e
update
tchaton Nov 19, 2021
24c8245
update
tchaton Nov 19, 2021
bcd5569
update
tchaton Nov 19, 2021
8d3844d
update
tchaton Nov 19, 2021
1829b46
update
tchaton Nov 19, 2021
a1a364a
typo
tchaton Nov 19, 2021
de41675
update on comments
tchaton Nov 22, 2021
8178a32
Update pytorch_lightning/utilities/auto_restart.py
kaushikb11 Nov 22, 2021
00b9355
update
tchaton Nov 22, 2021
96f0517
update
tchaton Nov 22, 2021
297fd67
Merge branch 'fault_tolerant_enum' of https://github.com/PyTorchLight…
tchaton Nov 22, 2021
9800cba
update
tchaton Nov 22, 2021
427ed03
docstring improvement
tchaton Nov 22, 2021
ae712b0
update
tchaton Nov 22, 2021
9a5166d
Rename and simplify
carmocca Nov 22, 2021
b5fa819
Add comment
carmocca Nov 22, 2021
c82b2f2
update
tchaton Nov 22, 2021
2ede205
update
tchaton Nov 22, 2021
b16c4c0
update
tchaton Nov 22, 2021
ce9c23c
update
tchaton Nov 22, 2021
2baddb9
update
tchaton Nov 22, 2021
97548bb
update
tchaton Nov 22, 2021
d953ae9
update
tchaton Nov 22, 2021
41ffbab
use_teardown
tchaton Nov 22, 2021
d04596d
Use `Protocol`
carmocca Nov 22, 2021
ff7b836
Simplify test
carmocca Nov 22, 2021
a5698e6
Update CHANGELOG.md
carmocca Nov 22, 2021
79fdacc
update
tchaton Nov 22, 2021
916b520
update
tchaton Nov 22, 2021
4b67fbf
update
tchaton Nov 22, 2021
c9481e2
update
tchaton Nov 22, 2021
ef29342
update
tchaton Nov 22, 2021
4a1fff7
update
tchaton Nov 22, 2021
cb27e30
update
tchaton Nov 22, 2021
7903d24
resolve tests
tchaton Nov 22, 2021
20d19a1
update
tchaton Nov 22, 2021
1104cbc
update
tchaton Nov 23, 2021
f071f9a
change to 0
tchaton Nov 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -13,7 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Fault Tolerant Manual
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646))
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674))


-
168 changes: 160 additions & 8 deletions pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
@@ -16,15 +16,21 @@
from functools import partial, wraps
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Protocol, runtime_checkable, Tuple, Union

import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
from torch.utils.data.dataloader import (
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
_SingleProcessDataLoaderIter,
DataLoader,
IterableDataset,
)

import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training

@@ -441,7 +447,9 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]:
return {"num_workers": num_workers, "previous_worker": previous_worker}


def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict:
def _capture_metadata_collate(
samples: List, dataset: Dataset, collate: Callable, fault_tolerant_mode: _FaultTolerantMode
) -> Any:
"""A collate function that adds the state dict of a :class:`CaptureIterableDataset` or
:class:`CaptureMapDataset` used in the worker processes. This function gets executed within the worker
processes. The structure will be:
@@ -453,10 +461,26 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate:
"__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1},
}
"""
data = default_collate(samples)
if not isinstance(dataset, (CaptureIterableDataset, CaptureMapDataset)):
data = collate(samples)
fault_tolerant_mode
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not fault_tolerant_mode.is_enabled:
return data
metadata = dataset.state_dict()
metadata = None
if fault_tolerant_mode.is_automatic:
metadata = dataset.state_dict()
else:
state_dict_fn = getattr(dataset, "state_dict", None)
info = get_worker_info()
worker_id = info.id if info else 0
if state_dict_fn:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
metadata = state_dict_fn()
if worker_id not in metadata:
raise MisconfigurationException(
f"The state_dict returned by {dataset} needs to be indexed by `worker_id` integer keys."
)
if metadata is None:
metadata = {worker_id: {}}

return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}
tchaton marked this conversation as resolved.
Show resolved Hide resolved


@@ -486,6 +510,9 @@ def patch_dataloader_iterator(
will extract the current iteration as part of the metadata returned by a custom batch.
"""

if not _FaultTolerantMode.detect_current_mode().is_automatic:
return
tchaton marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))

def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable:
@@ -534,7 +561,10 @@ def wrapper():
def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
"""Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled."""
dataloader.collate_fn = partial(
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
_capture_metadata_collate,
dataset=dataloader.dataset,
default_collate=dataloader.collate_fn,
fault_tolerant_mode=_FaultTolerantMode.detect_current_mode(),
)


@@ -570,3 +600,125 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A

else:
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")


@runtime_checkable
class _SupportsStateDict(Protocol):
def state_dict(self) -> Dict[str, Any]:
...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...


class _StatefulMixin:
"""This mixin is used to make PyTorch DataLoaderIter stateful."""

def _reset(self, loader: DataLoader, first_iter: bool = False):
super()._reset(loader, first_iter=first_iter)
self._loader = loader
self.num_batches_fetched = 0

def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None:
# initialize the queue if it doesn't exist.
if not hasattr(self, "_sampler_state"):
self._sampler_state = []
self._sampler_state_idx = 0
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# store sampler state within a queue alongside its idx.
self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1
self._sampler_state.append((sampler_state, self._sampler_state_idx))

def _store_sampler_state(self) -> None:
"""This function is used to extract the sampler states if any."""
sampler_state = {
k: v.state_dict()
for k, v in self._loader.__dict__.items()
if isinstance(v, _SupportsStateDict) and k != "dataset"
}

self.__accumulate_state(sampler_state)

def _next_index(self) -> Any:
indexes = super()._next_index()
self._store_sampler_state()
return indexes

def _prepare_loader(self, loader):
if not isinstance(loader.collate_fn, partial):
loader.collate_fn = partial(
_capture_metadata_collate, dataset=loader.dataset, default_collate=loader.collate_fn
)
self._loader = loader
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
self.num_batches_fetched = 0

def __del__(self) -> None:
if isinstance(self._loader.collate_fn, partial):
self._loader.collate_fn = self._loader.collate_fn.keywords["default_collate"]

def _next_data(self) -> Any:
combined_batch = super()._next_data()

batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META]

self.num_batches_fetched += 1

sampler_state, sampler_state_idx = self._sampler_state.pop(0)
# there is no workers within the samplers
worker_id = list(state.keys())[0]

state = [
IteratorState(
num_workers=self._loader.num_workers,
sampler_state=sampler_state,
dataset_state=state,
worker_id=worker_id,
num_batches_fetched=self.num_batches_fetched,
)
]
# ensures there is an alignement between the sampler state and currently fetched batch
assert sampler_state_idx == self.num_batches_fetched
self._data_fetcher._store_dataloader_iter_state(self, state)
return batch


class _SingleProcessDataLoaderIterStateful(_StatefulMixin, _SingleProcessDataLoaderIter):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, loader: DataLoader):
self._prepare_loader(loader)
super().__init__(loader)


class _MultiProcessingDataLoaderIterStateful(_StatefulMixin, _MultiProcessingDataLoaderIter):
def __init__(self, loader: DataLoader):
self._prepare_loader(loader)
super().__init__(loader)


def _get_iterator(self) -> "_BaseDataLoaderIter":
if not hasattr(self, "_lightning_fetcher"):
raise MisconfigurationException(
"A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader."
)
if self.num_workers == 0:
return _SingleProcessDataLoaderIterStateful(self)
else:
if hasattr(self, "check_worker_number_rationality"):
self.check_worker_number_rationality()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return _MultiProcessingDataLoaderIterStateful(self)


def _patch_dataloader_get_iterators() -> None:
"""This function is used to replace the DataLoader iterator by their stateful version."""
if not hasattr(DataLoader, "_ori_get_iterator"):
DataLoader._ori_get_iterator = DataLoader._get_iterator
DataLoader._get_iterator = _get_iterator


def _teardown_dataloader_get_iterators() -> None:
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
# cleanup the get_iterator replacement in case of Fault Tolerant Training.
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
if get_iterator:
DataLoader._get_iterator = get_iterator
del DataLoader._ori_get_iterator
12 changes: 12 additions & 0 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
@@ -99,6 +99,8 @@ def setup(
if self.profiler is not None and stage is None:
raise MisconfigurationException("When providing a profiler, the stage should be provided too.")

self._attach_data_fetcher()

@staticmethod
def _add_capture_metadata_collate(dataloader: Iterable) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
@@ -190,6 +192,16 @@ def collect_state(iterator: Iterator):

return apply_to_collection(self.loader_iters, Iterator, collect_state)

def _attach_data_fetcher(self):
def _attach_data_fetcher_fn(loader: DataLoader):
if isinstance(loader, CycleIterator):
loader = loader.loader

if isinstance(loader, DataLoader) and _fault_tolerant_training():
loader._lightning_fetcher = self

apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn)

def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
if self.dataloader is None:
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
134 changes: 134 additions & 0 deletions tests/utilities/test_auto_restart.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,11 @@
_add_capture_metadata_collate,
_dataloader_load_state_dict,
_dataloader_to_state_dict,
_MultiProcessingDataLoaderIterStateful,
_patch_dataloader_get_iterators,
_SingleProcessDataLoaderIterStateful,
_SupportsStateDict,
_teardown_dataloader_get_iterators,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
@@ -1195,6 +1200,29 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
assert "dataloader_state_dict" in state_dict


def test_supports_state_dict_protocol():
class StatefulClass:
def state_dict(self):
pass

def load_state_dict(self, state_dict):
pass

assert isinstance(StatefulClass(), _SupportsStateDict)

class NotStatefulClass:
def state_dict(self):
pass

assert not isinstance(NotStatefulClass(), _SupportsStateDict)

class NotStateful2Class:
def load_state_dict(self, state_dict):
pass

assert not isinstance(NotStateful2Class(), _SupportsStateDict)


def test_fault_tolerant_mode_enum():
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}):
assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode()
@@ -1213,3 +1241,109 @@ def test_fault_tolerant_mode_enum():
):
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}):
_FaultTolerantMode.detect_current_mode()


class StatefulRandomSampler(RandomSampler):

counter = 0

def state_dict(self):
self.counter += 1
return {"counter": self.counter}

def load_state_dict(self, state_dict):
self.counter = state_dict["counter"]


class FailingStatefulRandomDataset(RandomDataset):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0

def __getitem__(self, index):
self.counter += 1
return super().__getitem__(index)

def state_dict(self):
return {"counter": self.counter}

def load_state_dict(self, state_dict):
self.counter = state_dict["counter"]


class StatefulRandomDataset(FailingStatefulRandomDataset):
def state_dict(self):
info = get_worker_info()
worker_id = info.id if info else 0
return {worker_id: {"counter": self.counter}}


@pytest.mark.parametrize("num_workers", [0, 2])
tchaton marked this conversation as resolved.
Show resolved Hide resolved
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"})
def test_stateful_workers(num_workers):

seed_everything(42)

_get_iterator_fn = DataLoader._get_iterator
_patch_dataloader_get_iterators()
assert DataLoader._ori_get_iterator is not None

data_fetcher = DataFetcher()
dataloader = DataLoader(FailingStatefulRandomDataset(1, 64), shuffle=True)

with pytest.raises(MisconfigurationException, match="A stateful iterator should be used"):
iter(dataloader)

# This would attach the `data_fetcher` to the DataLoader.
data_fetcher.setup(dataloader)

dataloader_iter = iter(dataloader)
assert isinstance(dataloader_iter, _SingleProcessDataLoaderIterStateful)

with pytest.raises(MisconfigurationException, match="he state_dict returned by"):
next(dataloader_iter)

data_fetcher = DataFetcher()
dataset = StatefulRandomDataset(1, 64)
dataloader = DataLoader(dataset, sampler=StatefulRandomSampler(dataset), num_workers=num_workers)

# This would attach the `data_fetcher` to the DataLoader.
data_fetcher.setup(dataloader)
data_fetcher_iter = iter(data_fetcher)

worker_type = _SingleProcessDataLoaderIterStateful if num_workers == 0 else _MultiProcessingDataLoaderIterStateful
assert isinstance(data_fetcher.dataloader_iter, worker_type)

next(data_fetcher_iter)
state = data_fetcher.dataloader_iter.state.state
assert state[0].dataset_state == {0: {"counter": 1}}
assert state[0].sampler_state["sampler"] == {"counter": 1}

next(data_fetcher_iter)
previous_state = data_fetcher.dataloader_iter.previous_state.state
state = data_fetcher.dataloader_iter.state.state
assert previous_state[0].dataset_state == {0: {"counter": 1}}
assert previous_state[0].sampler_state["sampler"] == {"counter": 1}
# TODO: Resolve the previous `sampler_state` associated to `worker_id: 0`.
worker_id = 1 if num_workers else 0
assert state[worker_id].sampler_state["sampler"] == {"counter": 2}

# each worker has its own copy of the dataset
assert state[0].dataset_state == ({0: {"counter": 2}} if num_workers == 0 else {0: {"counter": 1}})
target_previous_state = deepcopy(state)

next(data_fetcher_iter)
latest_worker_id = data_fetcher.dataloader_iter.state.latest_worker_id
assert latest_worker_id == 0
previous_state = data_fetcher.dataloader_iter.previous_state.state
state = data_fetcher.dataloader_iter.state.state

assert target_previous_state == previous_state
assert state[0].sampler_state["sampler"] == {"counter": 3}
assert state[0].dataset_state == ({0: {"counter": 3}} if num_workers == 0 else {0: {"counter": 2}})

_teardown_dataloader_get_iterators()
assert not hasattr(DataLoader, "_ori_get_iterator")
assert DataLoader._get_iterator == _get_iterator_fn

data_fetcher.teardown()