Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions deepspeed/runtime/comm/nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@


class NcclBackend(object):
def __init__(self):
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
def __init__(self, mpu=None):
if mpu is None:
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
else:
self.mpu = mpu
self.world_group = self.mpu.get_data_parallel_group()
self.rank = dist.get_rank(group=self.world_group)
self.size = dist.get_world_size(group=self.world_group)
self.compression_backend = CupyBackend()
Expand Down Expand Up @@ -92,9 +96,9 @@ def compressed_allreduce(self,
# communication phase 1
# gather_start = time.time()
# Alltoall for sign
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed))
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group)
# Allgather for scale
dist.all_gather(recvbuf_scale, worker_scale)
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)

# gather_end = time.time()

Expand Down Expand Up @@ -151,8 +155,8 @@ def compressed_allreduce(self,
]

# Communication Phase 2
dist.all_gather(recvbuf_sign_server, server_sign_packed[0])
dist.all_gather(recvbuf_scale_server, server_scale)
dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group)
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)

cupy_server_sign_packed = None

Expand Down
5 changes: 4 additions & 1 deletion deepspeed/runtime/fp16/onebit/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self,
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.comm_backend_handle = NcclBackend()
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)

elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
Expand Down Expand Up @@ -254,8 +254,10 @@ def step(self, closure=None, grads=None):

if self.adam_freeze_key is False:
if state['step'] >= self.freeze_step:
print('OneBitAdam - starting compressed communication')
self.adam_freeze_key = True
self.deepspeed.enable_backward_allreduce = False
self.deepspeed.pipeline_enable_backward_allreduce = False

return loss

Expand All @@ -281,6 +283,7 @@ def load_state_dict(self, state_dict):
if self.adam_freeze_key is True:
self.adam_freeze_key = False
self.deepspeed.enable_backward_allreduce = True
self.deepspeed.pipeline_enable_backward_allreduce = True
else:
if torch.distributed.get_rank() == 0:
print(
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(self, *super_args, **super_kwargs):

# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False

# used to disable the pipeline all-reduce when used with 1-bit adam
self.pipeline_enable_backward_allreduce = True

assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
" with pipeline parallelism."

Expand Down Expand Up @@ -220,7 +224,7 @@ def _exec_reduce_tied_grads(self):

def _exec_reduce_grads(self):
self._force_grad_boundary = True
if self.is_data_parallel:
if self.is_data_parallel and self.pipeline_enable_backward_allreduce:
self.buffered_allreduce_fallback(
elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
self._force_grad_boundary = False
Expand Down