Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid rewrapping LightningModules in plugins #8593

Closed
ananthsub opened this issue Jul 28, 2021 · 10 comments · Fixed by #13738
Closed

Avoid rewrapping LightningModules in plugins #8593

ananthsub opened this issue Jul 28, 2021 · 10 comments · Fixed by #13738
Assignees
Labels
distributed Generic distributed-related topic feature Is an improvement or enhancement strategy
Milestone

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Jul 28, 2021

🚀 Feature

Background:

We are auditing the Lightning components and APIs to assess opportunities for improvements:

Motivation

Lightning has the concept of a Training Type Plugin, which functions as the distribution strategy to be used during execution. For this, Lightning offers an API and out of the box implementations for data parallel-based approaches (DataParallel, DistributedDataParallel, ShardedDataParallel, DeepSpeed, etc). The trainer wraps the whole LightningModule with a module-wrapper in order to facilitate the gradient synchronization.

This allows users to easily specify distributed training options like:

trainer = Trainer(num_nodes=..., gpus=..., plugins=["..."]) 
module = MyLightningModule(...)
trainer.fit(module)

The situation we'd like to explicitly avoid is if the user makes successive calls to Trainer APIs like this:

trainer = Trainer(num_nodes=..., gpus=..., plugins=["..."]) 
module = MyLightningModule(...)
trainer.fit(module)
trainer.test(module)

potentially causing the wrapper to be applied multiple times to the LightningModule. (e.g. DistributedDataParallel(DistributedDataParallel(LightningModule)) )

Pitch

Inside of the following plugins, avoid wrapping the plugin's self._model if the model is already an instance of the wrapped type.

Plugins to update
DDP: https://github.com/PyTorchLightning/pytorch-lightning/blob/80c529351439a0f8d3d6e9449cd47d16ba3abbec/pytorch_lightning/plugins/training_type/ddp.py#L249-L256
DDP Spawn: https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/plugins/training_type/ddp_spawn.py#L247-L252
Sharded Data Parallel: https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/plugins/training_type/sharded.py#L37-L45
Sharded Data Parallel Spawn: https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/plugins/training_type/sharded_spawn.py#L36-L41

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @awaelchli @rohitgr7 @akihironitta

@ananthsub ananthsub added feature Is an improvement or enhancement distributed Generic distributed-related topic labels Jul 28, 2021
@ananthsub ananthsub added this to the v1.5 milestone Jul 28, 2021
@ananthsub ananthsub self-assigned this Jul 28, 2021
@SeanNaren
Copy link
Contributor

Thanks for picking this up @ananthsub :) I may be missing some carefully thought out logic already, but we'll need to take care in this case right?

trainer = Trainer(num_nodes=..., gpus=..., plugins=["..."]) 
module = MyLightningModule(...)
trainer.fit(module)

# as an example...
module = MyLightningModule(module.backbone)
trainer.test(module) # self._model in the training type plugin points to a different wrapped module

This would involve having to do some checks to ensure the models are the same I think. I think the case above may not be a major issue as it's just a lightning_module comparison but something that may be important to ensure integrity of each stage!

@tchaton
Copy link
Contributor

tchaton commented Jul 29, 2021

Dear @ananthsub,

I would prefer a slightly alternate approach which I started to implement but broke for DeepSpeed Plugin.
On teardown, we should drop the wrapper as we currently do for DistributedDataParallel: https://github.com/PyTorchLightning/pytorch-lightning/blob/a64cc373946a755ce6c3aef57c1be607dfe29a0c/pytorch_lightning/plugins/training_type/parallel.py#L138.
We should extend this logic to other Plugins, but it seems to fail.

Best,
T.C

@awaelchli
Copy link
Contributor

We also have the option to do

self._model = DistributedDataParallel( 
         LightningDistributedModule(self.lightning_module),  # here note lightning_module instead of _model
         device_ids=self.determine_ddp_device_ids(), 
         **self._ddp_kwargs, 
     ) 

self.lightning_module is guaranteed to be the unwrapped model so we can prevent any accidental double wrapping.

@SeanNaren
Copy link
Contributor

SeanNaren commented Jul 29, 2021

I actually prefer what @ananthsub is suggesting over the proposals above, as we assume that the lightning module stays the same which is not true iirc. Hooks are attached to the lightning module (same with DeepSpeed, in FSDP it's even more destructive). By allowing the plugin to not re-create the wrapper, we remain truer to the actual standard loop (re-using the model across stages once wrapped).

EDIT:

example of what a loop looks like given ananth's suggestion:

model = MyModel()
model = DDP(MyModel)
for batch in train_batches:
    train_model(model, batch)
for batch in test_batches:
    test_model(model, batch)

@ananthsub
Copy link
Contributor Author

ananthsub commented Jul 30, 2021

There's a high-degree of overlap between this and #6977 and I think #5007 too.

  • The plugins wrap the LightningModule as a new wrapped model
  • The plugins persist this wrapped model
  • This is wrapped model is visible to the trainer
  • This is wrapped model sticks around across calls to the trainer
  • So we are responsible for state management of the wrapped model as well
  • This extends to Duplicate epochs when calling .fit() twice #5007 as it questions what the guarantees the trainer makes regarding its state across successive calls to the Trainer

As @SeanNaren points out, always rewrapping might not be safe in case applying the wrapper is not idempotent. (Separately, should users ensure configure_sharded_model is idempotent?)

However, @SeanNaren - could users expect this?

model = MyModel()
ddp_model = DDP(MyModel)
for batch in train_batches:
    train_model(ddp_model, batch)
for batch in test_batches:
    test_model(model, batch)

but this would be different from this:

model = MyModel()
fsdp_model = FSDP(MyModel)
for batch in train_batches:
    train_model(fsdp_model, batch)
for batch in test_batches:
    test_model(model, batch)

because of the differences between DDP and FSDP. As a result, does this logic become plugin-specific for when to rewrap vs not?

@SeanNaren this is an excellent example too:

trainer = Trainer(num_nodes=..., gpus=..., plugins=["..."]) 
module = MyLightningModule(...)
trainer.fit(module)

# as an example...
module = MyLightningModule(module.backbone)
trainer.test(module) # self._model in the training type plugin points to a different wrapped module

If people make changes to their LightningModule like this, what guarantees do we offer if we're relying on the previous state of self._model to make decisions? things quickly can fall out of sync. more fundamentally, the amount of state set on the trainer is very hard to unwind and verify across calls.

It feels much safer to recommend users to create a brand new Trainer object per call to avoid dealing with state cleanup. What do you think?

@SeanNaren
Copy link
Contributor

thanks @ananthsub for the great summary here!

Separately, should users ensure configure_sharded_model is idempotent?

configure_sharded_model after the changes proposed here should be called once in a lifecycle for both DeepSpeed or FSDP, and not called again as the model remains the same!

Regarding the cases of model swapping in subsequent stages or modifications made to the model, I think as you suggested we should recommend a brand new trainer object for these individual cases.

@carmocca
Copy link
Contributor

This extends to Duplicate epochs when calling .fit() twice #5007 as it questions what the guarantees the trainer makes regarding its state across successive calls to the Trainer

#5007 is not really related, as it's due to an error in the current_epoch increase logic.

@ananthsub
Copy link
Contributor Author

hoisting out from here: #8943 (comment)

@tchaton from #8490 - do you recall why unwrapping DDP in the plugin's teardown saves memory or reduces the chance of memory leaks? that unwrapping is at odds with this issue, so I'm wondering if we should still pursue this

@carmocca
Copy link
Contributor

do you recall why unwrapping DDP in the plugin's teardown saves memory or reduces the chance of memory leaks?

This is because the DDP wrapper seems to hold memory - supposedly in the buckets. You should be able to see the memory test failing if the change is undone.

If it's possible to keep the wrap without affecting the memory after fit, then it can be pursued.

@tchaton
Copy link
Contributor

tchaton commented Aug 20, 2021

hoisting out from here: #8943 (comment)

@tchaton from #8490 - do you recall why unwrapping DDP in the plugin's teardown saves memory or reduces the chance of memory leaks? that unwrapping is at odds with this issue, so I'm wondering if we should still pursue this

Hey @ananthsub. It seems the DDP Reducer was holding some GPU memory and had to be deleted to get back to the initial state.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed Generic distributed-related topic feature Is an improvement or enhancement strategy
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants