Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4778832
add nccl 1-bit optim.
awan-10 Nov 23, 2020
567232b
temporary commit to save stuff.
awan-10 Nov 24, 2020
79f6404
Use dist collectives instead of mpi routines.
awan-10 Nov 24, 2020
39b5949
Merge branch 'master' into amawa/1bit-adam-nccl
awan-10 Nov 24, 2020
57ab220
remove old code for comm.
awan-10 Nov 25, 2020
ebec1fe
Fix bugs. still does not work.
awan-10 Nov 25, 2020
3e6974d
modify to test the nccl side code path
awan-10 Nov 25, 2020
a72049b
Initial gather impl. Works intra-node.
awan-10 Nov 30, 2020
1bf1c27
Updates to comm. phase 2. nccl comm. passed the tests.
awan-10 Nov 30, 2020
886ebb5
refactor code to introduce nccl/mpi as backends for onebit adam.
awan-10 Dec 3, 2020
a38351e
Refactor updates to test/engine.
awan-10 Dec 3, 2020
716ac13
Merge branch 'master' into amawa/1-bit-refactor
awan-10 Dec 3, 2020
be75d88
Fix compile/runtime errors.
awan-10 Dec 3, 2020
7b7f122
simplify support for nccl/mpi backends.
awan-10 Dec 3, 2020
fd2c366
Add missign file
awan-10 Dec 3, 2020
df8c40d
Add compression backend in constructor. Revert later.
awan-10 Dec 4, 2020
f29ea3f
modify test with some perf counting.
awan-10 Dec 4, 2020
170ef02
Implement a true non-blocking gather for nccl side.
awan-10 Dec 7, 2020
e2ddf48
Revert "Add compression backend in constructor. Revert later."
awan-10 Dec 7, 2020
dbd3cff
improve the 1-bit adam test.
awan-10 Dec 7, 2020
7edc3ab
Refactor comm. and compression backend in 1-bit adam.
awan-10 Dec 8, 2020
0813d11
Fix the test.
awan-10 Dec 8, 2020
4c3c777
Fix runtime errors and typos in nccl backend
awan-10 Dec 8, 2020
d495c7a
fix mpi backend. modify tests.
awan-10 Dec 8, 2020
60f3344
modify nccl perf test.
awan-10 Dec 9, 2020
c1ab39e
fix mpi side errors.
awan-10 Dec 9, 2020
70938e1
Add an mpi perf test
awan-10 Dec 9, 2020
de63497
Merge branch 'master' into amawa/1-bit-refactor
awan-10 Dec 9, 2020
7aac018
Sync DSE.
awan-10 Dec 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
299 changes: 299 additions & 0 deletions deepspeed/runtime/comm/mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

import torch
import cupy
import time
import numpy as np
from mpi4py import MPI

from deepspeed.runtime.compression.cupy import CupyBackend


class MpiBackend(object):
def __init__(self, cuda_aware):
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.cuda_aware = cuda_aware
self.compression_backend = CupyBackend()

def my_igather(self, rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req

def gather_cuda(self,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers
requests = []
for idx in range(world_size):
req_sign = self.my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = self.my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)

def gather_host(self,
rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):

# In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)

# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed

for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])

numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)

cupy.cuda.get_current_stream().synchronize()

# 2. use numpy buffers for communication
requests = []

for idx in range(world_size):
req_sign = self.my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = self.my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)

# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])

cupy_worker_scale = cupy.asarray(numpy_worker_scale)
cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()

return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale

def allgather_cuda(self,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)

def allgather_host(self,
comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):

# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros(
[comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)

numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()

# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()

# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.asarray(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()

return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server

def compressed_allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
local_rank):

all_start_time = time.time()
original_size = buffer_m.numel()
cupy.cuda.Device(local_rank).use()

if torch.numel(buffer_m) != torch.numel(worker_error):
empty_tensor = torch.zeros(torch.numel(worker_error) - torch.numel(buffer_m),
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])

buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
sign_buffer_m = buffer_m.sign().add_(1).bool()
sign_buffer_m = sign_buffer_m.float()
sign_buffer_m.add_(-0.5).mul_(2.0)
worker_error.set_((buffer_m - worker_scale * sign_buffer_m))
sign_buffer_m = None

compensated_buffer_m = buffer_m
compensated_buffer_m.sign_()
compensated_buffer_m = compensated_buffer_m.add_(1).bool()
cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)
cupy_compensated_buffer_m = self.compression_backend.torch2cupy(
compensated_buffer_m)
compensated_buffer_m = None

cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
cupy_compensated_buffer_m,
self.size)
cupy_compensated_buffer_m = None

cupy_recvbuf_sign = cupy.zeros(
[self.size,
cupy_sign_list_packed[self.rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)

# Communication Phase 1
gather_start = time.time()
if self.cuda_aware:
self.gather_cuda(self.rank,
self.size,
self.comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
else:
cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale = self.gather_host(self.rank,
self.size,
self.comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
gather_end = time.time()

cupy_unpacked_sign = (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
self.size,
-1)
cupy_recvbuf_sign = None
unpacked_sign = self.compression_backend.cupy2torch(cupy_unpacked_sign).float()
cupy_unpacked_sign = None
unpacked_sign = unpacked_sign.add_(-0.5).mul_(2.0)
worker_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale).mul_(
1 / self.size)
compensated_server_m = unpacked_sign.mul_(worker_scale).sum(0)
unpacked_sign = None

compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
sign_server_m = compensated_server_m.sign().add_(1).bool()
sign_server_m = sign_server_m.float()
sign_server_m.add_(-0.5).mul_(2.0)
server_error.set_(compensated_server_m - server_scale * sign_server_m)
sign_server_m = None

compensated_server_m.sign_()
compensated_server_m = compensated_server_m.add_(1).bool()
cupy_server_scale = self.compression_backend.torch2cupy(server_scale)
cupy_compensated_server_m = self.compression_backend.torch2cupy(
compensated_server_m)
compensated_server_m = None

cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
cupy_compensated_server_m,
1)

cupy_recvbuf_sign_server = cupy.zeros(
[self.size,
cupy_server_sign_packed[0].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale_server = cupy.zeros([self.size,
1],
dtype=cupy_worker_scale.dtype)

# Communication Phase 2
if self.cuda_aware:
self.allgather_cuda(self.comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)
else:
cupy_server_sign_packed[0], cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server = self.allgather_host(self.comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)

cupy_server_unpacked_sign = (cupy.unpackbits(
cupy_recvbuf_sign_server.flatten())).reshape(self.size,
-1)
cupy_recvbuf_sign_server = None

server_unpacked_sign = self.compression_backend.cupy2torch(
cupy_server_unpacked_sign)
cupy_server_unpacked_sign = None

server_unpacked_sign = server_unpacked_sign.float().add_(-0.5).mul_(2.0)
server_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale_server)
buffer_m = server_unpacked_sign.mul_(server_scale).flatten()[0:original_size]

return buffer_m
Loading