Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

early stopping checks on_validation_end #1458

Merged
merged 13 commits into from
May 25, 2020

Conversation

baldassarreFe
Copy link
Contributor

@baldassarreFe baldassarreFe commented Apr 11, 2020

EarlyStopping should check the metric of interest on_validation_end rather than on_epoch_end.
In a normal scenario, this does not cause a problem, but in combination with check_val_every_n_epoch>1 in the Trainer it results in a warning or in a RuntimeError depending on strict.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

What does this PR do?

Fixes #490.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@mergify mergify bot requested a review from a team April 11, 2020 17:55
@Borda Borda changed the title Fixes PyTorchLightning/pytorch-lightning#490 early stopping checks on_validation_end Apr 11, 2020
@codecov
Copy link

codecov bot commented Apr 11, 2020

Codecov Report

Merging #1458 into master will decrease coverage by 0%.
The diff coverage is 100%.

@@          Coverage Diff           @@
##           master   #1458   +/-   ##
======================================
- Coverage      88%     88%   -0%     
======================================
  Files          74      74           
  Lines        4643    4645    +2     
======================================
  Hits         4068    4068           
- Misses        575     577    +2     

@awaelchli
Copy link
Contributor

Does ES still work on training metrics when no validation loop is defined? I thought there is a test and I would have expected this change to break it.

@Borda
Copy link
Member

Borda commented Apr 12, 2020

This shall be covered with a test...

@Borda Borda added the feature Is an improvement or enhancement label Apr 14, 2020
@jeremyjordan
Copy link
Contributor

jeremyjordan commented Apr 16, 2020

i don't think this (changing on_epoch_end to on_validation_end) is how we want to solve the root problem. we should enable early stopping to use metrics from either the train or val steps, but we shouldn't complain about not having validation metrics on epochs which validation wasn't run

@Borda
Copy link
Member

Borda commented Apr 16, 2020

@PyTorchLightning/core-contributors suggestion about this issue?

@Borda Borda added discussion In a discussion stage help wanted Open to be worked on labels Apr 16, 2020
@baldassarreFe
Copy link
Contributor Author

baldassarreFe commented Apr 16, 2020

Thanks for the input, @jeremyjordan. In general, I'm also in favor of offering flexibility to users, but I figured in this case the intention was to provide a simple way to monitor a validation metric (the default metric is val_loss).

This makes sense to me, since training metrics should always improve during training (e.g. when minimizing a loss with SGD), and the common practice to prevent overfitting is to keep an eye on validation metrics instead. That's why I think, at least by default, monitoring a validation metric is the most sensible thing to do.

If a more flexible solution is required, I suggest something similar to:

class EarlyStopping(Callback):
  def __init__(..., when='validation'):
    ...
    self.when = when
  
  def _should_stop(self, trainer, pl_module):
    # Existing stopping logic  
    ...
  
  def on_epoch_end(self, trainer, pl_module):
    if self.when != 'training':
      return False
    return self._should_stop(trainer, pl_module)
  
  def on_validation_end(self, trainer, pl_module):
    if self.when != 'validation':
      return False
    return self._should_stop(trainer, pl_module)

I agree it's not beautiful, but it does the job. Or would you prefer splitting early stopping in two classes e.g. EarlyStoppingTrain and EarlyStoppingVal?

@Borda Borda added this to the 0.7.5 milestone Apr 25, 2020
@Borda
Copy link
Member

Borda commented Apr 30, 2020

@jeremyjordan ^^

@williamFalcon
Copy link
Contributor

@baldassarreFe agree with @jeremyjordan that the solution might be more involved.

We need to enumerate the cases and work this properly.

Early stopping should work everytime validation ends (as you have changed here).

But there's the case where there is no val loop... in this case, the early stopping should run at the end of every epoch? But i'm also fine with requiring early stopping to work only on val...

Thoughts?

@jeremyjordan
Copy link
Contributor

