-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,996 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,3 +102,5 @@ venv.bak/ | |
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import tempfile | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.multiprocessing as mp | ||
|
||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
import os | ||
|
||
|
||
def setup(rank, world_size): | ||
os.environ['MASTER_ADDR'] = '127.0.0.1' | ||
os.environ['MASTER_PORT'] = '29500' | ||
|
||
# initialize the process group | ||
dist.init_process_group("gloo", rank=rank, world_size=world_size) | ||
|
||
# Explicitly setting seed to make sure that models created in two processes | ||
# start from same random weights and biases. | ||
torch.manual_seed(42) | ||
|
||
|
||
def cleanup(): | ||
dist.destroy_process_group() | ||
|
||
|
||
class ToyModel(nn.Module): | ||
def __init__(self): | ||
super(ToyModel, self).__init__() | ||
self.net1 = nn.Linear(10, 10) | ||
self.relu = nn.ReLU() | ||
self.net2 = nn.Linear(10, 5) | ||
|
||
def forward(self, x): | ||
return self.net2(self.relu(self.net1(x))) | ||
|
||
|
||
def demo_basic(rank, world_size): | ||
setup(rank, world_size) | ||
|
||
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and | ||
# rank 2 uses GPUs [4, 5, 6, 7]. | ||
n = torch.cuda.device_count() // world_size | ||
device_ids = list(range(rank * n, (rank + 1) * n)) | ||
|
||
# create model and move it to device_ids[0] | ||
model = ToyModel().to(device_ids[0]) | ||
# output_device defaults to device_ids[0] | ||
ddp_model = DDP(model, device_ids=device_ids) | ||
|
||
loss_fn = nn.MSELoss() | ||
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) | ||
|
||
optimizer.zero_grad() | ||
outputs = ddp_model(torch.randn(20, 10)) | ||
labels = torch.randn(20, 5).to(device_ids[0]) | ||
loss_fn(outputs, labels).backward() | ||
optimizer.step() | ||
|
||
cleanup() | ||
|
||
|
||
def run_demo(demo_fn, world_size): | ||
mp.spawn(demo_fn, | ||
args=(world_size,), | ||
nprocs=world_size, | ||
join=True) | ||
|
||
|
||
def demo_checkpoint(rank, world_size): | ||
setup(rank, world_size) | ||
|
||
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and | ||
# rank 2 uses GPUs [4, 5, 6, 7]. | ||
n = torch.cuda.device_count() // world_size | ||
device_ids = list(range(rank * n, (rank + 1) * n)) | ||
|
||
model = ToyModel().to(device_ids[0]) | ||
# output_device defaults to device_ids[0] | ||
ddp_model = DDP(model, device_ids=device_ids) | ||
|
||
loss_fn = nn.MSELoss() | ||
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) | ||
|
||
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint" | ||
if rank == 0: | ||
# All processes should see same parameters as they all start from same | ||
# random parameters and gradients are synchronized in backward passes. | ||
# Therefore, saving it in one process is sufficient. | ||
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH) | ||
|
||
# Use a barrier() to make sure that process 1 loads the model after process | ||
# 0 saves it. | ||
dist.barrier() | ||
# configure map_location properly | ||
rank0_devices = [x - rank * len(device_ids) for x in device_ids] | ||
device_pairs = zip(rank0_devices, device_ids) | ||
map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs} | ||
ddp_model.load_state_dict( | ||
torch.load(CHECKPOINT_PATH, map_location=map_location)) | ||
|
||
optimizer.zero_grad() | ||
outputs = ddp_model(torch.randn(20, 10)) | ||
labels = torch.randn(20, 5).to(device_ids[0]) | ||
loss_fn = nn.MSELoss() | ||
loss_fn(outputs, labels).backward() | ||
optimizer.step() | ||
|
||
# Use a barrier() to make sure that all processes have finished reading the | ||
# checkpoint | ||
dist.barrier() | ||
|
||
if rank == 0: | ||
os.remove(CHECKPOINT_PATH) | ||
|
||
cleanup() | ||
|
||
|
||
class ToyMpModel(nn.Module): | ||
def __init__(self, dev0, dev1): | ||
super(ToyMpModel, self).__init__() | ||
self.dev0 = dev0 | ||
self.dev1 = dev1 | ||
self.net1 = torch.nn.Linear(10, 10).to(dev0) | ||
self.relu = torch.nn.ReLU() | ||
self.net2 = torch.nn.Linear(10, 5).to(dev1) | ||
|
||
def forward(self, x): | ||
x = x.to(self.dev0) | ||
x = self.relu(self.net1(x)) | ||
x = x.to(self.dev1) | ||
return self.net2(x) | ||
|
||
|
||
def demo_model_parallel(rank, world_size): | ||
setup(rank, world_size) | ||
|
||
# setup mp_model and devices for this process | ||
dev0 = rank * 1 | ||
dev1 = rank * 1 + 1 | ||
mp_model = ToyMpModel(dev0, dev1) | ||
ddp_mp_model = DDP(mp_model) | ||
|
||
loss_fn = nn.MSELoss() | ||
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001) | ||
|
||
optimizer.zero_grad() | ||
# outputs will be on dev1 | ||
outputs = ddp_mp_model(torch.randn(20, 10)) | ||
labels = torch.randn(20, 5).to(dev1) | ||
loss_fn(outputs, labels).backward() | ||
optimizer.step() | ||
|
||
cleanup() | ||
|
||
|
||
if __name__ == "__main__": | ||
run_demo(demo_basic, 2) | ||
run_demo(demo_checkpoint, 2) | ||
|
||
if torch.cuda.device_count() >= 8: | ||
run_demo(demo_model_parallel, 4) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#!/usr/bin/env python | ||
import os | ||
import torch as th | ||
import torch.distributed as dist | ||
from torch.multiprocessing import Process | ||
|
||
|
||
def allreduce(send, recv): | ||
""" Implementation of a ring-reduce. """ | ||
rank = dist.get_rank() | ||
size = dist.get_world_size() | ||
send_buff = th.zeros(send.size()) | ||
recv_buff = th.zeros(send.size()) | ||
accum = th.zeros(send.size()) | ||
accum[:] = send[:] | ||
#th.cuda.synchronize() | ||
|
||
left = ((rank - 1) + size) % size | ||
right = (rank + 1) % size | ||
|
||
for i in range(size - 1): | ||
if i % 2 == 0: | ||
# Send send_buff | ||
send_req = dist.isend(send_buff, right) | ||
dist.recv(recv_buff, left) | ||
accum[:] += recv[:] | ||
else: | ||
# Send recv_buff | ||
send_req = dist.isend(recv_buff, right) | ||
dist.recv(send_buff, left) | ||
accum[:] += send[:] | ||
send_req.wait() | ||
#th.cuda.synchronize() | ||
recv[:] = accum[:] | ||
|
||
|
||
def run(rank, size): | ||
""" Distributed function to be implemented later. """ | ||
# t = th.ones(2, 2) | ||
t = th.rand(2, 2).cuda() | ||
# for _ in range(10000000): | ||
for _ in range(4): | ||
c = t.clone() | ||
dist.all_reduce(c, dist.reduce_op.SUM) | ||
# allreduce(t, c) | ||
t.set_(c) | ||
print(t) | ||
|
||
def init_processes(rank, size, fn, backend='mpi'): | ||
""" Initialize the distributed environment. """ | ||
os.environ['MASTER_ADDR'] = '127.0.0.1' | ||
os.environ['MASTER_PORT'] = '29500' | ||
# dist.init_process_group(backend, rank=rank, world_size=size) | ||
dist.init_process_group(backend, world_size=size) | ||
fn(rank, size) | ||
|
||
|
||
if __name__ == "__main__": | ||
size = 1 | ||
processes = [] | ||
for rank in range(size): | ||
p = Process(target=init_processes, args=(rank, size, run)) | ||
p.start() | ||
processes.append(p) | ||
|
||
for p in processes: | ||
p.join() |
Oops, something went wrong.