Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: model weight updates with automatic_optimization=False in mixed precision training #20460

Conversation

iamarunbrahma
Copy link

@iamarunbrahma iamarunbrahma commented Dec 2, 2024

What does this PR do?

Fixes an issue where model weights were not being properly updated when using automatic_optimization=False with mixed precision training (16-mixed precision). The core issues were:

  1. Improper gradient scaling/unscaling in the AMP plugin
  2. Missing validation of gradient values before optimizer steps
  3. Lack of proper error handling for repeated unscale operations

The changes include:

  • Enhanced AMP plugin implementation with proper gradient scaling workflow
  • Added explicit gradient unscaling before optimizer steps
  • Improved error handling for edge cases
  • Added warning when gradients become NaN/inf during training

Fixes #20215

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--20460.org.readthedocs.build/en/20460/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Dec 2, 2024
@lantiga
Copy link
Collaborator

lantiga commented Dec 3, 2024

Thank you for the contribution, will review today

@lantiga lantiga added the precision: amp Automatic Mixed Precision label Dec 3, 2024
Copy link

codecov bot commented Dec 3, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 79%. Comparing base (be608fa) to head (4e6faf5).

❗ There is a different number of reports uploaded between BASE (be608fa) and HEAD (4e6faf5). Click for more details.

HEAD has 542 uploads less than BASE
Flag BASE (be608fa) HEAD (4e6faf5)
cpu 146 24
lightning_fabric 21 0
pytest 76 0
python3.9 36 6
lightning 109 18
python3.11 36 6
python3.10 19 3
gpu 2 0
python3.12 55 9
pytorch2.1 27 9
pytest-full 72 24
pytorch_lightning 18 6
pytorch2.2.2 9 3
pytorch2.3 9 3
pytorch2.5.1 18 6
pytorch2.4.1 9 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #20460     +/-   ##
=========================================
- Coverage      88%      79%     -9%     
=========================================
  Files         267      265      -2     
  Lines       23276    23288     +12     
=========================================
- Hits        20383    18322   -2061     
- Misses       2893     4966   +2073     

@lantiga
Copy link
Collaborator

lantiga commented Dec 4, 2024

@lantiga
Copy link
Collaborator

lantiga commented Dec 4, 2024

@iamarunbrahma looking at the signature for optimizer_step from, are you working off a pre-2.0 version of PyTorch Lightning? optimizer_idx was removed before 2.0.0 in #16539

@lantiga
Copy link
Collaborator

lantiga commented Dec 4, 2024

ok the more I look at this the more puzzled I am, can you clarify the code in the PR @iamarunbrahma, how it fits with the plugin optimizer_step and how it solves the original issue?

@lantiga
Copy link
Collaborator

lantiga commented Dec 4, 2024

BTW the title is contradicting the issue: the issue describes something happening with automatic_optimization=True, not False. I'm closing this one.

@lantiga lantiga closed this Dec 4, 2024
@iamarunbrahma iamarunbrahma deleted the automatic_optimization_fix branch December 6, 2024 18:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pl Generic label for PyTorch Lightning package precision: amp Automatic Mixed Precision
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Model does not update its weights
2 participants