-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[App] Fixed Multi Node and add examples (#15557)
(cherry picked from commit 8202331)
- Loading branch information
Showing
15 changed files
with
298 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Lightning & Multi Node Training | ||
|
||
Lightning supports makes multi-node training simple by providing a simple interface to orchestrate compute and data. | ||
|
||
## Multi Node with raw PyTorch | ||
|
||
You can run the multi-node raw PyTorch by running the following commands. | ||
|
||
```bash | ||
lightning run app app_torch_work.py | ||
``` | ||
|
||
## Multi Node with raw PyTorch + Lite | ||
|
||
You can run the multi-node raw PyTorch and Lite by running the following commands. | ||
|
||
```bash | ||
lightning run app app_lite_work.py | ||
``` | ||
|
||
## Multi Node with PyTorch Lightning | ||
|
||
Lightning supports running PyTorch Lightning from a script or within a Lightning Work. | ||
|
||
### Multi Node PyTorch Lightning Script | ||
|
||
```bash | ||
lightning run app app_pl_script.py | ||
``` | ||
|
||
### Multi Node PyTorch Lightning Work | ||
|
||
```bash | ||
lightning run app app_pl_work.py | ||
``` | ||
|
||
## Multi Node with any frameworks | ||
|
||
```bash | ||
lightning run app app_generic_work.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
|
||
import torch | ||
|
||
import lightning as L | ||
from lightning.app.components import MultiNode | ||
from lightning.lite import LightningLite | ||
|
||
|
||
def distributed_train(lite: LightningLite): | ||
# 1. Prepare distributed model and optimizer | ||
model = torch.nn.Linear(32, 2) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||
model, optimizer = lite.setup(model, optimizer) | ||
criterion = torch.nn.MSELoss() | ||
|
||
# 2. Train the model for 50 steps. | ||
for step in range(50): | ||
model.zero_grad() | ||
x = torch.randn(64, 32).to(lite.device) | ||
output = model(x) | ||
loss = criterion(output, torch.ones_like(output)) | ||
print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}") | ||
lite.backward(loss) | ||
optimizer.step() | ||
|
||
# 3. Verify all processes have the same weights at the end of training. | ||
weight = model.module.weight.clone() | ||
torch.distributed.all_reduce(weight) | ||
assert torch.equal(model.module.weight, weight / lite.world_size) | ||
|
||
print("Multi Node Distributed Training Done!") | ||
|
||
|
||
class PyTorchDistributed(L.LightningWork): | ||
def run( | ||
self, | ||
main_address: str, | ||
main_port: int, | ||
num_nodes: int, | ||
node_rank: int, | ||
): | ||
|
||
os.environ["MASTER_ADDR"] = main_address | ||
os.environ["MASTER_PORT"] = str(main_port) | ||
os.environ["NODE_RANK"] = str(node_rank) | ||
|
||
lite = LightningLite(accelerator="auto", devices="auto", strategy="ddp_spawn", num_nodes=num_nodes) | ||
lite.launch(function=distributed_train) | ||
|
||
|
||
compute = L.CloudCompute("gpu-fast-multi") # 4xV100 | ||
app = L.LightningApp( | ||
MultiNode( | ||
PyTorchDistributed, | ||
num_nodes=2, | ||
cloud_compute=compute, | ||
) | ||
) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
|
||
import lightning as L | ||
from lightning.app.components import MultiNode | ||
from lightning.pytorch.demos.boring_classes import BoringModel | ||
|
||
|
||
class PyTorchLightningDistributed(L.LightningWork): | ||
def run( | ||
self, | ||
main_address: str, | ||
main_port: int, | ||
num_nodes: int, | ||
node_rank: int, | ||
): | ||
os.environ["MASTER_ADDR"] = main_address | ||
os.environ["MASTER_PORT"] = str(main_port) | ||
os.environ["NODE_RANK"] = str(node_rank) | ||
|
||
model = BoringModel() | ||
trainer = L.Trainer( | ||
max_epochs=10, | ||
devices="auto", | ||
accelerator="auto", | ||
num_nodes=num_nodes, | ||
strategy="ddp_spawn", # Only spawn based strategies are supported for now. | ||
) | ||
trainer.fit(model) | ||
|
||
|
||
compute = L.CloudCompute("gpu-fast-multi") # 4xV100 | ||
app = L.LightningApp( | ||
MultiNode( | ||
PyTorchLightningDistributed, | ||
num_nodes=2, | ||
cloud_compute=compute, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torch | ||
from torch.nn.parallel.distributed import DistributedDataParallel | ||
|
||
import lightning as L | ||
from lightning.app.components import MultiNode | ||
|
||
|
||
def distributed_train(local_rank: int, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int): | ||
# 1. Setting distributed environment | ||
global_rank = local_rank + node_rank * nprocs | ||
world_size = num_nodes * nprocs | ||
|
||
if torch.distributed.is_available() and not torch.distributed.is_initialized(): | ||
torch.distributed.init_process_group( | ||
"nccl" if torch.cuda.is_available() else "gloo", | ||
rank=global_rank, | ||
world_size=world_size, | ||
init_method=f"tcp://{main_address}:{main_port}", | ||
) | ||
|
||
# 2. Prepare distributed model | ||
model = torch.nn.Linear(32, 2) | ||
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu") | ||
device_ids = device if torch.cuda.is_available() else None | ||
model = DistributedDataParallel(model, device_ids=device_ids).to(device) | ||
|
||
# 3. Prepare loss and optimizer | ||
criterion = torch.nn.MSELoss() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||
|
||
# 4. Train the model for 50 steps. | ||
for step in range(50): | ||
model.zero_grad() | ||
x = torch.randn(64, 32).to(device) | ||
output = model(x) | ||
loss = criterion(output, torch.ones_like(output)) | ||
print(f"global_rank: {global_rank} step: {step} loss: {loss}") | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# 5. Verify all processes have the same weights at the end of training. | ||
weight = model.module.weight.clone() | ||
torch.distributed.all_reduce(weight) | ||
assert torch.equal(model.module.weight, weight / world_size) | ||
|
||
print("Multi Node Distributed Training Done!") | ||
|
||
|
||
class PyTorchDistributed(L.LightningWork): | ||
def run( | ||
self, | ||
main_address: str, | ||
main_port: int, | ||
num_nodes: int, | ||
node_rank: int, | ||
): | ||
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 | ||
torch.multiprocessing.spawn( | ||
distributed_train, args=(main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs | ||
) | ||
|
||
|
||
compute = L.CloudCompute("gpu-fast-multi") # 4xV100 | ||
app = L.LightningApp( | ||
MultiNode( | ||
PyTorchDistributed, | ||
num_nodes=2, | ||
cloud_compute=compute, | ||
) | ||
) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.demos.boring_classes import BoringModel | ||
import lightning as L | ||
from lightning.pytorch.demos.boring_classes import BoringModel | ||
|
||
if __name__ == "__main__": | ||
model = BoringModel() | ||
trainer = Trainer(max_epochs=1) | ||
trainer = L.Trainer(max_epochs=1) | ||
trainer.fit(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.