diff --git a/deepspeed/runtime/comm/nccl.py b/deepspeed/runtime/comm/nccl.py index 0ac2646bd0d7..94ea2a19bed2 100644 --- a/deepspeed/runtime/comm/nccl.py +++ b/deepspeed/runtime/comm/nccl.py @@ -12,8 +12,12 @@ class NcclBackend(object): - def __init__(self): - self.world_group = dist.new_group(ranks=range(dist.get_world_size())) + def __init__(self, mpu=None): + if mpu is None: + self.world_group = dist.new_group(ranks=range(dist.get_world_size())) + else: + self.mpu = mpu + self.world_group = self.mpu.get_data_parallel_group() self.rank = dist.get_rank(group=self.world_group) self.size = dist.get_world_size(group=self.world_group) self.compression_backend = CupyBackend() @@ -92,9 +96,9 @@ def compressed_allreduce(self, # communication phase 1 # gather_start = time.time() # Alltoall for sign - dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed)) + dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) # Allgather for scale - dist.all_gather(recvbuf_scale, worker_scale) + dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) # gather_end = time.time() @@ -151,8 +155,8 @@ def compressed_allreduce(self, ] # Communication Phase 2 - dist.all_gather(recvbuf_sign_server, server_sign_packed[0]) - dist.all_gather(recvbuf_scale_server, server_scale) + dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) + dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) cupy_server_sign_packed = None diff --git a/deepspeed/runtime/fp16/onebit/adam.py b/deepspeed/runtime/fp16/onebit/adam.py index e3417fea9d6f..a15565b12edd 100644 --- a/deepspeed/runtime/fp16/onebit/adam.py +++ b/deepspeed/runtime/fp16/onebit/adam.py @@ -94,7 +94,7 @@ def __init__(self, assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend - self.comm_backend_handle = NcclBackend() + self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend @@ -254,8 +254,10 @@ def step(self, closure=None, grads=None): if self.adam_freeze_key is False: if state['step'] >= self.freeze_step: + print('OneBitAdam - starting compressed communication') self.adam_freeze_key = True self.deepspeed.enable_backward_allreduce = False + self.deepspeed.pipeline_enable_backward_allreduce = False return loss @@ -281,6 +283,7 @@ def load_state_dict(self, state_dict): if self.adam_freeze_key is True: self.adam_freeze_key = False self.deepspeed.enable_backward_allreduce = True + self.deepspeed.pipeline_enable_backward_allreduce = True else: if torch.distributed.get_rank() == 0: print( diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 573dccce78a5..2c2196647bb4 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -54,6 +54,10 @@ def __init__(self, *super_args, **super_kwargs): # We schedule the all-reduces, so disable it in super().backward() self.enable_backward_allreduce = False + + # used to disable the pipeline all-reduce when used with 1-bit adam + self.pipeline_enable_backward_allreduce = True + assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ " with pipeline parallelism." @@ -220,7 +224,7 @@ def _exec_reduce_tied_grads(self): def _exec_reduce_grads(self): self._force_grad_boundary = True - if self.is_data_parallel: + if self.is_data_parallel and self.pipeline_enable_backward_allreduce: self.buffered_allreduce_fallback( elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) self._force_grad_boundary = False