Skip to content

Commit 20b806a

Browse files
tchatonSeanNarencarmoccaBorda
authored
[feat] 3/n pp (Lightning-AI#5036)
* add pp doc * udpate doc * update doc * update doc * Update docs * update doc * udpate * update doc * update doc * Formatting, update sharded zip link * Update docs/source/multi_gpu.rst Co-authored-by: Carlos Mocholí <[email protected]> * Apply suggestions from code review * Reference directly to section Co-authored-by: SeanNaren <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 69725ad commit 20b806a

File tree

4 files changed

+99
-8
lines changed

4 files changed

+99
-8
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ repos:
3232
types: [python]
3333

3434
- repo: https://github.com/pre-commit/mirrors-mypy
35-
rev: master
35+
rev: v0.790
3636
hooks:
3737
- id: mypy

docs/source/multi_gpu.rst

+81-6
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ This is useful when dealing with large Transformer based models, or in environme
612612
Lightning currently offers the following methods to leverage model parallelism:
613613

614614
- Sharded Training (partitioning your gradients and optimizer state across multiple GPUs, for reduced memory overhead with **no performance loss**)
615+
- Sequential Model Parallelism with Checkpointing (partition your :class:`nn.Sequential <torch.nn.Sequential>` module across multiple GPUs, leverage checkpointing and microbatching for further memory improvements and device utilization)
615616

616617
Sharded Training
617618
^^^^^^^^^^^^^^^^
@@ -666,7 +667,7 @@ To use Sharded Training, you need to first install FairScale using the command b
666667

667668
.. code-block:: bash
668669
669-
pip install https://github.com/facebookresearch/fairscale/archive/bb468670838b98dc8f8d67be4eabf195042a7994.zip
670+
pip install https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip
670671
671672
672673
.. code-block:: python
@@ -678,6 +679,80 @@ Sharded Training can work across all DDP variants by adding the additional ``--p
678679

679680
Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.
680681

682+
----------
683+
684+
.. _sequential-parallelism:
685+
686+
Sequential Model Parallelism with Checkpointing
687+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
688+
PyTorch Lightning integration for Sequential Model Parallelism using `FairScale <https://github.com/facebookresearch/fairscale>`_.
689+
Sequential Model Parallelism splits a sequential module onto multiple GPUs, reducing peak GPU memory requirements substantially.
690+
We also provide auto-balancing techniques through FairScale, to find optimal balances for the model across GPUs.
691+
In addition, we use Gradient Checkpointing to reduce GPU memory requirements further, and micro-batches to minimizing device under-utilization automatically.
692+
693+
Reference: https://arxiv.org/abs/1811.06965
694+
695+
.. note:: DDPSequentialPlugin is currently supported only for Pytorch 1.6.
696+
697+
To get started, install FairScale through extras using with ``pip install pytorch-lightning["extra"]``
698+
699+
or directly using
700+
701+
.. code-block:: bash
702+
703+
pip install https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip
704+
705+
To use Sequential Model Parallelism, you must define a :class:`nn.Sequential <torch.nn.Sequential>` module that defines the layers you wish to parallelize across GPUs.
706+
This should be kept within the ``sequential_module`` variable within your ``LightningModule`` like below.
707+
708+
.. code-block:: python
709+
710+
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
711+
from pytorch_lightning import LightningModule
712+
713+
class MyModel(LightningModule):
714+
def __init__(self):
715+
...
716+
self.sequential_module = torch.nn.Sequential(my_layers)
717+
718+
# Split my module across 4 gpus, one layer each
719+
model = MyModel()
720+
plugin = DDPSequentialPlugin(balance=[1, 1, 1, 1])
721+
trainer = Trainer(accelerator='ddp', gpus=4, plugins=[plugin])
722+
trainer.fit(model)
723+
724+
725+
We provide a minimal example of Sequential Model Parallelism using a convolutional model training on cifar10, split onto GPUs `here <https://github.com/PyTorchLightning/pytorch-lightning/tree/master/pl_examples/basic_examples/conv_sequential_example.py>`_.
726+
To run the example, you need to install `Bolts <https://github.com/PyTorchLightning/pytorch-lightning-bolts>`_. Install with ``pip install pytorch-lightning-bolts``.
727+
728+
When running the Sequential Model Parallelism example on 2 GPUS we achieve these memory savings.
729+
730+
.. list-table:: GPU Memory Utilization
731+
:widths: 25 25 50
732+
:header-rows: 1
733+
734+
* - GPUS
735+
- Without Balancing
736+
- With Balancing
737+
* - Gpu 0
738+
- 4436 MB
739+
- 1554 MB
740+
* - Gpu 1
741+
- ~0
742+
- 994 MB
743+
744+
To run the example with Sequential Model Parallelism:
745+
746+
.. code-block:: bash
747+
748+
python pl_examples/basic_examples/conv_sequential_example.py --batch_size 1024 --gpus 2 --accelerator ddp --use_ddp_sequential
749+
750+
To run the same example without Sequential Model Parallelism:
751+
752+
.. code-block:: bash
753+
754+
python pl_examples/basic_examples/conv_sequential_example.py --batch_size 1024 --gpus 1
755+
681756
682757
Batch size
683758
----------
@@ -728,17 +803,17 @@ Lightning supports the use of TorchElastic to enable fault-tolerant and elastic
728803
.. code-block:: python
729804
730805
Trainer(gpus=8, accelerator='ddp')
731-
732-
806+
807+
733808
Following the `TorchElastic Quickstart documentation <https://pytorch.org/elastic/latest/quickstart.html>`_, you then need to start a single-node etcd server on one of the hosts:
734809

735810
.. code-block:: bash
736811
737812
etcd --enable-v2
738813
--listen-client-urls http://0.0.0.0:2379,http://127.0.0.1:4001
739814
--advertise-client-urls PUBLIC_HOSTNAME:2379
740-
741-
815+
816+
742817
And then launch the elastic job with:
743818

744819
.. code-block:: bash
@@ -750,7 +825,7 @@ And then launch the elastic job with:
750825
--rdzv_backend=etcd
751826
--rdzv_endpoint=ETCD_HOST:ETCD_PORT
752827
YOUR_LIGHTNING_TRAINING_SCRIPT.py (--arg1 ... train script args...)
753-
828+
754829
755830
See the official `TorchElastic documentation <https://pytorch.org/elastic>`_ for details
756831
on installation and more use cases.

docs/source/performance.rst

+9-1
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,12 @@ To use Optimizer Sharded Training, refer to :ref:`model-parallelism`.
131131

132132
Sharded DDP can work across all DDP variants by adding the additional ``--plugins ddp_sharded`` flag.
133133

134-
Refer to the :ref:`distributed computing guide for more details <multi_gpu>`.
134+
Refer to the :ref:`distributed computing guide for more details <multi_gpu>`.
135+
136+
137+
Sequential Model Parallelism with Checkpointing
138+
---------------------------------------------------------------------
139+
PyTorch Lightning integration for Sequential Model Parallelism using `FairScale <https://github.com/facebookresearch/fairscale>`_.
140+
Sequential Model Parallelism splits a sequential module onto multiple GPUs, reducing peak GPU memory requirements substantially.
141+
142+
For more information, refer to :ref:`sequential-parallelism`.

docs/source/training_tricks.rst

+8
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,11 @@ The algorithm in short works by:
123123
:members: scale_batch_size
124124

125125
.. warning:: Batch size finder is not supported for DDP yet, it is coming soon.
126+
127+
128+
Sequential Model Parallelism with Checkpointing
129+
---------------------------------------------------------------------
130+
PyTorch Lightning integration for Sequential Model Parallelism using `FairScale <https://github.com/facebookresearch/fairscale>`_.
131+
Sequential Model Parallelism splits a sequential module onto multiple GPUs, reducing peak GPU memory requirements substantially.
132+
133+
For more information, refer to :ref:`sequential-parallelism`.

0 commit comments

Comments
 (0)