Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remember the eval mode of submodules when switching trainer stages #18951

Merged
merged 34 commits into from
Nov 16, 2023

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Nov 5, 2023

What does this PR do?

Fixes #18930
Part of #16827

A common issue users are facing is that the loop calls train() on the LightningModule despite the user having frozen certain layers. For example,

  • user finetunes a model and freezes certain layers they don't want to train
  • user's model has a feature extractor they don't train

This leads to a surprise when the user finds out that their batch norm layers have changed statistics, even though they were set explicitly to eval() mode. To avoid this, the user has to learn that they should override the on_validation_model_eval() and on_validation_model_train() hooks in the module, but this is a detail difficult to find in our docs and get right. Most users who face this challenge end up on slack or GH to ask for help.

The PR makes the following changes to automate this for the user:

  • The validation loop captures the .training mode of every submodule before calling .eval() now. When the validation loop ends, and before switching to training, it restores the .training mode on all submodules to what it was before. This ensures that layers the user has chosen to be in eval mode remain in eval mode!
  • The fit-loop no longer calls .train() at the beginning with the same motivation: The user can now set a subset of their model to .eval() mode / freeze it explicitly in the LightningModule's __init__ without doing acrobatics with hooks, and the Trainer will respect it and preserve it (see the added test). Note: This is not a breaking change, because PyTorch's default is to have a model in .training=True mode.

📚 Documentation preview 📚: https://pytorch-lightning--18951.org.readthedocs.build/en/18951/

cc @Borda @justusschock @awaelchli

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Nov 5, 2023
@awaelchli awaelchli added fun Staff contributions outside working hours - to differentiate from the "community" label trainer: validate trainer: fit labels Nov 12, 2023
@awaelchli awaelchli added this to the 2.2 milestone Nov 12, 2023
@awaelchli awaelchli added the feature Is an improvement or enhancement label Nov 12, 2023
@awaelchli awaelchli changed the title [WIP] Remember the eval mode of submodules when switching trainer stages Remember the eval mode of submodules when switching trainer stages Nov 12, 2023
@github-actions github-actions bot added the docs Documentation related label Nov 12, 2023
@awaelchli awaelchli marked this pull request as ready for review November 12, 2023 03:28
Copy link

codecov bot commented Nov 12, 2023

Codecov Report

Merging #18951 (e71ab68) into master (b80107e) will decrease coverage by 27%.
Report is 1 commits behind head on master.
The diff coverage is 100%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18951      +/-   ##
==========================================
- Coverage      76%      48%     -27%     
==========================================
  Files         450      442       -8     
  Lines       36508    36383     -125     
==========================================
- Hits        27583    17572   -10011     
- Misses       8925    18811    +9886     

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this PR replace #18826?

src/lightning/pytorch/utilities/model_helpers.py Outdated Show resolved Hide resolved
tests/tests_pytorch/models/test_hooks.py Outdated Show resolved Hide resolved
src/lightning/pytorch/loops/evaluation_loop.py Outdated Show resolved Hide resolved
@awaelchli awaelchli requested a review from carmocca November 16, 2023 05:30
@awaelchli
Copy link
Contributor Author

Does this PR replace #18826?

@carmocca Good find! Yes in fact it does, I verified thanks to the repro example that was posted there. After this PR lands, we could still use #18826 to add an additional test case for the tuner.

@mergify mergify bot added the ready PRs ready to be merged label Nov 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs Documentation related feature Is an improvement or enhancement fun Staff contributions outside working hours - to differentiate from the "community" label pl Generic label for PyTorch Lightning package ready PRs ready to be merged trainer: fit trainer: validate
Projects
None yet
3 participants