Skip to content

Commit

Permalink
force crash when max_epochs < epochs in a checkpoint (Lightning-AI#3580)
Browse files Browse the repository at this point in the history
* force crash when max_epochs < epochs in a checkpoint

* force crash when max_epochs < epochs in a checkpoint
  • Loading branch information
williamFalcon authored Sep 21, 2020
1 parent a71d62d commit 2775389
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
9 changes: 9 additions & 0 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from apex import amp
Expand Down Expand Up @@ -145,6 +146,14 @@ def restore_training_state(self, checkpoint):
self.trainer.global_step = checkpoint['global_step']
self.trainer.current_epoch = checkpoint['epoch']

# crash if max_epochs is lower than the current epoch from the checkpoint
if self.trainer.current_epoch > self.trainer.max_epochs:
m = f"""
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
but the Trainer(max_epochs={self.trainer.max_epochs})
"""
raise MisconfigurationException(m)

# Division deals with global step stepping once per accumulated batch
# Inequality deals with different global step for odd vs even num_training_batches
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
Expand Down
5 changes: 4 additions & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class EarlyStoppingTestRestore(EarlyStopping):
Expand Down Expand Up @@ -63,7 +64,9 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
resume_from_checkpoint=checkpoint_filepath,
early_stop_callback=early_stop_callback,
)
new_trainer.fit(model)

with pytest.raises(MisconfigurationException, match=r'.*you restored a checkpoint with current_epoch*'):
new_trainer.fit(model)


def test_early_stopping_no_extraneous_invocations(tmpdir):
Expand Down

0 comments on commit 2775389

Please sign in to comment.