Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Run logger.after_save_checkpoint in model checkpoint's on_train_end hook #9783

Closed
wants to merge 16 commits into from

Conversation

dalessioluca
Copy link

@dalessioluca dalessioluca commented Oct 1, 2021

What does this PR do?

This PR makes sure that the logger.after_save_checkpoint is called after ModelCheckpoint creates a checkpoint "on_train_end".

Note: A bigger refactor has been discussed there: #6231 (comment)

Details:

import torch
from pytorch_lightning.loggers import CSVLogger 
from pytorch_lightning.trainer import Trainer
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint


class Logger(CSVLogger):
    call_counter = 0
    
    def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
        print("called after new ckpt is saved")
        self.call_counter += 1


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    # These two together make sure that I save a ckpt every XXX epochs AND at the end of training
    ckpt_train_end = ModelCheckpoint(
        filename="last",
        save_on_train_epoch_end=False,
        save_last=True,
    )

    ckpt_train_interval = ModelCheckpoint(
        filename="my_checkpoint-{epoch}",
        save_on_train_epoch_end=True,
        save_last=False,
        every_n_epochs=6,
    )

    model = BoringModel()

    trainer = Trainer(
        logger=Logger(save_dir='./logging_dir', flush_logs_every_n_steps=1),
        max_epochs=9,
        log_every_n_steps=1,
        weights_save_path="saved_ckpt",
        callbacks=[ckpt_train_interval, ckpt_train_end],
    )

    trainer.fit(model, train_dataloaders=train_data)

    assert trainer.logger.call_counter == 2

if __name__ == "__main__":
    run()

This code saves 2 ckpt_files in the "saved_ckpt" folder but the logger.after_save_checkpoint is called only once.

-->

Fixes #

Does your PR introduce any breaking changes? If yes, please list them.

None

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • [x ] Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@ananthsub ananthsub added bug Something isn't working checkpointing Related to checkpointing labels Oct 1, 2021
Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for identifying this! did you see any other points in the checkpoint callback where the logger was not triggered?

@ananthsub ananthsub changed the title bugfix. ModelCheckpoint callback creates checkpoint "on_train_end" but does not trigger the logger.after_save_checkpoint [bugfix] Run logger.after_save_checkpoint in model checkpoint's on_train_end hook Oct 2, 2021
@mergify mergify bot removed the has conflicts label Oct 12, 2021
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@tchaton tchaton enabled auto-merge (squash) October 12, 2021 09:55
@codecov
Copy link

codecov bot commented Oct 12, 2021

Codecov Report

Merging #9783 (5801163) into master (6da5829) will decrease coverage by 4%.
The diff coverage is 100%.

❗ Current head 5801163 differs from pull request most recent head 0c2b15c. Consider uploading reports for the commit 0c2b15c to get more accurate results

@@           Coverage Diff           @@
##           master   #9783    +/-   ##
=======================================
- Coverage      93%     89%    -4%     
=======================================
  Files         178     178            
  Lines       15648   15650     +2     
=======================================
- Hits        14503   13897   -606     
- Misses       1145    1753   +608     

Copy link
Contributor

@rohitgr7 rohitgr7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks good. I think the test can be improved or configured with one of the existing tests which tests for save_last

@@ -77,6 +77,53 @@ def mock(key):
return calls


