Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support checkpoint save and load with Stochastic Weight Averaging #9938

Merged
merged 94 commits into from
Aug 9, 2022

Conversation

adamreeve
Copy link
Contributor

@adamreeve adamreeve commented Oct 15, 2021

What does this PR do?

Partly addresses #6074 by supporting saving and loading the StochasticWeightAveraging callback data in checkpoints. Support for using SWA during validation will be done as a follow up PR.

Does your PR introduce any breaking changes? If yes, please list them.

No

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG**? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

  • Make sure you had fun coding 🙃

@tchaton tchaton added this to the v1.6 milestone Nov 1, 2021
@tchaton tchaton added feature Is an improvement or enhancement bug Something isn't working and removed feature Is an improvement or enhancement labels Nov 1, 2021
@tchaton tchaton modified the milestones: v1.6, v1.6.x Nov 1, 2021
@awaelchli awaelchli modified the milestones: v1.6.x, 1.5.x Nov 3, 2021
@adamreeve adamreeve changed the title [Draft] Support checkpoint save and load with Stochastic Weight Averaging Support checkpoint save and load with Stochastic Weight Averaging Nov 8, 2021
@adamreeve adamreeve marked this pull request as ready for review November 8, 2021 20:28
@rohitgr7 rohitgr7 linked an issue Aug 8, 2022 that may be closed by this pull request
Copy link
Contributor

@rohitgr7 rohitgr7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!
Great work!

Apologies for the delayed review! but it was complex and you pulled it out smartly 😃 🚀

@mergify mergify bot added ready PRs ready to be merged has conflicts and removed has conflicts ready PRs ready to be merged labels Aug 9, 2022
@awaelchli awaelchli enabled auto-merge (squash) August 9, 2022 22:41
@mergify mergify bot added ready PRs ready to be merged and removed has conflicts ready PRs ready to be merged labels Aug 9, 2022
@awaelchli awaelchli merged commit 975a4fc into Lightning-AI:master Aug 9, 2022
@adamreeve adamreeve deleted the swa_checkpoint branch August 10, 2022 08:24
@awaelchli
Copy link
Contributor

@rohitgr7 Will this go into 1.7.x?

rohitgr7 added a commit that referenced this pull request Aug 15, 2022
)

Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
@rohitgr7
Copy link
Contributor

@awaelchli yes

jessecambon pushed a commit to jessecambon/lightning that referenced this pull request Aug 16, 2022
…ghtning-AI#9938)

Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
lexierule pushed a commit that referenced this pull request Aug 17, 2022
* update version and changelog for 1.7.2 release

