Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Clean up environment access in plugins (Lightning-AI#6941)
Browse files Browse the repository at this point in the history
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
3 people committed Apr 13, 2021
1 parent e8b48f1 commit f851bcc
Showing 23 changed files with 732 additions and 94 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))


- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941))


## [1.2.7] - 2021-04-06

### Fixed
17 changes: 16 additions & 1 deletion pytorch_lightning/plugins/environments/cluster_environment.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod


class ClusterEnvironment:
@@ -24,9 +25,23 @@ def master_address(self):
def master_port(self):
pass

@abstractmethod
def world_size(self) -> int:
return self._world_size
""" The number of processes across all devices and nodes. """

@abstractmethod
def set_world_size(self, size: int) -> None:
pass

@abstractmethod
def global_rank(self) -> int:
""" The rank (index) of the currently running process across all nodes and devices. """

@abstractmethod
def set_global_rank(self, rank: int) -> None:
pass

@abstractmethod
def local_rank(self) -> int:
pass

83 changes: 83 additions & 0 deletions pytorch_lightning/plugins/environments/lightning_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import socket

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_only


class LightningEnvironment(ClusterEnvironment):
"""
The default environment used by Lightning for a single node or free cluster (not managed).
The master process must be launched by the user and Lightning will spawn new
worker processes for distributed training, either in a single node or across multiple nodes.
If the master address and port are not provided, the default environment will choose them
automatically. It is recommended to use this default environment for single-node distributed
training as it provides the most convenient way to launch the training script.
"""

def __init__(self):
super().__init__()
self._master_port = None
self._global_rank: int = 0
self._world_size: int = 1

def creates_children(self) -> bool:
return False

def master_address(self) -> str:
return os.environ.get("MASTER_ADDR", "127.0.0.1")

def master_port(self) -> int:
if self._master_port is None:
self._master_port = os.environ.get("MASTER_PORT", find_free_network_port())
return int(self._master_port)

def world_size(self) -> int:
return self._world_size

def set_world_size(self, size: int) -> None:
self._world_size = size

def global_rank(self) -> int:
return self._global_rank

def set_global_rank(self, rank: int) -> None:
self._global_rank = rank
rank_zero_only.rank = rank

def local_rank(self) -> int:
return int(os.environ.get("LOCAL_RANK", 0))

def node_rank(self) -> int:
group_rank = os.environ.get("GROUP_RANK", 0)
return int(os.environ.get("NODE_RANK", group_rank))


def find_free_network_port() -> int:
"""
Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real master node but
have to set the `MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
21 changes: 15 additions & 6 deletions pytorch_lightning/plugins/environments/slurm_environment.py
Original file line number Diff line number Diff line change
@@ -23,10 +23,10 @@

class SLURMEnvironment(ClusterEnvironment):

def __init__(self):
super().__init__()
def creates_children(self) -> bool:
return True

def master_address(self):
def master_address(self) -> str:
# figure out the root node addr
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if slurm_nodelist:
@@ -66,10 +66,19 @@ def master_port(self):

return default_port

def world_size(self):
return self._world_size
def world_size(self) -> int:
return int(os.environ["SLURM_NTASKS"])

def local_rank(self):
def set_world_size(self, size: int) -> None:
log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["SLURM_PROCID"])

def set_global_rank(self, rank: int) -> None:
log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self) -> int:
return int(os.environ['SLURM_LOCALID'])

def node_rank(self):
20 changes: 17 additions & 3 deletions pytorch_lightning/plugins/environments/torchelastic_environment.py
Original file line number Diff line number Diff line change
@@ -23,8 +23,11 @@

class TorchElasticEnvironment(ClusterEnvironment):

def __init__(self):
super().__init__()
@staticmethod
def is_using_torchelastic() -> bool:
""" Returns ``True`` if the current process was launched using the torchelastic command. """
required_env_vars = ("RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
return all(v in os.environ for v in required_env_vars)

def master_address(self):
if "MASTER_ADDR" not in os.environ:
@@ -46,7 +49,18 @@ def master_port(self):
def world_size(self):
return os.environ.get('WORLD_SIZE')

def local_rank(self):
def set_world_size(self, size: int) -> None:
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["RANK"])

def set_global_rank(self, rank: int) -> None:
log.debug(
"TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
)

def local_rank(self) -> int:
return int(os.environ['LOCAL_RANK'])

def node_rank(self) -> int:
49 changes: 38 additions & 11 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
@@ -74,8 +74,8 @@ def __init__(
self._ddp_kwargs = kwargs
self._has_spawned_children = False
self.task_idx = None
self.node_rank = 0
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
self.set_world_ranks()

@property
def root_device(self):
@@ -171,18 +171,46 @@ def _call_children_scripts(self):
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

def setup_distributed(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection()

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
"You tried to run `.fit` or `.test` multiple times in the same script."
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
)

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes
def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
@@ -215,12 +243,11 @@ def determine_ddp_device_ids(self):
return None
return [self.root_device.index]

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
# TODO: From where to get cluster environment?
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None:
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
14 changes: 10 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,14 @@

class DDP2Plugin(DDPPlugin):

@property
def global_rank(self) -> int:
return self.node_rank

@property
def world_size(self) -> int:
return self.num_nodes

def setup(self, model):
self._model = model
# set the task idx
@@ -48,7 +56,5 @@ def distributed_sampler_kwargs(self):
return distributed_sampler_kwargs

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
self.global_rank = self.node_rank
self.world_size = self.num_nodes
self.cluster_environment.set_global_rank(self.node_rank)
self.cluster_environment.set_world_size(self.num_nodes)
28 changes: 17 additions & 11 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
@@ -59,9 +59,14 @@ def __init__(
self.sync_batchnorm = sync_batchnorm
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices)
self.node_rank = 0
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.mp_queue = None
self._local_rank = 0
self.set_world_ranks()

@property
def local_rank(self) -> int:
return self._local_rank

def __getstate__(self):
""" Makes this plugin pickleable without destroying the queue in the current process. """
@@ -90,12 +95,12 @@ def setup(self, model):
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()

def set_world_ranks(self, process_idx):
self.local_rank = process_idx
self.node_rank = self.cluster_environment.node_rank()
self.task_idx = self.cluster_local_rank
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes
def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

@property
def mp_spawn_kwargs(self):
@@ -198,11 +203,12 @@ def configure_ddp(self):
**self._ddp_kwargs,
)

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
def init_ddp_connection(self, global_rank: Optional[int], world_size: Optional[int]) -> None:
# TODO: this code is duplicated in DDP and DDPSpawn, make this a function
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
16 changes: 16 additions & 0 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,22 @@ class DataParallelPlugin(ParallelPlugin):
def __init__(self, parallel_devices: Optional[List[torch.device]]):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)

@property
def global_rank(self) -> int:
return 0

@property
def local_rank(self) -> int:
return 0

@property
def node_rank(self) -> int:
return 0

@property
def world_size(self) -> int:
return 1

def setup(self, model):
# model needs to be moved to the device before it is wrapped
model.to(self.root_device)
Loading

0 comments on commit f851bcc

Please sign in to comment.