diff --git a/deepspeed/utils/distributed.py b/deepspeed/utils/distributed.py index e70f00b440bb..c9722af21c24 100644 --- a/deepspeed/utils/distributed.py +++ b/deepspeed/utils/distributed.py @@ -77,9 +77,9 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True) os.environ['MASTER_PORT'])) if torch.distributed.is_initialized(): - assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) - assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( - world_size, dist.get_world_size()) + assert torch.distributed.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) + assert torch.distributed.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( + world_size, torch.distributed.get_world_size()) def in_aml():