i would also be fine with requiring early stopping to only work with a validation loop. however, we would need more changes to enable this. take for example, the patience argument which is currently defined as the number of epochs with no improvement. if we instead change this to run every time validation runs we need to consider:

  • when validation runs multiple times per epoch, we only track the optimal value
  • when validation runs once per epoch, we do the same thing the current code does
  • when validation runs every n epochs, the patience argument becomes fragile

@shijie-wu
Copy link

shijie-wu commented May 6, 2020

i would also be fine with requiring early stopping to only work with a validation loop. however, we would need more changes to enable this. take for example, the patience argument which is currently defined as the number of epochs with no improvement. if we instead change this to run every time validation runs we need to consider:

* when validation runs multiple times per epoch, we only track the optimal value

* when validation runs once per epoch, we do the same thing the current code does

* when validation runs every n epochs, the patience argument becomes fragile

IMO the patience argument should be defined as the number of validations with no improvement, which offer better support for the following two cases:

  • when validation runs multiple times per epoch, common for large dataset for example BERT-style pretraining
  • when validation runs every n epochs

@jeremyjordan
Copy link
Contributor

i agree with @shijie-wu, we shall enforce that the out of the box EarlyStopping callback runs when validation is present and the arguments are defined with respect to validation runs. if a user has a different (rare) use case, they can always implement their own logic in a custom callback.

this change will have an effect on existing users, but i think it's a reasonable one.

@williamFalcon
Copy link
Contributor

agreed!

@baldassarreFe
Copy link
Contributor Author

So, to summarize:

  • the callback should run only on validation epochs (therefore it will not monitor training metrics for early stopping)
  • the patience parameter counts the number of validation checks that are performed by the callback, not the training epochs nor the partial training epochs

If this is correct, I think the code in the initial pull request will satisfy this behavior. I would also update the docs to clarify what we changed, warn the user about possible misunderstandings, and suggest that advanced behavior can be obtained by implementing a custom early stopping callback. Then we're done.

Shall I?

@williamFalcon
Copy link
Contributor

yes, let's do it

@jeremyjordan
Copy link
Contributor

there's an existing test which we'll need to remove
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/callbacks/test_callbacks.py#L204

as @awaelchli pointed out i'm surprised this branch isn't failing on that test currently

@Borda
Copy link
Member

Borda commented May 10, 2020

as @awaelchli pointed out i'm surprised this branch isn't failing on that test currently

do we have a bug in tests?

@Borda
Copy link
Member

Borda commented May 11, 2020

@baldassarreFe how is it going, can we finish it for next release?

@Borda Borda modified the milestones: 0.7.6, 0.8.0 May 12, 2020
@mergify
Copy link
Contributor

mergify bot commented May 12, 2020

This pull request is now in conflict... :(

@baldassarreFe baldassarreFe force-pushed the bugfix/early-stopping branch from 9706712 to 38ed30d Compare May 12, 2020 17:48
@mergify
Copy link
Contributor

mergify bot commented May 12, 2020

This pull request is now in conflict... :(

@awaelchli
Copy link
Contributor

awaelchli commented May 23, 2020

@jeremyjordan Good observation. I just quickly tested this locally and commenting out the two last if clauses does not change anything, the test runs forever. Agree that this validation step pattern should be an option in some test and not default :)

@awaelchli
Copy link
Contributor

awaelchli commented May 23, 2020

Found the issue. This PR accidentally disables ES because

  • callback method was renamed to on_epoch_end from on_validation_end
  • Trainer still manually calls on_epoch_end (see training_loop.py line 365)
  • ES callback is not actually in list of callbacks, therefore not called

This is why it is disabled.

@jeremyjordan
Copy link
Contributor

Ah yes I forgot about that, in my WIP PR (#1504) I updated the trainer to not invoke this directly.

@jeremyjordan
Copy link
Contributor

@baldassarreFe would you mind updating the PR swapping

should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())

to

should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())

