Skip to content

Fix val_loop run on restart #11552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))


- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))


## [1.5.8] - 2022-01-05

### Fixed
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:

# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = is_last_batch

# while restarting with no fault-tolerant, batch_progress.current.ready is -1
if batch_idx == -1:
return False

if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):
Expand Down
16 changes: 16 additions & 0 deletions tests/loops/epoch/test_training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import pytest

from pytorch_lightning.loops import TrainingEpochLoop
from pytorch_lightning.trainer.trainer import Trainer
from tests.helpers.boring_model import BoringModel

_out00 = {"loss": 0.0}
_out01 = {"loss": 0.1}
Expand Down Expand Up @@ -141,3 +143,17 @@ def test_prepare_outputs_training_batch_end_manual(batch_end_outputs, expected):
num_optimizers=-1, # does not matter for manual optimization
)
assert prepared == expected


def test_no_val_on_train_epoch_loop_restart(tmpdir):
"""Test that training validation loop doesn't get triggered at the beginning of a restart."""
trainer = Trainer()
model = BoringModel()
trainer.strategy.connect(model)
trainer._data_connector.attach_data(model)
trainer.reset_train_dataloader()
training_epoch_loop = trainer.fit_loop.epoch_loop
training_epoch_loop.restarting = True
assert not training_epoch_loop._should_check_val_fx(
training_epoch_loop.batch_idx, training_epoch_loop.batch_progress.is_last_batch
)