Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ddp] Support multi-node distributed execution under torchelastic #1811

Merged
merged 1 commit into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for Pytorch elastic distributed launch environment ([#1811](https://github.com/PyTorchLightning/pytorch-lightning/pull/1811))

- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498))

- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564))
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,11 +943,12 @@ def init_ddp_connection(
os.environ['MASTER_PORT'] = '12910'
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size:
log.warning("WORLD_SIZE environment variable is not equal to the computed "
"world size. Ignored.")
if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
log.warning(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
f"is not equal to the computed world size ({world_size}). Ignored.")

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(f"initializing proc_rank {proc_rank} world {world_size}")
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)

def configure_apex(
Expand Down
28 changes: 19 additions & 9 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,25 @@ def configure_slurm_ddp(self, num_gpu_nodes):
if self.is_slurm_managing_tasks:
log.info('Multi-processing is handled by Slurm.')

def determine_ddp_node_rank(self):
if self.is_slurm_managing_tasks:
return int(os.environ['SLURM_NODEID'])

# torchelastic uses the envvar GROUP_RANK, whereas other systems(?) use NODE_RANK.
# otherwise use given node rank or default to node rank 0
env_vars = ['NODE_RANK', 'GROUP_RANK']
node_ids = [(k, os.environ.get(k, None)) for k in env_vars]
node_ids = [(k, v) for k, v in node_ids if v is not None]
if len(node_ids) == 0:
log.warning("No environment variable for node rank defined. Set as 0.")
return 0
if len(node_ids) > 1:
log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. "
f"Using the first one.")
k, rank = node_ids.pop()
log.info(f"Using environment variable {k} for node rank ({rank}).")
return int(rank)

def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
if data_parallel_device_ids is None:
return
Expand All @@ -305,15 +324,6 @@ def ddp_train(self, process_idx, model):
:param cluster_obj:
:return:
"""
# node rank using relative slurm id if under slurm management
# otherwise use given node rank or default to node rank 0
try:
node_id = os.environ['SLURM_NODEID'] if self.is_slurm_managing_tasks else os.environ['NODE_RANK']
self.node_rank = int(node_id)
except KeyError:
log.warning("SLURM_NODEID or NODE_RANK environment variable is not defined. Set as 0.")
self.node_rank = 0

# show progressbar only on progress_rank 0
if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ def __init__(
# init flags for SLURM+ddp to work
self.proc_rank = 0
self.world_size = 1
self.node_rank = 0
self.configure_slurm_ddp(self.num_nodes)
self.node_rank = self.determine_ddp_node_rank()

# nvidia setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
Expand Down Expand Up @@ -796,11 +796,14 @@ def fit(
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)

elif self.use_ddp:
if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
# torchelastic
elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what we need to do to get this working on slurm. It could be as simple as using the LOCAL_RANK environment instead of the SLURM_LOCALID.

I'll look into it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay so looked into it and it's not that simple. Basically, much like distributed training, there's a few ways to initialize elastic training. However, because elastic training needs to own the processes to work, slurm can't spawn them for it.

In distributed training you have the options:

  1. run python -m torch.distributed.launch <train_script.py>, which creates the processes for you
  2. start the processes yourself using mp (e.g. in the else of this if statement)
  3. let slurm (or another scheduler) create the processes

In elastic training you have the options:

  1. run python3 -m torchelastic.distributed.launch, which creates the processes and handles the fault-tolerence and elastic workers.
  2. create an elastic agent such as LocalElasticAgent. This will spawn the elastic processes and manage them with the synchronous function LocalElasticAgent.run().

This doesn't leave a particularly easy way to do slurm because the agent needs to spawn the processes. My guess is that you need to configure slurm to have 1 process per node (i.e. ntasks-per-node=1) and then create the agent and processes as explained in 2. at the beginning of training. You'd also need to setup the distributed key-value store backend (Etcd or Zeus). Luckily they've provided a helpful python API for spawning Etcd server.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tullie I don't know what you mean. Lightning already works correctly under a Slurm managed task environment. Do you mean having the same code for both pytorch elastic and slurm?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this should work fine no?
If this is the route we take with elastic it means that something else created the process that called each script. Is that the expected behavior @ashwinb @tullie (I haven't used elastic yet).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, elastic launches agents on each node which manage the individual worker processes. Lightning's job in that case is to init its process group and configure and run a single trainer worker. This is just like Slurm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All I was saying is that this PR doesn't add support for Elastic Pytorch in a Slurm managed environment. This is fine for now but ideally they'd be able to work together in the future.

task = int(os.environ['LOCAL_RANK'])
self.ddp_train(task, model)
else:
self.__set_random_port()
# track for predict
Expand Down