-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ref: adding compute environments (1/n) #3837
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment | ||
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment | ||
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
class ClusterEnvironment: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to the file general - missing Licence |
||
|
||
def __init__(self, world_size): | ||
self._world_size = world_size | ||
|
||
def master_address(self): | ||
pass | ||
|
||
def master_port(self): | ||
pass | ||
|
||
def world_size(self): | ||
return self._world_size |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import re | ||
from pytorch_lightning import _logger as log | ||
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment | ||
|
||
|
||
class SLURMEnvironment(ClusterEnvironment): | ||
|
||
def __init__(self, world_size): | ||
super().__init__(world_size) | ||
|
||
def master_address(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall it be rather a property? |
||
# figure out the root node addr | ||
try: | ||
root_node = os.environ["SLURM_NODELIST"].split(" ")[0] | ||
except Exception: | ||
root_node = "127.0.0.1" | ||
|
||
root_node = self._resolve_root_node_address(root_node) | ||
os.environ["MASTER_ADDR"] = root_node | ||
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") | ||
return root_node | ||
|
||
def master_port(self): | ||
# ----------------------- | ||
# SLURM JOB = PORT number | ||
# ----------------------- | ||
# this way every process knows what port to use | ||
try: | ||
# use the last 4 numbers in the job id as the id | ||
default_port = os.environ["SLURM_JOB_ID"] | ||
default_port = default_port[-4:] | ||
|
||
# all ports should be in the 10k+ range | ||
default_port = int(default_port) + 15000 | ||
|
||
except Exception: | ||
default_port = 12910 | ||
|
||
# ----------------------- | ||
# PORT NUMBER = MASTER_PORT | ||
# ----------------------- | ||
# in case the user passed it in | ||
try: | ||
default_port = os.environ["MASTER_PORT"] | ||
except Exception: | ||
os.environ["MASTER_PORT"] = str(default_port) | ||
|
||
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") | ||
|
||
return default_port | ||
|
||
def world_size(self): | ||
return self._world_size | ||
|
||
def _resolve_root_node_address(self, root_node): | ||
if '[' in root_node: | ||
name, numbers = root_node.split('[', maxsplit=1) | ||
number = numbers.split(',', maxsplit=1)[0] | ||
if '-' in number: | ||
number = number.split('-')[0] | ||
|
||
number = re.sub('[^0-9]', '', number) | ||
root_node = name + number | ||
|
||
return root_node |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
from pytorch_lightning import _logger as log | ||
from pytorch_lightning.utilities import rank_zero_warn | ||
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment | ||
|
||
|
||
class TorchElasticEnvironment(ClusterEnvironment): | ||
|
||
def __init__(self, world_size): | ||
super().__init__(world_size) | ||
|
||
def master_address(self): | ||
if "MASTER_ADDR" not in os.environ: | ||
rank_zero_warn( | ||
"MASTER_ADDR environment variable is not defined. Set as localhost" | ||
) | ||
os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") | ||
master_address = os.environ.get('MASTER_ADDR') | ||
return master_address | ||
|
||
def master_port(self): | ||
if "MASTER_PORT" not in os.environ: | ||
rank_zero_warn( | ||
"MASTER_PORT environment variable is not defined. Set as 12910" | ||
) | ||
os.environ["MASTER_PORT"] = "12910" | ||
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") | ||
|
||
port = os.environ.get('MASTER_PORT') | ||
return port | ||
|
||
def world_size(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this probably also does not change during training, so a property? |
||
return os.environ.get('WORLD_SIZE', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we rename this package just to
environments
as TE is not a true cluster and there may be another non cluster evns