Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] Fixed Multi Node and add examples #15557

Merged
merged 36 commits into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3e187d1
update
tchaton Nov 6, 2022
092b36a
update
tchaton Nov 6, 2022
fcb2ea2
update
tchaton Nov 6, 2022
6481338
Merge branch 'master' into add_multi_node_examples
tchaton Nov 6, 2022
e1271ce
update
tchaton Nov 6, 2022
478d0f0
Merge branch 'add_multi_node_examples' of https://github.com/Lightnin…
tchaton Nov 6, 2022
0c5a079
update
tchaton Nov 6, 2022
804c5cb
update
tchaton Nov 6, 2022
baf1cae
update
tchaton Nov 6, 2022
a393f58
update
tchaton Nov 6, 2022
38f1c72
update
tchaton Nov 6, 2022
4ddb3ae
update
tchaton Nov 6, 2022
ed93320
update
tchaton Nov 6, 2022
402b6fd
update
tchaton Nov 6, 2022
dece823
update
tchaton Nov 6, 2022
4b7e8af
update
tchaton Nov 6, 2022
db336d3
update
tchaton Nov 6, 2022
2cd0d54
update
tchaton Nov 6, 2022
7da57cd
update
tchaton Nov 6, 2022
651590e
update
tchaton Nov 6, 2022
17ac6db
update
tchaton Nov 6, 2022
589ff92
update
tchaton Nov 6, 2022
d221d35
update
tchaton Nov 6, 2022
53597c7
Merge branch 'master' into add_multi_node_examples
tchaton Nov 6, 2022
fa6def5
update
tchaton Nov 6, 2022
0adcdf3
Merge branch 'add_multi_node_examples' of https://github.com/Lightnin…
tchaton Nov 6, 2022
f2fa720
update
tchaton Nov 6, 2022
7778c58
update
tchaton Nov 6, 2022
c005373
update
tchaton Nov 6, 2022
f8fda2e
update
tchaton Nov 6, 2022
00119f6
update
tchaton Nov 6, 2022
0f4e5e5
update
tchaton Nov 6, 2022
1706574
update
tchaton Nov 6, 2022
060f726
update
tchaton Nov 6, 2022
e9c4332
Merge branch 'master' into add_multi_node_examples
lantiga Nov 7, 2022
8323b04
update
tchaton Nov 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/app_multi_node/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Lightning & Multi Node Training

Lightning supports makes multi-node training simple by providing a simple interface to orchestrate compute and data.

## Multi Node with raw PyTorch

You can run the multi-node raw PyTorch by running the following commands.

```bash
lightning run app app_torch_work.py
```

## Multi Node with raw PyTorch + Lite

You can run the multi-node raw PyTorch and Lite by running the following commands.

```bash
lightning run app app_lite_work.py
```

## Multi Node with PyTorch Lightning

Lightning supports running PyTorch Lightning from a script or within a Lightning Work.

### Multi Node PyTorch Lightning Script

```bash
lightning run app app_pl_script.py
```

### Multi Node PyTorch Lightning Work

```bash
lightning run app app_pl_work.py
```

## Multi Node with any frameworks

```bash
lightning run app app_generic_work.py
```
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import lightning.app as L
import lightning as L
from lightning.app.components import MultiNode


Expand All @@ -7,16 +7,17 @@ def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {node_rank}")
print(f"ADD YOUR DISTRIBUTED CODE: {main_address} {main_port} {num_nodes} {node_rank}.")


compute = L.CloudCompute("gpu")
app = L.LightningApp(
MultiNode(
AnyDistributedComponent,
nodes=2,
num_nodes=2,
cloud_compute=compute,
)
)
59 changes: 59 additions & 0 deletions examples/app_multi_node/app_lite_work.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os

import torch

import lightning as L
from lightning.app.components import MultiNode
from lightning.lite import LightningLite


