-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix gradient accumulation for ShardedDataParallel
#9122
Fix gradient accumulation for ShardedDataParallel
#9122
Conversation
Codecov Report
@@ Coverage Diff @@
## master #9122 +/- ##
=======================================
- Coverage 93% 89% -4%
=======================================
Files 179 179
Lines 15303 15317 +14
=======================================
- Hits 14200 13590 -610
- Misses 1103 1727 +624 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
Thanks for the PR @ananthsub! What we can ensure in a test is that the context manager is called correctly during grad accumulation, maybe by mocking/patching the function? |
@tchaton I'll add a test first before merging. still on my todo |
eda6d7c
to
024cc87
Compare
64e3e7c
to
bc9b7ee
Compare
e163f72
to
a727b24
Compare
for more information, see https://pre-commit.ci
* Fix gradient accumulation for `ShardedDataParallel` * Update changelog * Update pytorch_lightning/plugins/training_type/sharded.py * add test * Update test_sharded_plugin.py * Update test_sharded_plugin.py * Update test_sharded_plugin.py
What does this PR do?
Followup to this comment: #9101 (comment)
ShardedDataParallel
supports theno_sync
context manager too: https://fairscale.readthedocs.io/en/latest/_modules/fairscale/nn/data_parallel/sharded_ddp.html#ShardedDataParallel.no_syncBut we're not taking advantage of it now due to this check:https://github.com/PyTorchLightning/pytorch-lightning/blob/9d62f248476c6358d8707188f7b20fafa79f8a4f/pytorch_lightning/plugins/training_type/parallel.py#L131-L133
as the model here is wrapped with
ShardedDataParallel
notDistributedDataParallel
Breaking the inheritance chain in these plugins will make these opportunities clearer.
@SeanNaren n00b question: do you have suggestions for how to verify this a unit/integration test? especially to prevent future regressions
Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