Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 56 additions & 51 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,31 @@ def __init__(self,
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
self.progressive_layer_drop = None
self.dist_backend = "nccl"

if dist_init_required is None:
dist_init_required = not dist.is_initialized()

if self._in_aml():
self._set_environment_variables_for_nccl_backend(args)
else:
self._mpi_check(args, dist_init_required)
if dist_init_required is False:
assert (dist.is_initialized()==True), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"

self.dist_backend = "nccl"
if dist_init_required:
if not dist.is_initialized():
logger.info("Initializing torch distributed with backend: {}".format(
self.dist_backend))
dist.init_process_group(backend=self.dist_backend)
# DeepSpeed will initialize torch distributed only if the user has not already intialized it.
if dist_init_required and not dist.is_initialized():
# discover using mpi4py if user specifies the flag
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
# if in Azure ML environment and user specified this flag, notify the user to remove the flag.
if self._in_aml():
logger.warning(
"Please remove the --deepspeed_mpi flag if running on AzureML.")
self._mpi_check(args, dist_init_required)
else:
logger.warning(
"Was given dist_init_required=True but detected that torch"
"distributed was already initialized, cannot initialize twice.")
# detect if we are in Azure ML environment
if self._in_aml():
self._set_environment_variables_for_nccl_backend(args)

logger.info("Initializing torch distributed with backend: {}".format(
self.dist_backend))
dist.init_process_group(backend=self.dist_backend)

self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
Expand Down Expand Up @@ -203,7 +209,7 @@ def __init__(self,
self.unflatten = util_ops.unflatten

def _in_aml(self):
# read and environment variable to detect if we are using an Azure ML environment
# read AzureML environment variable to detect if we are using an Azure ML environment
if 'AZUREML_EXPERIMENT_ID' in os.environ:
return True
else:
Expand Down Expand Up @@ -246,43 +252,42 @@ def _set_environment_variables_for_nccl_backend(self,
os.environ['MASTER_PORT']))

def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()

master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)

# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])

os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
args.local_rank = local_rank
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT

logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
args.local_rank,
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))

if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()

master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)

# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])

os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
args.local_rank = local_rank
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT

logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
args.local_rank,
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))

if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())

def pld_enabled(self):
return self._config.pld_enabled
Expand Down