From 277538970d76ac342f89498e7035eef033f9924f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 20 Sep 2020 22:04:22 -0400 Subject: [PATCH] force crash when max_epochs < epochs in a checkpoint (#3580) * force crash when max_epochs < epochs in a checkpoint * force crash when max_epochs < epochs in a checkpoint --- .../trainer/connectors/checkpoint_connector.py | 9 +++++++++ tests/callbacks/test_early_stopping.py | 5 ++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index df6e8b8213c36..7ebd84a428de7 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS from pytorch_lightning.accelerators.base_backend import Accelerator +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp @@ -145,6 +146,14 @@ def restore_training_state(self, checkpoint): self.trainer.global_step = checkpoint['global_step'] self.trainer.current_epoch = checkpoint['epoch'] + # crash if max_epochs is lower than the current epoch from the checkpoint + if self.trainer.current_epoch > self.trainer.max_epochs: + m = f""" + you restored a checkpoint with current_epoch={self.trainer.current_epoch} + but the Trainer(max_epochs={self.trainer.max_epochs}) + """ + raise MisconfigurationException(m) + # Division deals with global step stepping once per accumulated batch # Inequality deals with different global step for odd vs even num_training_batches n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index d681d9cb894d8..d08a4f6fe35ea 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -9,6 +9,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from tests.base import EvalModelTemplate +from pytorch_lightning.utilities.exceptions import MisconfigurationException class EarlyStoppingTestRestore(EarlyStopping): @@ -63,7 +64,9 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): resume_from_checkpoint=checkpoint_filepath, early_stop_callback=early_stop_callback, ) - new_trainer.fit(model) + + with pytest.raises(MisconfigurationException, match=r'.*you restored a checkpoint with current_epoch*'): + new_trainer.fit(model) def test_early_stopping_no_extraneous_invocations(tmpdir):