* Reset all results on epoch end (#14061)

Co-authored-by: Carlos Mocholí <[email protected]>

* Skip ddp fork tests on windows (#14121)

* Fix device placement when `.cuda()` called without specifying index (#14128)

* Convert subprocess test to standalone test (#14101)

* Fix entry point test for Python 3.10 (#14154)

* Fix flaky test caused by weak reference (#14157)

* Fix saving hyperparameters in a composition where parent is not a LM or LDM (#14151)



Co-authored-by: Rohit Gupta <[email protected]>

* Remove DeepSpeed version restriction from Lite (#13967)

* Configure the check-group app (#14165)

Co-authored-by: Jirka <[email protected]>

* Update onnxruntime requirement from <=1.12.0 to <1.13.0 in /requirements (#14083)

Updates the requirements on [onnxruntime](https://github.com/microsoft/onnxruntime) to permit the latest version.
- [Release notes](https://github.com/microsoft/onnxruntime/releases)
- [Changelog](https://github.com/microsoft/onnxruntime/blob/master/docs/ReleaseManagement.md)
- [Commits](microsoft/onnxruntime@v0.1.4...v1.12.1)

---
updated-dependencies:
- dependency-name: onnxruntime
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Update gcsfs requirement from <2022.6.0,>=2021.5.0 to >=2021.5.0,<2022.8.0 in /requirements (#14079)

Update gcsfs requirement in /requirements

Updates the requirements on [gcsfs](https://github.com/fsspec/gcsfs) to permit the latest version.
- [Release notes](https://github.com/fsspec/gcsfs/releases)
- [Commits](fsspec/gcsfs@2021.05.0...2022.7.1)

---
updated-dependencies:
- dependency-name: gcsfs
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <[email protected]>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Fix a bug that caused spurious `AttributeError` when multiple `DataLoader` classes are imported (#14117)


fix

* CI: Replace `_` of in GHA workflow filenames with `-` (#13917)

* Rename workflow files

* Update docs

* Fix azure badges

* Update the main readme

* bad rebase

* Update doc

* CI: Update Windows version from 2019 to 2022 (#14129)

Update windows

* CI/CD: Add CUDA version to docker image tags (#13831)

* append cuda version to tags

* revertme: push to hub

* Update docker readme

* Build base-conda-py3.9-torch1.12-cuda11.3.1

* Use new images in conda tests

* revertme: push to hub

* Revert "revertme: push to hub"

This reverts commit 0f7d534.

* Revert "revertme: push to hub"

This reverts commit 46a05fc.

* Run conda if workflow edited

* Run gpu testing if workflow edited

* Use new tags in release/Dockerfile

* Build base-cuda and PL release images with all combinations

* Update release docker

* Update conda from py3.9-torch1.12 to py3.10-torch.1.12

* Fix ubuntu version

* Revert conda

* revertme: push to hub

* Don't build Python 3.10 for now...

* Fix pl release builder

* updating version contribute to the error? docker/buildx#456

* Update actions' versions

* Update slack user to notify

* Don't use 11.6.0 to avoid bagua incompatibility

* Don't use 11.1, and use 11.1.1

* Update .github/workflows/ci-pytorch_test-conda.yml

Co-authored-by: Luca Medeiros <[email protected]>

* Update trigger

* Ignore artfacts from tutorials

* Trim docker images to distribute

* Add an image for tutorials

* Update conda image 3.8x1.10

* Try different conda variants

* No need to set cuda for conda jobs

* Update who to notify ipu failure

* Don't push

* update filenaem

Co-authored-by: Luca Medeiros <[email protected]>

* Avoid entry_points deprecation warning (#14052)

Co-authored-by: Adam J. Stewart <[email protected]>
Co-authored-by: Akihiro Nitta <[email protected]>

* Configure the check-group app (#14165)

Co-authored-by: Jirka <[email protected]>

* Profile batch transfer and gradient clipping hooks (#14069)

Co-authored-by: Rohit Gupta <[email protected]>

* Avoid false positive warning about using `sync_dist` when using torchmetrics (#14143)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>

* Avoid raising the sampler warning if num_replicas=1 (#14097)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>

Co-authored-by: otaj <[email protected]>

* Remove skipping logic in favor of path filtering (#14170)

* Support checkpoint save and load with Stochastic Weight Averaging (#9938)

Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>

* Use fsdp module to initialize precision scalar for fsdp native (#14092)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Laverne Henderson <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>

* add more issues types (#14174)

* add more issues types

* Update .github/ISSUE_TEMPLATE/config.yml

Co-authored-by: Mansy <[email protected]>

* typo

Co-authored-by: Adrian Wälchli <[email protected]>

Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Mansy <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Laverne Henderson <[email protected]>
Co-authored-by: Akihiro Nitta <[email protected]>

* CI: clean building docs (#14216)

* CI: clean building docs

* group

* .

* CI: docker focus on PL only (#14246)

* CI: docker focus on PL only

* group

* Allowed setting attributes on `DataLoader` and `BatchSampler` when instantiated inside `*_dataloader` hooks (#14212)


Co-authored-by: otaj <[email protected]>

* Revert "Remove skipping logic in favor of path filtering (#14170)" (#14244)

* Update defaults for WandbLogger's run name and project name (#14145)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>
Co-authored-by: Luca Medeiros <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
Co-authored-by: otaj <[email protected]>
Co-authored-by: Adam Reeve <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
Co-authored-by: Laverne Henderson <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Mansy <[email protected]>
@kampelmuehler
Copy link

kampelmuehler commented Jan 16, 2023

@adamreeve
Am I correct that the "follow up PR" on using the _average_model for validation hasn't yet been done?

For those looking to do it manually, this would be a very rough sketch

    def on_fit_start(self):
        for cb in self.trainer.callbacks:
            if isinstance(cb, StochasticWeightAveraging):
                self.swa = cb

    def validation_step(self, batch, batch_idx):
        x, y = batch

        loss = self.loss_fn(self(x), y)
        print(f"{loss.item():.02f}", end="")
        if self.swa._initialized:
            loss_swa = self.loss_fn(self.swa._average_model(x), y)
            print(f" {loss_swa.item():.02f}", end="")
        print()
        return loss

@kampelmuehler
Copy link

kampelmuehler commented Jan 16, 2023

@adamreeve
I think that these features should also somehow be documented. Currently the docs include nothing on how to actually use the SWA model (or that indeed only in the end of training the weights are transferred to the original model). What do you think?

Edit:
Couple of users unsure about SWA usage
https://lightning.ai/forums/t/how-to-implement-swa/1761
https://lightning.ai/forums/t/stochasticweightaveraging-validation-logging-and-checkpoints/2023/2

@kampelmuehler
Copy link

kampelmuehler commented Jan 16, 2023

Also:
The checkpointing is not in any sense automated, or is it?
Only way I see thus far is to manually load the checkpoint, get the state_dict for swa and load it manually into the callback instance.

If I just declare the model, add an swa callback to the trainer and pass the checkpoint path to trainer.fit() it won't load the swa callback states.

EDIT:
I was wrong, passing to trainer.fit() correctly loads the callback state as well, trainer.predict(), however, will not load the swa callback state. Which might be intentional?

@kampelmuehler
Copy link

cc @awaelchli

@adamreeve
Copy link
Contributor Author

adamreeve commented Jan 16, 2023

Hi @kampelmuehler

Am I correct that the "follow up PR" on using the _average_model for validation hasn't yet been done?

Yes, that was originally part of this PR but the scope was reduced, so #6074 probably shouldn't have been closed. I had started working towards that in a separate branch (https://github.com/adamreeve/pytorch-lightning/commits/swa_validation) but that's now quite far behind master and doesn't include all the changes from this PR. Supporting batch normalization in conjunction with SWA made that a lot more complicated, but I think it was working. I don't currently have any plans to continue with that work.

This PR didn't really add any new user-visible feature but just fixed checkpointing to work correctly with SWA so that training could be resumed. Eg. it fixed #11665. But I agree that the documentation could be improved to better explain how SWA works.

It's a while since I looked at this now so I'm not sure whether trainer.predict() not loading the SWA parameters is intentional or just a limitation of the current approach, but it sounds consistent with the behaviour of the averaged parameters not being transferred until training is completed.

@kampelmuehler
Copy link

Hi @adamreeve - thanks for the quick response and all the insights!

@zhong-yy
Copy link

Hi, is swa_validation removed in the latest version?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback: swa community This PR is from the community feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
No open projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Cant reload from checkpoint when using SWA