Skip to content

Commit

Permalink
nb steps in early stop (#3909)
Browse files Browse the repository at this point in the history
* nb steps

* if

* skip

* rev

* seed

* seed
  • Loading branch information
Borda authored Oct 6, 2020
1 parent 39b3704 commit 064ae53
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci_test-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
pip list
- name: Cache datasets
# todo this probably does not work with docker images, rather cache dockers
uses: actions/cache@v2
with:
path: Datasets # This path is specific to Ubuntu
Expand Down
1 change: 1 addition & 0 deletions .pyrightconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"pytorch_lightning/trainer/connectors/checkpoint_connector.py",
"pytorch_lightning/trainer/connectors/data_connector.py",
"pytorch_lightning/trainer/connectors/logger_connector.py",
"pytorch_lightning/trainer/connectors/slurm_connector.py",
"pytorch_lightning/distributed/dist.py",
"pytorch_lightning/tuner",
"pytorch_lightning/plugins"
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _run_early_stopping_check(self, trainer, pl_module):

def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
# todo: remove this old warning
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping triggered.')
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def restore_training_state(self, checkpoint):
self.trainer.global_step = checkpoint['global_step']
self.trainer.current_epoch = checkpoint['epoch']

# crash if max_epochs is lower than the current epoch from the checkpoint
# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.current_epoch > self.trainer.max_epochs:
m = f"""
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
Expand Down
6 changes: 3 additions & 3 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -35,7 +35,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
https://github.com/PyTorchLightning/pytorch-lightning/issues/1464
https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
"""

seed_everything(42)
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(monitor="early_stop_on", save_top_k=1)
early_stop_callback = EarlyStoppingTestRestore()
Expand All @@ -60,7 +60,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state)
new_trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
max_epochs=1,
resume_from_checkpoint=checkpoint_filepath,
early_stop_callback=early_stop_callback,
)
Expand Down

0 comments on commit 064ae53

Please sign in to comment.