Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skips DDP parameter sync #4301

Merged
merged 9 commits into from
Oct 29, 2020
Merged

Skips DDP parameter sync #4301

merged 9 commits into from
Oct 29, 2020

Conversation

justusschock
Copy link
Member

@justusschock justusschock commented Oct 22, 2020

What does this PR do?

Skips DDP parameter sync in forward and backward whenever possible.

Fixes #4092 and fixes #2595

Is a revamp of #4146 after adding optimiser closures.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@justusschock justusschock added the feature Is an improvement or enhancement label Oct 22, 2020
@justusschock justusschock self-assigned this Oct 22, 2020
@pep8speaks
Copy link

pep8speaks commented Oct 22, 2020

Hello @justusschock! Thanks for updating this PR.

Line 688:121: E501 line too long (121 > 120 characters)

Comment last updated at 2020-10-29 17:04:11 UTC

@mergify mergify bot requested a review from a team October 22, 2020 09:21
@codecov
Copy link

codecov bot commented Oct 22, 2020

Codecov Report

Merging #4301 into master will increase coverage by 2%.
The diff coverage is 90%.

@@           Coverage Diff           @@
##           master   #4301    +/-   ##
=======================================
+ Coverage      91%     93%    +2%     
=======================================
  Files         113     111     -2     
  Lines        8301    8134   -167     
=======================================
+ Hits         7527    7553    +26     
+ Misses        774     581   -193     

@williamFalcon
Copy link
Contributor

cc @ananthsub for review

Comment on lines 684 to 694

# no ddp sync at the beginning of forward or backward due to parameter changes in this or last step required
no_sync = self._updated_model_last_step and isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel)
if no_sync:
self.trainer.model.no_sync.__enter__()

self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)

if no_sync:
self.trainer.model.__exit__()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible for this to logic to live in the DDP accelerators as opposed to the general training loop?

@@ -16,6 +16,7 @@
from copy import copy, deepcopy

import numpy as np
from numpy.lib.arraysetops import isin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where's this used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not. was added by my IDE as auto import during typing isinstance :D
removed it

@awaelchli
Copy link
Contributor

awaelchli commented Oct 26, 2020

This is implementing the same as requested in #2595, right?

Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

besides my small comments, the PR looks good to me. the accelerator part is my only significant question. if it's not immediately obvious how we can move to this check to the DDP accelerators, then let's land this to reap the perf savings and figure out the refactor later. i really like that the training loop now barely has any references to pytorch specifics

@justusschock
Copy link
Member Author

@awaelchli yes, haven't seen that one.

pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
@tchaton tchaton self-requested a review October 27, 2020 09:44
Comment on lines +688 to +691
# perform dpp sync only when performing optimizer_step
with self.block_ddp_sync_behaviour():
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)

Copy link
Contributor

@ananthsub ananthsub Oct 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need block_ddp_sync_behaviour ?

Suggested change
# perform dpp sync only when performing optimizer_step
with self.block_ddp_sync_behaviour():
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)
# perform dpp sync only when performing optimizer_step
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
with self.trainer.model.no_sync():
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)
else:
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make it more explicit for our readers :)

Copy link
Member Author

@justusschock justusschock Oct 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ananthsub @tchaton requested this to make it more readable and hide the conditions for that in the context manager

Copy link
Contributor

@tchaton tchaton Oct 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ananthsub
The training and evaluation loop should look absolutely perfect and simple to understand.
As a new coder, I should recognise my training loop. As a new coder, I have no knowledge about ddp and your suggested change would have confused me :)

One great example of this PR is: not (accumulation_done or is_final_batch) -> should_accumulate. It is pretty simple, but make a clear statement about what is happening. We should try to enforce those pattern as much as possible :)

I hope it makes sense :)

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch !

@@ -46,6 +46,7 @@ def __init__(self, trainer):
self.automatic_optimization = True
self._curr_step_result = None
self._cur_grad_norm_dict = None
self._updated_model_last_step = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please remove _updated_model_last_step . You don't use it anymore :)

Comment on lines +688 to +691
# perform dpp sync only when performing optimizer_step
with self.block_ddp_sync_behaviour():
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make it more explicit for our readers :)

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great addition !

Copy link
Contributor

@s-rog s-rog left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Should be a nice performance boost!

@rohitgr7 rohitgr7 merged commit bbd81df into master Oct 29, 2020
@rohitgr7 rohitgr7 deleted the ddp_no_sync branch October 29, 2020 17:36
@edenlightning edenlightning added this to the 1.0.x milestone Nov 4, 2020
Borda pushed a commit that referenced this pull request Nov 4, 2020
* ddp no-sync

* Update pytorch_lightning/trainer/training_loop.py

Co-authored-by: ananthsub <[email protected]>

* Update training_loop.py

* factor __enter__ and __exit__ out to separate context manager

* delete _updated_model_last_step

Co-authored-by: justusschock <[email protected]>
Co-authored-by: Teddy Koker <[email protected]>
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: chaton <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
(cherry picked from commit bbd81df)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Avoid unnecessary DDP synchronization when gradient_accumulation_steps > 1 ddp no_sync