def distributed_train(lite: LightningLite):
# 1. Prepare distributed model and optimizer
model = torch.nn.Linear(32, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
model, optimizer = lite.setup(model, optimizer)
criterion = torch.nn.MSELoss()

# 2. Train the model for 50 steps.
for step in range(50):
model.zero_grad()
x = torch.randn(64, 32).to(lite.device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}")
lite.backward(loss)
optimizer.step()

# 3. Verify all processes have the same weights at the end of training.
weight = model.module.weight.clone()
torch.distributed.all_reduce(weight)
assert torch.equal(model.module.weight, weight / lite.world_size)

print("Multi Node Distributed Training Done!")


class PyTorchDistributed(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):

os.environ["MASTER_ADDR"] = main_address
os.environ["MASTER_PORT"] = str(main_port)
os.environ["NODE_RANK"] = str(node_rank)

lite = LightningLite(accelerator="auto", devices="auto", strategy="ddp_spawn", num_nodes=num_nodes)
lite.launch(function=distributed_train)


compute = L.CloudCompute("gpu-fast-multi") # 4xV100
app = L.LightningApp(
MultiNode(
PyTorchDistributed,
num_nodes=2,
cloud_compute=compute,
)
)
38 changes: 38 additions & 0 deletions examples/app_multi_node/app_pl_work.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

import lightning as L
from lightning.app.components import MultiNode
from lightning.pytorch.demos.boring_classes import BoringModel


class PyTorchLightningDistributed(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
os.environ["MASTER_ADDR"] = main_address
os.environ["MASTER_PORT"] = str(main_port)
os.environ["NODE_RANK"] = str(node_rank)

model = BoringModel()
trainer = L.Trainer(
max_epochs=10,
devices="auto",
accelerator="auto",
num_nodes=num_nodes,
strategy="ddp_spawn", # Only spawn based strategies are supported for now.
)
trainer.fit(model)


compute = L.CloudCompute("gpu-fast-multi") # 4xV100
app = L.LightningApp(
MultiNode(
PyTorchLightningDistributed,
num_nodes=2,
cloud_compute=compute,
)
)
70 changes: 70 additions & 0 deletions examples/app_multi_node/app_torch_work.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from torch.nn.parallel.distributed import DistributedDataParallel

import lightning as L
from lightning.app.components import MultiNode


def distributed_train(local_rank: int, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int):
# 1. Setting distributed environment
global_rank = local_rank + node_rank * nprocs
world_size = num_nodes * nprocs

if torch.distributed.is_available() and not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl" if torch.cuda.is_available() else "gloo",
rank=global_rank,
world_size=world_size,
init_method=f"tcp://{main_address}:{main_port}",
)

# 2. Prepare distributed model
model = torch.nn.Linear(32, 2)
device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu")
device_ids = device if torch.cuda.is_available() else None
model = DistributedDataParallel(model, device_ids=device_ids).to(device)

# 3. Prepare loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 4. Train the model for 50 steps.
for step in range(50):
model.zero_grad()
x = torch.randn(64, 32).to(device)
output = model(x)
loss = criterion(output, torch.ones_like(output))
print(f"global_rank: {global_rank} step: {step} loss: {loss}")
loss.backward()
optimizer.step()

# 5. Verify all processes have the same weights at the end of training.
weight = model.module.weight.clone()
torch.distributed.all_reduce(weight)
assert torch.equal(model.module.weight, weight / world_size)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

print("Multi Node Distributed Training Done!")


class PyTorchDistributed(L.LightningWork):
def run(
self,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
):
nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1
torch.multiprocessing.spawn(
distributed_train, args=(main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs
)


compute = L.CloudCompute("gpu-fast-multi") # 4xV100
app = L.LightningApp(
MultiNode(
PyTorchDistributed,
num_nodes=2,
cloud_compute=compute,
)
)
2 changes: 0 additions & 2 deletions examples/app_multi_node/bare/.gitignore

This file was deleted.

36 changes: 0 additions & 36 deletions examples/app_multi_node/bare/multi_node.py

This file was deleted.

6 changes: 3 additions & 3 deletions examples/app_multi_node/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
import lightning as L
from lightning.pytorch.demos.boring_classes import BoringModel

if __name__ == "__main__":
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer = L.Trainer(max_epochs=1)
trainer.fit(model)
2 changes: 1 addition & 1 deletion src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed missing root flow among the flows of the app ([#15531](https://github.com/Lightning-AI/lightning/pull/15531))

-
- Fixed bug with Multi Node Component and add some examples ([#15557](https://github.com/Lightning-AI/lightning/pull/15557))



Expand Down
13 changes: 7 additions & 6 deletions src/lightning_app/components/multi_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MultiNode(LightningFlow):
def __init__(
self,
work_cls: Type["LightningWork"],
nodes: int,
num_nodes: int,
cloud_compute: "CloudCompute",
*work_args: Any,
**work_kwargs: Any,
Expand Down Expand Up @@ -39,22 +39,22 @@ def run(
app = L.LightningApp(
MultiNode(
AnyDistributedComponent,
nodes=8,
num_nodes=8,
cloud_compute=compute,
)
)

Arguments:
work_cls: The work to be executed
nodes: Number of nodes.
num_nodes: Number of nodes.
cloud_compute: The cloud compute object used in the cloud.
work_args: Arguments to be provided to the work on instantiation.
work_kwargs: Keywords arguments to be provided to the work on instantiation.
"""
super().__init__()
self.ws = structures.List()
self._work_cls = work_cls
self.nodes = nodes
self.num_nodes = num_nodes
self._cloud_compute = cloud_compute
self._work_args = work_args
self._work_kwargs = work_kwargs
Expand All @@ -65,7 +65,7 @@ def run(self) -> None:

# 1. Create & start the works
if not self.ws:
for node_rank in range(self.nodes):
for node_rank in range(self.num_nodes):
self.ws.append(
self._work_cls(
*self._work_args,
Expand All @@ -84,12 +84,13 @@ def run(self) -> None:
self.has_started = True

# Loop over all node machines
for node_rank in range(self.nodes):
for node_rank in range(self.num_nodes):

# 3. Run the user code in a distributed way !
self.ws[node_rank].run(
main_address=self.ws[0].internal_ip,
main_port=self.ws[0].port,
num_nodes=self.num_nodes,
node_rank=node_rank,
)

Expand Down
8 changes: 4 additions & 4 deletions src/lightning_app/utilities/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,18 @@ def run_once(self):
# 6. Create the state observer thread.
self.state_observer = WorkStateObserver(self.work, delta_queue=self.delta_queue)

# 7. Deepcopy the work state and send the first `RUNNING` status delta to the flow.
reference_state = deepcopy(self.work.state)

# Set the internal IP address.
# Set this here after the state observer is initialized, since it needs to record it as a change and send
# it back to the flow
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", "127.0.0.1")

# 7. Patch the setattr method of the work. This needs to be done after step 4, so we don't
# 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
# send delta while calling `set_state`.
self._proxy_setattr()

# 8. Deepcopy the work state and send the first `RUNNING` status delta to the flow.
reference_state = deepcopy(self.work.state)

if self._is_starting(called, reference_state, call_hash):
return

Expand Down
Loading