Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move block_backward_sync from ParallelPlugin to DDPPlugins #9101

Closed

Conversation

four4fish
Copy link
Contributor

@four4fish four4fish commented Aug 25, 2021

What does this PR do?

Parallel plugin should be generic and self-contained. Move reference to DistributedDataParallel from parallel block_backward_sync

Subtask 2 for "remove reference to DistributedDataParallel from parallel plugin"
Subtask 1 is #8943

Tests:
python -m pytest -v tests/plugins passed
python -m pytest -v tests/accelerators passed
python -m pytest -v tests/trainer passed

Fixes #<issue_number>

Does your PR introduce any breaking changes? If yes, please list them.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@four4fish four4fish force-pushed the remove/block_backward branch from 0f6ddf7 to cc6284a Compare August 25, 2021 06:45
@codecov
Copy link

codecov bot commented Aug 25, 2021

Codecov Report

Merging #9101 (cc6284a) into master (de57fef) will decrease coverage by 4%.
The diff coverage is 50%.

@@           Coverage Diff           @@
##           master   #9101    +/-   ##
=======================================
- Coverage      92%     88%    -4%     
=======================================
  Files         176     176            
  Lines       14663   14670     +7     
=======================================
- Hits        13496   12892   -604     
- Misses       1167    1778   +611     

Comment on lines +447 to +458
@contextmanager
def block_backward_sync(self):
"""
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, DistributedDataParallel):
with self.model.no_sync():
yield None
else:
yield None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: do you think this ought to be its own mini interface to represent this trait?

class PluginWithBlockBackwardSync(ABC):

    @contextmanager
    @abstractmethod
    def block_backward_sync(self) -> Generator:

this way, we only need to check isinstance(isinstance(self.trainer.training_type_plugin, PluginWithBlockBackwardSync) in the training batch loop.

otherwise i'm not sure about the isinstance check for custom plugins that require this

(pls ignore the verbose naming)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a good point! If custom plugin needs this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton @awaelchli @justusschock what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we can explore making plugins more composable too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ananthsub I agree that we should explore this. My only concern in this direction (we had/have something similar for the trainer and module) is that sometimes it becomes hard to track what is implemented where (especially when debugging), which is why at some point we decided to avoid patterns like this.

I still think though, that together with good purely abstract interfaces this should be possible and is likely the best way to tackle this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it is a balance of good code taste with reliable / general abstractions.

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, DistributedDataParallel):
Copy link
Contributor

@ananthsub ananthsub Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw, ShardedDataParallel supports this too, but we're not taking advantage of it now due to this check 😞
https://fairscale.readthedocs.io/en/latest/_modules/fairscale/nn/data_parallel/sharded_ddp.html#ShardedDataParallel.no_sync

this is also masked with the current inheritance structure, as sharded doesn't override this
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/sharded.py

ideally, splitting up this inheritance like this

                       Parallel
           /      /             \       \
         DDP      Sharded      FDSP      Deepspeed

will make these opportunities more apparent

fyi @SeanNaren

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ananthsub,

I am a bit concerned that making all plugins subclass directly from parallel would result in lot of duplicated code and higher maintenance cost, especially for sharded.

                       Parallel
           /      /             \       \
         DDP      Sharded      FDSP      Deepspeed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we could remove the if isinstance(self.model, DistributedDataParallel): check there.

Copy link
Contributor

@ananthsub ananthsub Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the check for DDP would affect DeepSpeed and fully sharded.

Regarding code duplication, I think if we better abstract the subprocess launch or start_processes in the DDP and DDP spawn plugins to ensure that code can be shared, would this address your concern? Are there other parts of the code you're worried about duplication.

My concern with the inheritance we have now is if things are silently not called.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton especially with FSDP and Deepspeed, checkpoint loading and saving is so different from ddp and sharded

Copy link
Contributor

@tchaton tchaton Aug 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, I would prefer to avoid it, but it can have some pros too as you shared there.
I am not against FB trying to PoC a refactor.
@justusschock do you agree with this as you designed it based on inheritance ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton I designed it based on inheritance to avoid code duplication. However, as we get more and more different kinds of plugins, I think it could make sense to split them out to minimal mixins (like the one @ananthsub shared) shared above and then make the actual plugin inherit them.

I know that we decided against mixins, but I think those mixins together with a purely abstract interface class are the best way to tackle this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me ! @ananthsub Mind creating a [RFC] For Refactoring Accelerator around base components and tag the name of the person assign on your side.

@ananthsub ananthsub added distributed Generic distributed-related topic refactor labels Aug 25, 2021
@ananthsub ananthsub added this to the v1.5 milestone Aug 25, 2021
self.trainer.lightning_module.automatic_optimization or should_block_sync
):
if (
isinstance(self.trainer.training_type_plugin, DDPPlugin)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DDPPlugin or DDPSpawnPlugin right ?

Comment on lines +434 to +435
isinstance(self.trainer.training_type_plugin, DDPPlugin)
or isinstance(self.trainer.training_type_plugin, DDPPlugin)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
isinstance(self.trainer.training_type_plugin, DDPPlugin)
or isinstance(self.trainer.training_type_plugin, DDPPlugin)
isinstance(self.trainer.training_type_plugin, (DDPPlugin, DDPSpawnPlugin))

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, DistributedDataParallel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ananthsub,

I am a bit concerned that making all plugins subclass directly from parallel would result in lot of duplicated code and higher maintenance cost, especially for sharded.

                       Parallel
           /      /             \       \
         DDP      Sharded      FDSP      Deepspeed

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, DistributedDataParallel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we could remove the if isinstance(self.model, DistributedDataParallel): check there.

@Borda
Copy link
Member

Borda commented Sep 12, 2022

@carmocca @awaelchli, how about this one? 🦦

@carmocca
Copy link
Contributor

carmocca commented Nov 8, 2022

This has been implemented for Lite in #14966. It'll eventually trickle into the PL interfaces as we merge implementations. cc @awaelchli

@carmocca carmocca closed this Nov 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed Generic distributed-related topic has conflicts refactor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants