- 
                Notifications
    
You must be signed in to change notification settings  - Fork 3.6k
 
          Reset trainer variable should_stop when fit is called
          #19177
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
  
    Reset trainer variable should_stop when fit is called
  
  #19177
              Conversation
| 
           seems this is failing on a test that is designed to make sure the trainer stays as  @pytest.mark.parametrize(("min_epochs", "min_steps", "val_count"), [(3, None, 3), (None, 3, 2)])
def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, tmp_path):
    """Regression test for issue #15708.
    Test that the request for `should_stop=True` only triggers validation when Trainer is allowed to stop
    (min_epochs/steps is satisfied).
    """
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmp_path,
        num_sanity_val_steps=0,
        limit_val_batches=2,
        limit_train_batches=2,
        max_epochs=3,
        min_epochs=min_epochs,
        min_steps=min_steps,
        enable_model_summary=False,
        enable_checkpointing=False,
    )
    trainer.should_stop = True  # Request to stop before min_epochs/min_steps are reached
    trainer.fit_loop.epoch_loop.val_loop.run = Mock()
    trainer.fit(model)
    assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count | 
    
for more information, see https://pre-commit.ci
| 
           I have changed the above test to use an  +    class NewBoring(BoringModel):
+        def training_step(self, batch, batch_idx):
+            self.log("loss", self.step(batch))
+            return {"loss": self.step(batch)}
+
-    model = BoringModel()
+    model = NewBoring()
+    # create a stopping condition with a high threshold so it triggers immediately
+    # check the condition before validation so the count is unaffected
+    stopping = EarlyStopping(monitor="loss", check_on_train_epoch_end=True, stopping_threshold=100)
     trainer = Trainer(
        default_root_dir=tmp_path,
        num_sanity_val_steps=0,
        limit_val_batches=2,
        limit_train_batches=2,
        max_epochs=3,
        min_epochs=min_epochs,
        min_steps=min_steps,
        enable_model_summary=False,
        enable_checkpointing=False,
        callbacks=[stopping],
    )
-   trainer.should_stop = True  # Request to stop before min_epochs/min_steps are reached
    trainer.fit_loop.epoch_loop.val_loop.run = Mock()
    trainer.fit(model)
    assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count | 
    
          ️✅ There are no secrets present in this pull request anymore.If these secrets were true positive and are still valid, we highly recommend you to revoke them. 🦉 GitGuardian detects secrets in your source code to help developers and security teams secure the modern development process. You are seeing this because you or someone else with access to this repository has authorized GitGuardian to scan your pull request.  | 
    
          Codecov ReportAll modified and coverable lines are covered by tests ✅ 
 
 Additional details and impacted files@@            Coverage Diff             @@
##           master   #19177      +/-   ##
==========================================
- Coverage      83%      48%     -35%     
==========================================
  Files         450      442       -8     
  Lines       38250    38098     -152     
==========================================
- Hits        31893    18438   -13455     
- Misses       6357    19660   +13303     🚀 New features to boost your workflow:
  | 
    
| 
           is this PR in progress??  | 
    
--------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 40c682e)
--------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 40c682e)
What does this PR do?
Reset trainer variable
should_stopwhenfitis calledIf fit is called after early stopping has already stopped training, then the model will not continue training as the trainer flag
should_stopis currently not reset when fit is called.Change this to reset
should_stopevery time fit is calledFixes #18727
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--19177.org.readthedocs.build/en/19177/