@pytest.mark.parametrize("save_last", [False, True]) # problem with save_last = True
@pytest.mark.parametrize("save_on_train_epoch_end", [True, False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per PR title, the bug was that it's not calling after_save_ckpt for last.ckpt which is saved in on_train_end. so do we need this? since on_train_end is called irrespective of save_on_train_epoch_end.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is meant to raise an error in all scenarios in which logger.after_save_checkpoint is not called after a checkpoint file is saved so we need to test both save_last = True and False.

triggered_train_epoch_end: bool = False
trainer: Optional[Trainer]

def after_save_checkpoint(self, *_):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def after_save_checkpoint(self, *_):
def after_save_checkpoint(self, *_):

I think a better test would be to mock LightningLoggerBase.after_save_checkpoint and check it's call count

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good thought. Originally, I tried to implement the test as you suggested but did not succeed. The problem is that:

  1. save_checkpoint calls 3 separate saving routines (_save_top_k_checkpoint, _save_none_monitor_checkpoint, _save_last_checkpoint) followed by a single call to Logger.after_save_checkpoint.
  2. on_train_end calls a single saving routine (_save_last_checkpoint) followed by a single call to Logger.after_save_checkpoint.

This makes checking the call counts a fragile way to test that Logger.after_save_checkpoint is called appropriately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you can try just this config:

ModelCheckpoint(save_top_k=0, monitor=None, save_last=True)

it only saves the last.ckpt, thus will ensure the after_save_checkpoint call for on_train_end.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have rewritten the test. It now checks that the set of ckpt files generated by the trainer matches with the set of ckpt files received by the logger.after_save_checkpoint.

dalessioluca added 2 commits October 12, 2021 09:46
aMerge branch 'bugfix/ckpt_and_logger' of https://github.com/dalessioluca/pytorch-lightning into bugfix/ckpt_and_logger
# Conflicts:
#	tests/checkpointing/test_model_checkpoint.py
auto-merge was automatically disabled October 12, 2021 13:49

Head branch was pushed to by a user without write access

@stale
Copy link

stale bot commented Oct 27, 2021

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://pytorch-lightning.readthedocs.io/en/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Slack. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Oct 27, 2021
@rohitgr7
Copy link
Contributor

rohitgr7 commented Nov 1, 2021

hey @dalessioluca , thank you for this PR and identifying issues here. looks like it does log the model here atleast for wandb: https://github.com/PyTorchLightning/pytorch-lightning/blob/45c45dc7b018f9a2db60f5df1a3f7dbbb45ccb36/pytorch_lightning/loggers/wandb.py#L461-L464

but the way these are logged and how it connects to the logger after checkpointing isn't super convenient and can be done in a more optimal way. We should call after_save_checkpoint inside trainer.save_checkpoint(). I'll do another PR after sometime which might solve more issues and this one too. In case you want to make changes to loggers and checkpoint then it would be great else you can close this one. Again thanks a lot for this.. was really helpful identifying more underlying issues.

@stale stale bot removed the won't fix This will not be worked on label Nov 1, 2021
@tchaton
Copy link
Contributor

tchaton commented Nov 1, 2021

Dear @dalessioluca,

Based on @rohitgr7 latest comment, I believe it is better to make sure after_save_checkpoint and save_checkpoint are called together.

Would you be willing to keep going with this PR based on @rohitgr7 comments ?

Best,
T.C

@tchaton tchaton modified the milestones: v1.6, v1.6.x Nov 1, 2021
@dalessioluca
Copy link
Author

It indeed seems more clean to call after_save_checkpoint inside save_checkpoint.
I will modify the PR accordingly.

@rohitgr7
Copy link
Contributor

rohitgr7 commented Nov 1, 2021

also you might need to update some loggers as well. for starters, wandb won't be needing the finalize method if we call after_save_checkpoint right after save_checkpoint.

@awaelchli awaelchli modified the milestones: v1.6.x, 1.5.x Nov 3, 2021
@Borda Borda modified the milestones: 1.5.x, 1.6 Mar 21, 2022
@rohitgr7 rohitgr7 self-assigned this Mar 21, 2022
@rohitgr7 rohitgr7 modified the milestones: 1.6, 1.6.x Mar 21, 2022
@carmocca
Copy link
Contributor

I believe this is not necessary anymore. We call after_save_checkpoint already whenever we save: https://github.com/Lightning-AI/lightning/blob/233b36b185ae71b480ddd5c87ff5d12cd70c88fe/src/pytorch_lightning/callbacks/model_checkpoint.py#L386-L394

Please, correct me if I'm wrong.

@carmocca carmocca closed this Jul 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing has conflicts
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants