-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid rewrapping LightningModules in plugins #8593
Comments
Thanks for picking this up @ananthsub :) I may be missing some carefully thought out logic already, but we'll need to take care in this case right? trainer = Trainer(num_nodes=..., gpus=..., plugins=["..."])
module = MyLightningModule(...)
trainer.fit(module)
# as an example...
module = MyLightningModule(module.backbone)
trainer.test(module) # self._model in the training type plugin points to a different wrapped module This would involve having to do some checks to ensure the models are the same I think. I think the case above may not be a major issue as it's just a lightning_module comparison but something that may be important to ensure integrity of each stage! |
Dear @ananthsub, I would prefer a slightly alternate approach which I started to implement but broke for DeepSpeed Plugin. Best, |
We also have the option to do self._model = DistributedDataParallel(
LightningDistributedModule(self.lightning_module), # here note lightning_module instead of _model
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
|
I actually prefer what @ananthsub is suggesting over the proposals above, as we assume that the lightning module stays the same which is not true iirc. Hooks are attached to the lightning module (same with DeepSpeed, in FSDP it's even more destructive). By allowing the plugin to not re-create the wrapper, we remain truer to the actual standard loop (re-using the model across stages once wrapped). EDIT: example of what a loop looks like given ananth's suggestion: model = MyModel()
model = DDP(MyModel)
for batch in train_batches:
train_model(model, batch)
for batch in test_batches:
test_model(model, batch) |
There's a high-degree of overlap between this and #6977 and I think #5007 too.
As @SeanNaren points out, always rewrapping might not be safe in case applying the wrapper is not idempotent. (Separately, should users ensure However, @SeanNaren - could users expect this?
but this would be different from this:
because of the differences between DDP and FSDP. As a result, does this logic become plugin-specific for when to rewrap vs not? @SeanNaren this is an excellent example too:
If people make changes to their LightningModule like this, what guarantees do we offer if we're relying on the previous state of It feels much safer to recommend users to create a brand new Trainer object per call to avoid dealing with state cleanup. What do you think? |
thanks @ananthsub for the great summary here!
Regarding the cases of model swapping in subsequent stages or modifications made to the model, I think as you suggested we should recommend a brand new trainer object for these individual cases. |
hoisting out from here: #8943 (comment) @tchaton from #8490 - do you recall why unwrapping DDP in the plugin's teardown saves memory or reduces the chance of memory leaks? that unwrapping is at odds with this issue, so I'm wondering if we should still pursue this |
This is because the DDP wrapper seems to hold memory - supposedly in the buckets. You should be able to see the memory test failing if the change is undone. If it's possible to keep the wrap without affecting the memory after fit, then it can be pursued. |
Hey @ananthsub. It seems the DDP Reducer was holding some GPU memory and had to be deleted to get back to the initial state. |
🚀 Feature
Background:
We are auditing the Lightning components and APIs to assess opportunities for improvements:
Motivation
Lightning has the concept of a Training Type Plugin, which functions as the distribution strategy to be used during execution. For this, Lightning offers an API and out of the box implementations for data parallel-based approaches (DataParallel, DistributedDataParallel, ShardedDataParallel, DeepSpeed, etc). The trainer wraps the whole LightningModule with a module-wrapper in order to facilitate the gradient synchronization.
This allows users to easily specify distributed training options like:
The situation we'd like to explicitly avoid is if the user makes successive calls to Trainer APIs like this:
potentially causing the wrapper to be applied multiple times to the LightningModule. (e.g. DistributedDataParallel(DistributedDataParallel(LightningModule)) )
Pitch
Inside of the following plugins, avoid wrapping the plugin's
self._model
if the model is already an instance of the wrapped type.Plugins to update
DDP: https://github.com/PyTorchLightning/pytorch-lightning/blob/80c529351439a0f8d3d6e9449cd47d16ba3abbec/pytorch_lightning/plugins/training_type/ddp.py#L249-L256
DDP Spawn: https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/plugins/training_type/ddp_spawn.py#L247-L252
Sharded Data Parallel: https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/plugins/training_type/sharded.py#L37-L45
Sharded Data Parallel Spawn: https://github.com/PyTorchLightning/pytorch-lightning/blob/6b47cbe3ca8aa3fd82211bc9fa32e734753a6950/pytorch_lightning/plugins/training_type/sharded_spawn.py#L36-L41
Alternatives
Additional context
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @Borda @awaelchli @rohitgr7 @akihironitta
The text was updated successfully, but these errors were encountered: