From 9e0cc177071031cd83eaecee9908fbb34d55c6a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 29 Jun 2022 21:09:35 +0200 Subject: [PATCH] Call `set_epoch` for distributed batch samplers (#13396) Co-authored-by: Jirka Co-authored-by: Rohit Gupta --- CHANGELOG.md | 7 +-- .../loops/dataloader/evaluation_loop.py | 11 +--- .../loops/dataloader/prediction_loop.py | 10 +--- pytorch_lightning/loops/fit_loop.py | 9 +-- pytorch_lightning/loops/utilities.py | 14 +++++ pytorch_lightning/trainer/supporters.py | 7 ++- tests/loops/test_evaluation_loop.py | 57 ++++++++++++++++--- tests/loops/test_utilities.py | 24 +++++++- 8 files changed, 101 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e631cc20599f0c..0c34b2b49dc0fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -243,15 +243,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350)) - - - Fixed bug with Python version check that prevented use with development versions of Python ([#13420](https://github.com/PyTorchLightning/pytorch-lightning/pull/13420)) - - - Fixed Model Summary when using DeepSpeed Stage 3 ([#13427](https://github.com/PyTorchLightning/pytorch-lightning/pull/13427)) - - - Fixed `pytorch_lightning.utilities.distributed.gather_all_tensors` to handle tensors of different dimensions ([#12630](https://github.com/PyTorchLightning/pytorch-lightning/pull/12630)) +- The loops now call `.set_epoch()` also on batch samplers if the dataloader has one wrapped in a distributed sampler ([#13396](https://github.com/PyTorchLightning/pytorch-lightning/pull/13396)) ## [1.6.4] - 2022-06-01 diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 2ec42fa0acbb5c..d45df5885c0ecb 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -26,6 +26,7 @@ from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop +from pytorch_lightning.loops.utilities import _set_sampler_epoch from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -161,14 +162,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self._has_run = True def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - dataloader = self.current_dataloader - if ( - dataloader is not None - and getattr(dataloader, "sampler", None) - and callable(getattr(dataloader.sampler, "set_epoch", None)) - ): - # set seed for distributed sampler (enables shuffling for each epoch) - dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed) + if self.current_dataloader is not None: + _set_sampler_epoch(self.current_dataloader, self.trainer.fit_loop.epoch_progress.current.processed) super().on_advance_start(*args, **kwargs) diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index a14a218ef67e91..36648b7f43e345 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -5,6 +5,7 @@ from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop +from pytorch_lightning.loops.utilities import _set_sampler_epoch from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PREDICT_OUTPUT @@ -87,13 +88,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None: """Predicts one entire dataloader.""" void(*args, **kwargs) dataloader = self.current_dataloader - if ( - dataloader is not None - and getattr(dataloader, "sampler", None) - and callable(getattr(dataloader.sampler, "set_epoch", None)) - ): - # set seed for distributed sampler (enables shuffling for each epoch) - dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed) + if dataloader is not None: + _set_sampler_epoch(dataloader, self.trainer.fit_loop.epoch_progress.current.processed) dataloader = self.trainer.strategy.process_dataloader(dataloader) dataloader_iter = enumerate(dataloader) dl_max_batches = self.max_batches[self.current_dataloader_idx] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index ac33390a97cece..0771a4a71de9f7 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -21,7 +21,7 @@ from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE -from pytorch_lightning.loops.utilities import _is_max_limit_reached +from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum @@ -232,11 +232,8 @@ def on_advance_start(self) -> None: # type: ignore[override] # reset outputs here instead of in `reset` as they are not accumulated between epochs self._outputs = [] - if self.trainer.train_dataloader is not None and callable( - getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) - ): - # set seed for distributed sampler (enables shuffling for each epoch) - self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed) + if self.trainer.train_dataloader is not None: + _set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index d84c195d758f94..15142be626587f 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -21,6 +21,7 @@ import numpy as np import torch from torch.optim import Optimizer +from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.loops import Loop @@ -228,3 +229,16 @@ def _reset_progress(loop: Loop) -> None: def _v1_8_output_format(fx: Callable) -> bool: parameters = inspect.signature(fx).parameters return "new_format" in parameters and parameters["new_format"].default is True + + +def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None: + """Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader. + + Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a + :class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning + of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off. + """ + for sampler_name in ("sampler", "batch_sampler"): + sampler = getattr(dataloader, sampler_name, None) + if sampler is not None and callable(getattr(sampler, "set_epoch", None)): + sampler.set_epoch(epoch) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index b8f688892b318b..6d3ec88b0be6a0 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -438,9 +438,14 @@ class DataLoaderDict(dict): @property def sampler(self) -> Union[Iterable, Sequence, Mapping]: - """Return a collections of samplers extracting from loaders.""" + """Return a collections of samplers extracted from loaders.""" return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "sampler", None) + @property + def batch_sampler(self) -> Union[Iterable, Sequence, Mapping]: + """Return a collections of batch samplers extracted from loaders.""" + return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "batch_sampler", None) + def _wrap_loaders_max_size_cycle(self) -> Any: """Wraps all loaders to make sure they are cycled until the longest loader is exhausted. diff --git a/tests/loops/test_evaluation_loop.py b/tests/loops/test_evaluation_loop.py index 137608c426ee03..bddf819aafdd6d 100644 --- a/tests/loops/test_evaluation_loop.py +++ b/tests/loops/test_evaluation_loop.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock -from unittest.mock import Mock +from unittest.mock import call, Mock import torch from torch.utils.data.dataloader import DataLoader -from torch.utils.data.sampler import RandomSampler +from torch.utils.data.sampler import BatchSampler, RandomSampler from pytorch_lightning import Trainer from pytorch_lightning.loops import EvaluationEpochLoop @@ -44,9 +44,8 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): assert eval_epoch_end_mock.call_count == 4 -def test_set_epoch_called_eval_predict(tmpdir): - """Tests that set_epoch (if the sampler has one) is called on the DataLoader during evaluation and - prediction.""" +def test_evaluation_loop_sampler_set_epoch_called(tmpdir): + """Tests that set_epoch is called on the dataloader's sampler (if any) during training and validation.""" def _get_dataloader(): dataset = RandomDataset(32, 64) @@ -56,20 +55,60 @@ def _get_dataloader(): model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, enable_model_summary=False + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + ) + + train_dataloader = _get_dataloader() + val_dataloader = _get_dataloader() + trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) + # One for each epoch + assert train_dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)] + # One for each epoch + sanity check + assert val_dataloader.sampler.set_epoch.call_args_list == [call(0), call(0), call(1)] + + val_dataloader = _get_dataloader() + trainer.validate(model, val_dataloader) + assert val_dataloader.sampler.set_epoch.call_args_list == [call(2)] + + +def test_evaluation_loop_batch_sampler_set_epoch_called(tmpdir): + """Tests that set_epoch is called on the dataloader's batch sampler (if any) during training and validation.""" + + def _get_dataloader(): + dataset = RandomDataset(32, 64) + sampler = RandomSampler(dataset) + batch_sampler = BatchSampler(sampler, 2, True) + batch_sampler.set_epoch = Mock() + return DataLoader(dataset, batch_sampler=batch_sampler) + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, ) train_dataloader = _get_dataloader() val_dataloader = _get_dataloader() trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) # One for each epoch - assert train_dataloader.sampler.set_epoch.call_count == 2 + assert train_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(1)] # One for each epoch + sanity check - assert val_dataloader.sampler.set_epoch.call_count == 3 + assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(0), call(1)] val_dataloader = _get_dataloader() trainer.validate(model, val_dataloader) - assert val_dataloader.sampler.set_epoch.call_count == 1 + assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(2)] @mock.patch( diff --git a/tests/loops/test_utilities.py b/tests/loops/test_utilities.py index c5d2e98d008b08..914c1de8e115b5 100644 --- a/tests/loops/test_utilities.py +++ b/tests/loops/test_utilities.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import Mock + import pytest import torch -from pytorch_lightning.loops.utilities import _extract_hiddens, _v1_8_output_format +from pytorch_lightning.loops.utilities import _extract_hiddens, _set_sampler_epoch, _v1_8_output_format from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -61,3 +63,23 @@ def training_epoch_end(outputs, new_format=True): ... assert _v1_8_output_format(training_epoch_end) + + +def test_set_sampler_epoch(): + # No samplers + dataloader = Mock() + dataloader.sampler = None + dataloader.batch_sampler = None + _set_sampler_epoch(dataloader, 55) + + # set_epoch not callable + dataloader = Mock() + dataloader.sampler.set_epoch = None + dataloader.batch_sampler.set_epoch = None + _set_sampler_epoch(dataloader, 55) + + # set_epoch callable + dataloader = Mock() + _set_sampler_epoch(dataloader, 55) + dataloader.sampler.set_epoch.assert_called_once_with(55) + dataloader.batch_sampler.set_epoch.assert_called_once_with(55)