diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 6d3e7d3fffd..c846dd65d0f 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -130,8 +130,16 @@ def __init__( ), "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source" self.distributed_type = DistributedType.DEEPSPEED if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl", **kwargs) + from .utils import compare_versions + self.backend = "nccl" + if compare_versions("deepspeed", ">", "0.6.5"): + from deepspeed import comm as dist + + dist.init_distributed(dist_backend=self.backend) + else: + torch.distributed.init_process_group(backend="nccl", **kwargs) + self.num_processes = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))