Skip to content

Commit

Permalink
[App] Fixed Multi Node and add examples (#15557)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Nov 7, 2022
1 parent 96c5744 commit 8202331
Show file tree
Hide file tree
Showing 15 changed files with 298 additions and 57 deletions.
41 changes: 41 additions & 0 deletions examples/app_multi_node/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Lightning & Multi Node Training

Lightning supports makes multi-node training simple by providing a simple interface to orchestrate compute and data.

## Multi Node with raw PyTorch

You can run the multi-node raw PyTorch by running the following commands.

```bash
lightning run app app_torch_work.py
```

## Multi Node with raw PyTorch + Lite

You can run the multi-node raw PyTorch and Lite by running the following commands.

```bash
lightning run app app_lite_work.py
```

## Multi Node with PyTorch Lightning

Lightning supports running PyTorch Lightning from a script or within a Lightning Work.

### Multi Node PyTorch Lightning Script

```bash
lightning run app app_pl_script.py
```

### Multi Node PyTorch Lightning Work

```bash
lightning run app app_pl_work.py
```

## Multi Node with any frameworks

```bash
lightning run app app_generic_work.py
```
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import lightning.app as L
import lightning as L
from lightning.app.components import MultiNode


Expand All @@ -7,16 +7,17 @@ def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {num_nodes} {node_rank}.")


compute = L.CloudCompute("gpu")
app = L.LightningApp(
MultiNode(
AnyDistributedComponent,
nodes=2,
num_nodes=2,
cloud_compute=compute,
)
)
59 changes: 59 additions & 0 deletions examples/app_multi_node/app_lite_work.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os

import torch

import lightning as L
from lightning.app.components import MultiNode
from lightning.lite import LightningLite


def distributed_train(lite: LightningLite):
# 1. Prepare distributed model and optimizer
model = torch.nn.Linear(32, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
model, optimizer = lite.setup(model, optimizer)
criterion = torch.nn.MSELoss()

# 2. Train the model for 50 steps.
for step in range(50):
model.zero_grad()
x = torch.randn(64, 32).to(lite.device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}")
lite.backward(loss)
optimizer.step()

# 3. Verify all processes have the same weights at the end of training.
weight = model.module.weight.clone()
torch.distributed.all_reduce(weight)
assert torch.equal(model.module.weight, weight / lite.world_size)

print("Multi Node Distributed Training Done!")


class PyTorchDistributed(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):

os.environ["MASTER_ADDR"] = main_address
os.environ["MASTER_PORT"] = str(main_port)
os.environ["NODE_RANK"] = str(node_rank)

lite = LightningLite(accelerator="auto", devices="auto", strategy="ddp_spawn", num_nodes=num_nodes)
lite.launch(function=distributed_train)


compute = L.CloudCompute("gpu-fast-multi") # 4xV100
app = L.LightningApp(
MultiNode(
PyTorchDistributed,
num_nodes=2,
cloud_compute=compute,
)
)
File renamed without changes.
38 changes: 38 additions & 0 deletions examples/app_multi_node/app_pl_work.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

import lightning as L
from lightning.app.components import MultiNode
from lightning.pytorch.demos.boring_classes import BoringModel


class PyTorchLightningDistributed(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
os.environ["MASTER_ADDR"] = main_address
os.environ["MASTER_PORT"] = str(main_port)
os.environ["NODE_RANK"] = str(node_rank)

model = BoringModel()
trainer = L.Trainer(
max_epochs=10,
devices="auto",
accelerator="auto",
num_nodes=num_nodes,
strategy="ddp_spawn", # Only spawn based strategies are supported for now.
)
trainer.fit(model)


compute = L.CloudCompute("gpu-fast-multi") # 4xV100
app = L.LightningApp(
MultiNode(
PyTorchLightningDistributed,
num_nodes=2,
cloud_compute=compute,
)
)
70 changes: 70 additions & 0 deletions examples/app_multi_node/app_torch_work.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from torch.nn.parallel.distributed import DistributedDataParallel

import lightning as L
from lightning.app.components import MultiNode


def distributed_train(local_rank: int, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int):
# 1. Setting distributed environment
global_rank = local_rank + node_rank * nprocs
world_size = num_nodes * nprocs

if torch.distributed.is_available() and not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl" if torch.cuda.is_available() else "gloo",
rank=global_rank,
world_size=world_size,
init_method=f"tcp://{main_address}:{main_port}",
)

# 2. Prepare distributed model
model = torch.nn.Linear(32, 2)
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
device_ids = device if torch.cuda.is_available() else None
model = DistributedDataParallel(model, device_ids=device_ids).to(device)

# 3. Prepare loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 4. Train the model for 50 steps.
for step in range(50):
model.zero_grad()
x = torch.randn(64, 32).to(device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
loss.backward()
optimizer.step()

# 5. Verify all processes have the same weights at the end of training.
weight = model.module.weight.clone()
torch.distributed.all_reduce(weight)
assert torch.equal(model.module.weight, weight / world_size)

print("Multi Node Distributed Training Done!")


class PyTorchDistributed(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
torch.multiprocessing.spawn(
distributed_train, args=(main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs
)


compute = L.CloudCompute("gpu-fast-multi") # 4xV100
app = L.LightningApp(
MultiNode(
PyTorchDistributed,
num_nodes=2,
cloud_compute=compute,
)
)
2 changes: 0 additions & 2 deletions examples/app_multi_node/bare/.gitignore

This file was deleted.

36 changes: 0 additions & 36 deletions examples/app_multi_node/bare/multi_node.py

This file was deleted.

File renamed without changes.
6 changes: 3 additions & 3 deletions examples/app_multi_node/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
import lightning as L
from lightning.pytorch.demos.boring_classes import BoringModel

if __name__ == "__main__":
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer = L.Trainer(max_epochs=1)
trainer.fit(model)
2 changes: 1 addition & 1 deletion src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed missing root flow among the flows of the app ([#15531](https://github.com/Lightning-AI/lightning/pull/15531))

-
- Fixed bug with Multi Node Component and add some examples ([#15557](https://github.com/Lightning-AI/lightning/pull/15557))



Expand Down
13 changes: 7 additions & 6 deletions src/lightning_app/components/multi_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MultiNode(LightningFlow):
def __init__(
self,
work_cls: Type["LightningWork"],
nodes: int,
num_nodes: int,
cloud_compute: "CloudCompute",
*work_args: Any,
**work_kwargs: Any,
Expand Down Expand Up @@ -39,22 +39,22 @@ def run(
app = L.LightningApp(
MultiNode(
AnyDistributedComponent,
nodes=8,
num_nodes=8,
cloud_compute=compute,
)
)
Arguments:
work_cls: The work to be executed
nodes: Number of nodes.
num_nodes: Number of nodes.
cloud_compute: The cloud compute object used in the cloud.
work_args: Arguments to be provided to the work on instantiation.
work_kwargs: Keywords arguments to be provided to the work on instantiation.
"""
super().__init__()
self.ws = structures.List()
self._work_cls = work_cls
self.nodes = nodes
self.num_nodes = num_nodes
self._cloud_compute = cloud_compute
self._work_args = work_args
self._work_kwargs = work_kwargs
Expand All @@ -65,7 +65,7 @@ def run(self) -> None:

# 1. Create & start the works
if not self.ws:
for node_rank in range(self.nodes):
for node_rank in range(self.num_nodes):
self.ws.append(
self._work_cls(
*self._work_args,
Expand All @@ -84,12 +84,13 @@ def run(self) -> None:
self.has_started = True

# Loop over all node machines
for node_rank in range(self.nodes):
for node_rank in range(self.num_nodes):

# 3. Run the user code in a distributed way !
self.ws[node_rank].run(
main_address=self.ws[0].internal_ip,
main_port=self.ws[0].port,
num_nodes=self.num_nodes,
node_rank=node_rank,
)

Expand Down
8 changes: 4 additions & 4 deletions src/lightning_app/utilities/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,18 @@ def run_once(self):
# 6. Create the state observer thread.
self.state_observer = WorkStateObserver(self.work, delta_queue=self.delta_queue)

# 7. Deepcopy the work state and send the first `RUNNING` status delta to the flow.
reference_state = deepcopy(self.work.state)

# Set the internal IP address.
# Set this here after the state observer is initialized, since it needs to record it as a change and send
# it back to the flow
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", "127.0.0.1")

# 7. Patch the setattr method of the work. This needs to be done after step 4, so we don't
# 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
# send delta while calling `set_state`.
self._proxy_setattr()

# 8. Deepcopy the work state and send the first `RUNNING` status delta to the flow.
reference_state = deepcopy(self.work.state)

if self._is_starting(called, reference_state, call_hash):
return

Expand Down
Loading

0 comments on commit 8202331

Please sign in to comment.