-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move block_backward_sync from ParallelPlugin to DDPPlugins #9101
Conversation
1c97633
to
0f6ddf7
Compare
0f6ddf7
to
cc6284a
Compare
Codecov Report
@@ Coverage Diff @@
## master #9101 +/- ##
=======================================
- Coverage 92% 88% -4%
=======================================
Files 176 176
Lines 14663 14670 +7
=======================================
- Hits 13496 12892 -604
- Misses 1167 1778 +611 |
@contextmanager | ||
def block_backward_sync(self): | ||
""" | ||
Blocks ddp sync gradients behaviour on backwards pass. | ||
This is useful for skipping sync when accumulating gradients, reducing communication overhead | ||
Returns: context manager with sync behaviour off | ||
""" | ||
if isinstance(self.model, DistributedDataParallel): | ||
with self.model.no_sync(): | ||
yield None | ||
else: | ||
yield None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: do you think this ought to be its own mini interface to represent this trait?
class PluginWithBlockBackwardSync(ABC):
@contextmanager
@abstractmethod
def block_backward_sync(self) -> Generator:
this way, we only need to check isinstance(isinstance(self.trainer.training_type_plugin, PluginWithBlockBackwardSync)
in the training batch loop.
otherwise i'm not sure about the isinstance
check for custom plugins that require this
(pls ignore the verbose naming)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's a good point! If custom plugin needs this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tchaton @awaelchli @justusschock what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think we can explore making plugins more composable too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ananthsub I agree that we should explore this. My only concern in this direction (we had/have something similar for the trainer and module) is that sometimes it becomes hard to track what is implemented where (especially when debugging), which is why at some point we decided to avoid patterns like this.
I still think though, that together with good purely abstract interfaces this should be possible and is likely the best way to tackle this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it is a balance of good code taste with reliable / general abstractions.
This is useful for skipping sync when accumulating gradients, reducing communication overhead | ||
Returns: context manager with sync behaviour off | ||
""" | ||
if isinstance(self.model, DistributedDataParallel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwiw, ShardedDataParallel supports this too, but we're not taking advantage of it now due to this check 😞
https://fairscale.readthedocs.io/en/latest/_modules/fairscale/nn/data_parallel/sharded_ddp.html#ShardedDataParallel.no_sync
this is also masked with the current inheritance structure, as sharded
doesn't override this
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/sharded.py
ideally, splitting up this inheritance like this
Parallel
/ / \ \
DDP Sharded FDSP Deepspeed
will make these opportunities more apparent
fyi @SeanNaren
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @ananthsub,
I am a bit concerned that making all plugins subclass directly from parallel would result in lot of duplicated code and higher maintenance cost, especially for sharded.
Parallel
/ / \ \
DDP Sharded FDSP Deepspeed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, we could remove the if isinstance(self.model, DistributedDataParallel):
check there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the check for DDP would affect DeepSpeed and fully sharded.
Regarding code duplication, I think if we better abstract the subprocess launch or start_processes in the DDP and DDP spawn plugins to ensure that code can be shared, would this address your concern? Are there other parts of the code you're worried about duplication.
My concern with the inheritance we have now is if things are silently not called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tchaton especially with FSDP and Deepspeed, checkpoint loading and saving is so different from ddp and sharded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, I would prefer to avoid it, but it can have some pros too as you shared there.
I am not against FB trying to PoC a refactor.
@justusschock do you agree with this as you designed it based on inheritance ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tchaton I designed it based on inheritance to avoid code duplication. However, as we get more and more different kinds of plugins, I think it could make sense to split them out to minimal mixins (like the one @ananthsub shared) shared above and then make the actual plugin inherit them.
I know that we decided against mixins, but I think those mixins together with a purely abstract interface class are the best way to tackle this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me ! @ananthsub Mind creating a [RFC] For Refactoring Accelerator around base components and tag the name of the person assign on your side.
self.trainer.lightning_module.automatic_optimization or should_block_sync | ||
): | ||
if ( | ||
isinstance(self.trainer.training_type_plugin, DDPPlugin) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DDPPlugin or DDPSpawnPlugin right ?
isinstance(self.trainer.training_type_plugin, DDPPlugin) | ||
or isinstance(self.trainer.training_type_plugin, DDPPlugin) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isinstance(self.trainer.training_type_plugin, DDPPlugin) | |
or isinstance(self.trainer.training_type_plugin, DDPPlugin) | |
isinstance(self.trainer.training_type_plugin, (DDPPlugin, DDPSpawnPlugin)) |
This is useful for skipping sync when accumulating gradients, reducing communication overhead | ||
Returns: context manager with sync behaviour off | ||
""" | ||
if isinstance(self.model, DistributedDataParallel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @ananthsub,
I am a bit concerned that making all plugins subclass directly from parallel would result in lot of duplicated code and higher maintenance cost, especially for sharded.
Parallel
/ / \ \
DDP Sharded FDSP Deepspeed
This is useful for skipping sync when accumulating gradients, reducing communication overhead | ||
Returns: context manager with sync behaviour off | ||
""" | ||
if isinstance(self.model, DistributedDataParallel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, we could remove the if isinstance(self.model, DistributedDataParallel):
check there.
@carmocca @awaelchli, how about this one? 🦦 |
This has been implemented for Lite in #14966. It'll eventually trickle into the PL interfaces as we merge implementations. cc @awaelchli |
What does this PR do?
Parallel plugin should be generic and self-contained. Move reference to DistributedDataParallel from parallel block_backward_sync
Subtask 2 for "remove reference to DistributedDataParallel from parallel plugin"
Subtask 1 is #8943
Tests:
python -m pytest -v tests/plugins passed
python -m pytest -v tests/accelerators passed
python -m pytest -v tests/trainer passed
Fixes #<issue_number>
Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