-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bugfix] Run logger.after_save_checkpoint
in model checkpoint's on_train_end
hook
#9783
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for identifying this! did you see any other points in the checkpoint callback where the logger was not triggered?
logger.after_save_checkpoint
in model checkpoint's on_train_end
hook
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
Codecov Report
@@ Coverage Diff @@
## master #9783 +/- ##
=======================================
- Coverage 93% 89% -4%
=======================================
Files 178 178
Lines 15648 15650 +2
=======================================
- Hits 14503 13897 -606
- Misses 1145 1753 +608 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR looks good. I think the test can be improved or configured with one of the existing tests which tests for save_last
@@ -77,6 +77,53 @@ def mock(key): | |||
return calls | |||
|
|||
|
|||
@pytest.mark.parametrize("save_last", [False, True]) # problem with save_last = True | |||
@pytest.mark.parametrize("save_on_train_epoch_end", [True, False]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as per PR title, the bug was that it's not calling after_save_ckpt
for last.ckpt which is saved in on_train_end
. so do we need this? since on_train_end is called irrespective of save_on_train_epoch_end
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is meant to raise an error in all scenarios in which logger.after_save_checkpoint
is not called after a checkpoint file is saved so we need to test both save_last = True and False.
triggered_train_epoch_end: bool = False | ||
trainer: Optional[Trainer] | ||
|
||
def after_save_checkpoint(self, *_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def after_save_checkpoint(self, *_): | |
def after_save_checkpoint(self, *_): |
I think a better test would be to mock LightningLoggerBase.after_save_checkpoint
and check it's call count
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good thought. Originally, I tried to implement the test as you suggested but did not succeed. The problem is that:
save_checkpoint
calls 3 separate saving routines (_save_top_k_checkpoint, _save_none_monitor_checkpoint, _save_last_checkpoint) followed by a single call toLogger.after_save_checkpoint
.on_train_end
calls a single saving routine (_save_last_checkpoint) followed by a single call toLogger.after_save_checkpoint
.
This makes checking the call counts a fragile way to test that Logger.after_save_checkpoint
is called appropriately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you can try just this config:
ModelCheckpoint(save_top_k=0, monitor=None, save_last=True)
it only saves the last.ckpt
, thus will ensure the after_save_checkpoint
call for on_train_end
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have rewritten the test. It now checks that the set of ckpt files generated by the trainer matches with the set of ckpt files received by the logger.after_save_checkpoint
.
aMerge branch 'bugfix/ckpt_and_logger' of https://github.com/dalessioluca/pytorch-lightning into bugfix/ckpt_and_logger
# Conflicts: # tests/checkpointing/test_model_checkpoint.py
Head branch was pushed to by a user without write access
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://pytorch-lightning.readthedocs.io/en/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Slack. Thank you for your contributions. |
hey @dalessioluca , thank you for this PR and identifying issues here. looks like it does log the model here atleast for wandb: https://github.com/PyTorchLightning/pytorch-lightning/blob/45c45dc7b018f9a2db60f5df1a3f7dbbb45ccb36/pytorch_lightning/loggers/wandb.py#L461-L464 but the way these are logged and how it connects to the logger after checkpointing isn't super convenient and can be done in a more optimal way. We should call after_save_checkpoint inside trainer.save_checkpoint(). I'll do another PR after sometime which might solve more issues and this one too. In case you want to make changes to loggers and checkpoint then it would be great else you can close this one. Again thanks a lot for this.. was really helpful identifying more underlying issues. |
Dear @dalessioluca, Based on @rohitgr7 latest comment, I believe it is better to make sure Would you be willing to keep going with this PR based on @rohitgr7 comments ? Best, |
It indeed seems more clean to call after_save_checkpoint inside save_checkpoint. |
also you might need to update some loggers as well. for starters, wandb won't be needing the finalize method if we call |
I believe this is not necessary anymore. We call Please, correct me if I'm wrong. |
What does this PR do?
This PR makes sure that the logger.after_save_checkpoint is called after ModelCheckpoint creates a checkpoint "on_train_end".
Note: A bigger refactor has been discussed there: #6231 (comment)
Details:
This code saves 2 ckpt_files in the "saved_ckpt" folder but the
logger.after_save_checkpoint
is called only once.-->
Fixes #
Does your PR introduce any breaking changes? If yes, please list them.
None
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