Skip to content

Commit

Permalink
init DDP
Browse files Browse the repository at this point in the history
  • Loading branch information
jundet committed Nov 7, 2019
1 parent 08af110 commit 2016ea3
Show file tree
Hide file tree
Showing 14 changed files with 1,996 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,5 @@ venv.bak/

# mypy
.mypy_cache/

.idea/
164 changes: 164 additions & 0 deletions DDP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP
import os


def setup(rank, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'

# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)

# Explicitly setting seed to make sure that models created in two processes
# start from same random weights and biases.
torch.manual_seed(42)


def cleanup():
dist.destroy_process_group()


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
setup(rank, world_size)

# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
# rank 2 uses GPUs [4, 5, 6, 7].
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))

# create model and move it to device_ids[0]
model = ToyModel().to(device_ids[0])
# output_device defaults to device_ids[0]
ddp_model = DDP(model, device_ids=device_ids)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
loss_fn(outputs, labels).backward()
optimizer.step()

cleanup()


def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)


def demo_checkpoint(rank, world_size):
setup(rank, world_size)

# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
# rank 2 uses GPUs [4, 5, 6, 7].
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))

model = ToyModel().to(device_ids[0])
# output_device defaults to device_ids[0]
ddp_model = DDP(model, device_ids=device_ids)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
rank0_devices = [x - rank * len(device_ids) for x in device_ids]
device_pairs = zip(rank0_devices, device_ids)
map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))

optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()

# Use a barrier() to make sure that all processes have finished reading the
# checkpoint
dist.barrier()

if rank == 0:
os.remove(CHECKPOINT_PATH)

cleanup()


class ToyMpModel(nn.Module):
def __init__(self, dev0, dev1):
super(ToyMpModel, self).__init__()
self.dev0 = dev0
self.dev1 = dev1
self.net1 = torch.nn.Linear(10, 10).to(dev0)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to(dev1)

def forward(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)
return self.net2(x)


def demo_model_parallel(rank, world_size):
setup(rank, world_size)

# setup mp_model and devices for this process
dev0 = rank * 1
dev1 = rank * 1 + 1
mp_model = ToyMpModel(dev0, dev1)
ddp_mp_model = DDP(mp_model)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

optimizer.zero_grad()
# outputs will be on dev1
outputs = ddp_mp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(dev1)
loss_fn(outputs, labels).backward()
optimizer.step()

cleanup()


if __name__ == "__main__":
run_demo(demo_basic, 2)
run_demo(demo_checkpoint, 2)

if torch.cuda.device_count() >= 8:
run_demo(demo_model_parallel, 4)

67 changes: 67 additions & 0 deletions allreduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python
import os
import torch as th
import torch.distributed as dist
from torch.multiprocessing import Process


def allreduce(send, recv):
""" Implementation of a ring-reduce. """
rank = dist.get_rank()
size = dist.get_world_size()
send_buff = th.zeros(send.size())
recv_buff = th.zeros(send.size())
accum = th.zeros(send.size())
accum[:] = send[:]
#th.cuda.synchronize()

left = ((rank - 1) + size) % size
right = (rank + 1) % size

for i in range(size - 1):
if i % 2 == 0:
# Send send_buff
send_req = dist.isend(send_buff, right)
dist.recv(recv_buff, left)
accum[:] += recv[:]
else:
# Send recv_buff
send_req = dist.isend(recv_buff, right)
dist.recv(send_buff, left)
accum[:] += send[:]
send_req.wait()
#th.cuda.synchronize()
recv[:] = accum[:]


def run(rank, size):
""" Distributed function to be implemented later. """
# t = th.ones(2, 2)
t = th.rand(2, 2).cuda()
# for _ in range(10000000):
for _ in range(4):
c = t.clone()
dist.all_reduce(c, dist.reduce_op.SUM)
# allreduce(t, c)
t.set_(c)
print(t)

def init_processes(rank, size, fn, backend='mpi'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
# dist.init_process_group(backend, rank=rank, world_size=size)
dist.init_process_group(backend, world_size=size)
fn(rank, size)


if __name__ == "__main__":
size = 1
processes = []
for rank in range(size):
p = Process(target=init_processes, args=(rank, size, run))
p.start()
processes.append(p)

for p in processes:
p.join()
Loading

0 comments on commit 2016ea3

Please sign in to comment.