and then I'll take care of removing this explicit invocation in #1504 ?

@mergify
Copy link
Contributor

mergify bot commented May 25, 2020

This pull request is now in conflict... :(

@Borda Borda force-pushed the bugfix/early-stopping branch from 0db51fa to f9c0ac5 Compare May 25, 2020 16:59
@Borda Borda removed the discussion In a discussion stage label May 25, 2020
@mergify mergify bot merged commit 65b4352 into Lightning-AI:master May 25, 2020
@mergify
Copy link
Contributor

mergify bot commented May 25, 2020

Great job! =)

@baldassarreFe
Copy link
Contributor Author

Great, we made it! Thanks to you all ;)

@Borda Borda modified the milestones: 0.7.7, 0.8.0 May 26, 2020
@collinmccarthy
Copy link

Thank you for this! But shouldn't on_validation_end return the flag to stop training? I think there may be a missing return statement here: https://github.com/PyTorchLightning/pytorch-lightning/blob/3af4994d5a84bc80738b50983b4b42c3eb946433/pytorch_lightning/callbacks/early_stopping.py#L113 ?

@baldassarreFe
Copy link
Contributor Author

Oh yeah, @collinmccarthy I think you're right. The trainer relies on the return value to stop:

if self.enable_early_stop:
    if (met_min_epochs and met_min_steps) or self.fast_dev_run:
        should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())
        # stop training
        stop = should_stop and met_min_epochs
        if stop:
            self.run_training_teardown()
            return

In 006a067 I only changed the name of the method from on_epoch_end to on_validation_end, so the return value was the same.

commit 006a0678f833626dc255a139a87335423d23a0d7
Author: Federico Baldassarre <[email protected]>
Date:   Sat Apr 11 19:33:09 2020 +0200

    Fixes PyTorchLightning/pytorch-lightning#490
    
    `EarlyStopping` should check the metric of interest `on_validation_end` rather than `on_epoch_end`.
    In a normal scenario, this does not cause a problem, but in combination with `check_val_every_n_epoch>1` in the `Trainer` it results in a warning or in a `RuntimeError` depending on `strict`.

diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py
index 100c3171..32e37ac5 100644
--- a/pytorch_lightning/callbacks/early_stopping.py
+++ b/pytorch_lightning/callbacks/early_stopping.py
@@ -109,7 +109,7 @@ class EarlyStopping(Callback):
         self.stopped_epoch = 0
         self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
 
-    def on_epoch_end(self, trainer, pl_module):
+    def on_validation_end(self, trainer, pl_module):
         logs = trainer.callback_metrics
         stop_training = False
         if not self._validate_condition_metric(logs):

Then in 1680c88 the stopping logic was moved to a separate method. I think that's where we missed the return.

commit 1680c88c17a6cc072b2105a28dbf2e9a15f96a15
Author: William Falcon <[email protected]>
Date:   Sun May 17 08:42:15 2020 -0400

    Update early_stopping.py

diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py
index 9be02a28..61abfb87 100644
--- a/pytorch_lightning/callbacks/early_stopping.py
+++ b/pytorch_lightning/callbacks/early_stopping.py
@@ -110,6 +110,9 @@ class EarlyStopping(Callback):
         self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
 
     def on_validation_end(self, trainer, pl_module):
+        self._run_early_stopping_check(trainer, pl_module)
+
+    def _run_early_stopping_check(self, trainer, pl_module):
         logs = trainer.callback_metrics
         stop_training = False
         if not self._validate_condition_metric(logs):

@williamFalcon do you think that's why that test was failing? Maybe it was because the new callback is actually not stopping training.

@jeremyjordan
Copy link
Contributor

this shall be addressed in #1504, i switched from returning a value to setting the trainer attribute should_stop so that the callback can be appropriately managed by the callback handler

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Early stopping conditioned on metric val_loss isn't recognised when setting the val_check_interval
8 participants