Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call set_epoch for distributed batch samplers #13396

Merged
merged 11 commits into from
Jun 29, 2022
2 changes: 1 addition & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350))


-
- The loops now call `.set_epoch()` also on batch samplers if the dataloader has one wrapped in a distributed sampler ([#13396](https://github.com/PyTorchLightning/pytorch-lightning/pull/13396))


## [1.6.4] - 2022-06-01
Expand Down
11 changes: 3 additions & 8 deletions src/pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.loops.utilities import _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -161,14 +162,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
self._has_run = True

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
dataloader = self.current_dataloader
if (
dataloader is not None
and getattr(dataloader, "sampler", None)
and callable(getattr(dataloader.sampler, "set_epoch", None))
):
# set seed for distributed sampler (enables shuffling for each epoch)
dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed)
if self.current_dataloader is not None:
_set_sampler_epoch(self.current_dataloader, self.trainer.fit_loop.epoch_progress.current.processed)

super().on_advance_start(*args, **kwargs)

Expand Down
10 changes: 3 additions & 7 deletions src/pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.loops.utilities import _set_sampler_epoch
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
Expand Down Expand Up @@ -90,13 +91,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Predicts one entire dataloader."""
void(*args, **kwargs)
dataloader = self.current_dataloader
if (
dataloader is not None
and getattr(dataloader, "sampler", None)
and callable(getattr(dataloader.sampler, "set_epoch", None))
):
# set seed for distributed sampler (enables shuffling for each epoch)
dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed)
if dataloader is not None:
_set_sampler_epoch(dataloader, self.trainer.fit_loop.epoch_progress.current.processed)
dataloader = self.trainer.strategy.process_dataloader(dataloader)
dataloader_iter = enumerate(dataloader)
dl_max_batches = self.max_batches[self.current_dataloader_idx]
Expand Down
9 changes: 3 additions & 6 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
Expand Down Expand Up @@ -232,11 +232,8 @@ def on_advance_start(self) -> None: # type: ignore[override]
# reset outputs here instead of in `reset` as they are not accumulated between epochs
self._outputs = []

if self.trainer.train_dataloader is not None and callable(
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
):
# set seed for distributed sampler (enables shuffling for each epoch)
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed)
if self.trainer.train_dataloader is not None:
_set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed)

# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
Expand Down
14 changes: 14 additions & 0 deletions src/pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loops import Loop
Expand Down Expand Up @@ -220,3 +221,16 @@ def _reset_progress(loop: Loop) -> None:
def _v1_8_output_format(fx: Callable) -> bool:
parameters = inspect.signature(fx).parameters
return "new_format" in parameters and parameters["new_format"].default is True


def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader.

Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a
:class:`~torch.utils.data.DistributedSampler`, ``set_epoch`` must be called at the beginning of every epoch to
ensure shuffling applies a new ordering. This has no effect if shuffling is off.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""
for sampler_name in ("sampler", "batch_sampler"):
sampler = getattr(dataloader, sampler_name, None)
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):
sampler.set_epoch(epoch)
7 changes: 6 additions & 1 deletion src/pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,14 @@ class DataLoaderDict(dict):

@property
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of samplers extracting from loaders."""
"""Return a collections of samplers extracted from loaders."""
return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "sampler", None)

@property
def batch_sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of batch samplers extracted from loaders."""
return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "batch_sampler", None)

def _wrap_loaders_max_size_cycle(self) -> Any:
"""Wraps all loaders to make sure they are cycled until the longest loader is exhausted.
Expand Down
57 changes: 48 additions & 9 deletions tests/tests_pytorch/loops/test_evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import Mock
from unittest.mock import call, Mock

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.sampler import BatchSampler, RandomSampler

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
Expand Down Expand Up @@ -44,9 +44,8 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
assert eval_epoch_end_mock.call_count == 4


def test_set_epoch_called_eval_predict(tmpdir):
"""Tests that set_epoch (if the sampler has one) is called on the DataLoader during evaluation and
prediction."""
def test_evaluation_loop_sampler_set_epoch_called(tmpdir):
"""Tests that set_epoch is called on the dataloader's sampler (if any) during training and validation."""

def _get_dataloader():
dataset = RandomDataset(32, 64)
Expand All @@ -56,20 +55,60 @@ def _get_dataloader():

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, enable_model_summary=False
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
)

train_dataloader = _get_dataloader()
val_dataloader = _get_dataloader()
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
# One for each epoch
assert train_dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)]
# One for each epoch + sanity check
assert val_dataloader.sampler.set_epoch.call_args_list == [call(0), call(0), call(1)]

val_dataloader = _get_dataloader()
trainer.validate(model, val_dataloader)
assert val_dataloader.sampler.set_epoch.call_args_list == [call(2)]


def test_evaluation_loop_batch_sampler_set_epoch_called(tmpdir):
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during training and validation."""

def _get_dataloader():
dataset = RandomDataset(32, 64)
sampler = RandomSampler(dataset)
batch_sampler = BatchSampler(sampler, 2, True)
batch_sampler.set_epoch = Mock()
return DataLoader(dataset, batch_sampler=batch_sampler)

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
)

train_dataloader = _get_dataloader()
val_dataloader = _get_dataloader()
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
# One for each epoch
assert train_dataloader.sampler.set_epoch.call_count == 2
assert train_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(1)]
# One for each epoch + sanity check
assert val_dataloader.sampler.set_epoch.call_count == 3
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(0), call(1)]

val_dataloader = _get_dataloader()
trainer.validate(model, val_dataloader)
assert val_dataloader.sampler.set_epoch.call_count == 1
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(2)]


@mock.patch(
Expand Down
25 changes: 24 additions & 1 deletion tests/tests_pytorch/loops/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, Mock

import pytest
import torch
from torch.utils.data import DataLoader

from pytorch_lightning.loops.utilities import _extract_hiddens, _v1_8_output_format
from pytorch_lightning.loops.utilities import _extract_hiddens, _set_sampler_epoch, _v1_8_output_format
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -61,3 +64,23 @@ def training_epoch_end(outputs, new_format=True):
...

assert _v1_8_output_format(training_epoch_end)


def test_set_sampler_epoch():
# No samplers
dataloader = Mock()
dataloader.sampler = None
dataloader.batch_sampler = None
_set_sampler_epoch(dataloader, 55)

# set_epoch not callable
dataloader = Mock()
dataloader.sampler.set_epoch = None
dataloader.batch_sampler.set_epoch = None
_set_sampler_epoch(dataloader, 55)

# set_epoch callable
dataloader = Mock()
_set_sampler_epoch(dataloader, 55)
dataloader.sampler.set_epoch.assert_called_once_with(55)
dataloader.batch_sampler.set_epoch.assert_called_once_with(55)