Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored PTH DDP env vars creation in SLURM #2206

Merged
merged 7 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 98 additions & 27 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import subprocess
import warnings
from distutils.version import LooseVersion
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -210,21 +210,13 @@ def setup_env_vars(self, rank: Optional[int] = None, world_size: Optional[int] =

self._env_backup = os.environ.copy()

# check whether all necessary env vars are set or not
env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
all_env_vars_defined = [k in os.environ for k in env_vars]

if "SLURM_JOB_ID" in os.environ:
if any(all_env_vars_defined):
raise RuntimeError(
f"Defined env variables '{env_vars}' should not be specified with SLURM. Typically, this "
"happens when `torch.distributed.launch` or `torch.multiprocessing.spawn` are used. Please be "
"sure to use the `srun` command instead."
)
if rank is not None or world_size is not None:
raise ValueError("Arguments rank and world_size should not be specified with SLURM")
self._setup_env_in_slurm()
else:
env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
all_env_vars_defined = [k in os.environ for k in env_vars]
# check if all necessary env vars are set
# if partially defined raise an error
if any(all_env_vars_defined) and not all(all_env_vars_defined):
Expand All @@ -243,25 +235,23 @@ def setup_env_vars(self, rank: Optional[int] = None, world_size: Optional[int] =
self._master_port = int(os.environ["MASTER_PORT"])

def _setup_env_in_slurm(self) -> None:
for k in ["SLURM_JOB_ID", "SLURM_PROCID", "SLURM_LOCALID", "SLURM_NTASKS", "SLURM_JOB_NODELIST"]:
slurm_env_req_vars = [
"SLURM_JOB_ID",
"SLURM_PROCID",
"SLURM_LOCALID",
"SLURM_NTASKS",
"SLURM_JOB_NODELIST",
"SLURM_JOB_NUM_NODES",
]
for k in slurm_env_req_vars:
if k not in os.environ:
raise RuntimeError(f"SLURM distributed configuration is missing '{k}' in env variables")

os.environ["RANK"] = os.environ["SLURM_PROCID"]
os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"]
# port should be the same over all process
slurm_port = os.environ["SLURM_JOB_ID"]
slurm_port = slurm_port[-4:]
os.environ["MASTER_PORT"] = str(int(slurm_port) + 15000)
try:
# use scontrol to expand hostname list
hostnames = subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]])
except FileNotFoundError:
# expand hostname list as scontrol
hostnames = " ".join(_expand_hostlist(os.environ["SLURM_JOB_NODELIST"])).encode("utf-8")
# master address is the first hostname of nodes list
os.environ["MASTER_ADDR"] = hostnames.split()[0].decode("utf-8")
ddp_vars = _setup_ddp_vars_from_slurm_env(cast(Dict, os.environ))

# define DDP env vars required by PTH:
for key, value in ddp_vars.items():
os.environ[key] = str(value)

def get_local_rank(self) -> int:
return cast(int, self._local_rank)
Expand Down Expand Up @@ -499,3 +489,84 @@ def _expand_hostlist(nodelist: str) -> List[str]:
result_hostlist += final_hostlist

return result_hostlist

def _setup_ddp_vars_from_slurm_env(environ: Dict[str, str]) -> Dict[str, Union[str, int]]:
"""Method to setup DDP env vars required by PyTorch from SLURM env
"""
# 1) Tools like enroot can have hooks to translate slurm env vars to RANK, LOCAL_RANK, WORLD_SIZE etc
# See https://github.com/NVIDIA/enroot/blob/v3.1.0/conf/hooks/extra/50-slurm-pytorch.sh
# 2) User can use torch.distributed.launch tool to schedule on N local GPUs using 1 node, 1 task by SLURM
# To cover case 1), let's ensure that defined RANK == SLURM_PROCID, LOCAL_RANK == SLURM_LOCALID,
# WORLD_SIZE == SLURM_NTASKS. We will use defined MASTER_ADDR and MASTER_PORT instead of defining
# them by our means
# To cover case 2), let's check that defined RANK >= SLURM_PROCID, LOCAL_RANK >= SLURM_LOCALID,
Copy link
Contributor

@sdesrozis sdesrozis Sep 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly what is done, the idea is to ensure that in such a case, the user didn't use srun as a mistake

srun python -m torch.distributed.launch ...

Therefore, every process should have a rank, local rank and world size greater or equal to what is defined by slurm.

Is it correct ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case 2 is to cover use-case like :

srun -N1 -n1 -G8 python -m torch.distributed.launch\
        --nproc_per_node=8 --nnodes=1 --node_rank=0 \
        --master_addr="localhost" --master_port=1234 \
        main.py

Copy link
Contributor

@sdesrozis sdesrozis Sep 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case 2)

RANK >= SLURM_PROCID, LOCAL_RANK >= SLURM_LOCALID means that one process is spawn on the node by srun
but
WORLD_SIZE >= SLURM_NTASKS sounds weird. SLURM_NTASKS is the max number of tasks checked by slurm. If WORLD_SIZE is greater to that value, the scheduler should kill the process, because more than allocated ressources for the job are used. I think it could be an issue using gloo.

What do you think ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, it works on my side.

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Sep 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example above srun -N1 -n1 allocates SLURM_NTASKS=1, but launcher creates ws=8. That's why WORLD_SIZE >= SLURM_NTASKS. Am I missing something ?

See here as well : https://www.hpcworkshops.com/08-ml-on-parallelcluster/03-distributed-data-parallel.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but the slurm scheduler can kill the job if more ressources than allocated are used. I suppose it depends on how the scheduler is configured. Imagine you schedule a job defining 4 tasks and in fact you use 8, it could be a big issue in production. Anyway, I think that does not really matter. This is a user constraint, we can't handle.

# WORLD_SIZE >= SLURM_NTASKS, SLURM_JOB_NUM_NODES == 1

ddp_vars: Dict[str, Union[str, int, None]] = {
"RANK": int(environ["SLURM_PROCID"]),
"LOCAL_RANK": int(environ["SLURM_LOCALID"]),
"WORLD_SIZE": int(environ["SLURM_NTASKS"]),
"MASTER_ADDR": None,
"MASTER_PORT": None,
}

pth_ddp_env_vars = {key: environ.get(key, None) for key in ddp_vars}
defined_pth_ddp_env_vars = [v is not None for v in pth_ddp_env_vars.values()]
if all(defined_pth_ddp_env_vars):
nnodes = int(environ["SLURM_JOB_NUM_NODES"])
if nnodes > 1:
# ensure that all pth_ddp_env_vars are consistent with slurm vars
for key in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
slurm_var = cast(int, ddp_vars[key])
pth_var = int(cast(str, pth_ddp_env_vars[key]))
if slurm_var != pth_var:
raise RuntimeError(
"Environment variable defined for PyTorch Distributed context is inconsistent with "
f"equivalent SLURM env variable. {key}: {pth_var} vs {slurm_var}\n"
f"SLURM vars: {ddp_vars}\n"
f"PTH vars: {pth_ddp_env_vars}\n"
)
else:
# ensure that PTH RANK >= SLURM_PROCID, PTH LOCAL_RANK >= SLURM_LOCALID,
# PTH WORLD_SIZE >= SLURM_NTASKS
for key in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
slurm_var = cast(int, ddp_vars[key])
pth_var = int(cast(str, pth_ddp_env_vars[key]))
if pth_var < slurm_var:
raise RuntimeError(
"Environment variable defined for PyTorch Distributed context is "
"inconsistent with equivalent SLURM env variable. "
f"We expect that {key}: {pth_var} >= {slurm_var}\n"
f"SLURM vars: {ddp_vars}\n"
f"PTH vars: {pth_ddp_env_vars}\n"
)
ddp_vars[key] = pth_var
# set up MASTER_ADDR and MASTER_PORT from PTH
ddp_vars["MASTER_ADDR"] = cast(str, pth_ddp_env_vars["MASTER_ADDR"])
ddp_vars["MASTER_PORT"] = int(cast(str, pth_ddp_env_vars["MASTER_PORT"]))
elif any(defined_pth_ddp_env_vars):
# Let's warn user about PTH env variables that we could not taken into account
warnings.warn(
"We detected the following env variables: "
f"{[(k, v) for k, v in pth_ddp_env_vars.items() if v is not None]},\n"
"but will not take them into account as the following env vars are missing:"
f"{[k for k, v in pth_ddp_env_vars.items() if v is None]},\n"
)

if ddp_vars["MASTER_ADDR"] is None:
try:
# use scontrol to expand hostname list
hostnames = subprocess.check_output(["scontrol", "show", "hostnames", environ["SLURM_JOB_NODELIST"]])
except FileNotFoundError:
# expand hostname list as scontrol
hostnames = " ".join(_expand_hostlist(environ["SLURM_JOB_NODELIST"])).encode("utf-8")
# master address is the first hostname of nodes list
ddp_vars["MASTER_ADDR"] = str(hostnames.split()[0].decode("utf-8"))

if ddp_vars["MASTER_PORT"] is None:
# port should be the same over all process
slurm_port = environ["SLURM_JOB_ID"]
slurm_port = slurm_port[-4:]
ddp_vars["MASTER_PORT"] = int(slurm_port) + 15000

return cast(Dict[str, Union[str, int]], ddp_vars)
122 changes: 119 additions & 3 deletions tests/ignite/distributed/comp_models/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if not has_native_dist_support:
pytest.skip("Skip if no native dist support", allow_module_level=True)
else:
from ignite.distributed.comp_models.native import _expand_hostlist, _NativeDistModel
from ignite.distributed.comp_models.native import _expand_hostlist, _NativeDistModel, _setup_ddp_vars_from_slurm_env


# tests from https://github.com/LLNL/py-hostlist/blob/master/hostlist/unittest_hostlist.py
Expand Down Expand Up @@ -125,17 +125,20 @@ def test__native_dist_model_create_from_backend_bad_slurm_config():
os.environ["SLURM_LOCALID"] = "0"
os.environ["SLURM_NTASKS"] = "1"
os.environ["SLURM_JOB_NODELIST"] = "localhost"
os.environ["SLURM_JOB_NUM_NODES"] = "1"

os.environ["RANK"] = "1"

with pytest.raises(RuntimeError, match=r"Defined env variables"):
_NativeDistModel.create_from_backend(backend="gloo", timeout=timedelta(seconds=10))
with pytest.warns(UserWarning, match=r"We detected the following env variables"):
model = _NativeDistModel.create_from_backend(backend="gloo", timeout=timedelta(seconds=10))
model.finalize()

del os.environ["SLURM_JOB_ID"]
del os.environ["SLURM_PROCID"]
del os.environ["SLURM_LOCALID"]
del os.environ["SLURM_NTASKS"]
del os.environ["SLURM_JOB_NODELIST"]
del os.environ["SLURM_JOB_NUM_NODES"]
del os.environ["RANK"]


Expand Down Expand Up @@ -239,6 +242,7 @@ def _test__native_dist_model_create_from_backend_slurm(local_rank, rank, world_s
os.environ["SLURM_LOCALID"] = str(local_rank)
os.environ["SLURM_NTASKS"] = str(world_size)
os.environ["SLURM_JOB_NODELIST"] = "localhost"
os.environ["SLURM_JOB_NUM_NODES"] = "1"

model = _NativeDistModel.create_from_backend(backend=backend, timeout=timeout)

Expand Down Expand Up @@ -268,6 +272,7 @@ def _test__native_dist_model_create_from_backend_slurm(local_rank, rank, world_s
del os.environ["SLURM_LOCALID"]
del os.environ["SLURM_NTASKS"]
del os.environ["SLURM_JOB_NODELIST"]
del os.environ["SLURM_JOB_NUM_NODES"]

assert "MASTER_ADDR" not in os.environ
assert "MASTER_PORT" not in os.environ
Expand Down Expand Up @@ -531,3 +536,114 @@ def test__native_dist_model_init_method_is_not_none(world_size, local_rank, get_

with pytest.raises(ValueError, match=r"Both rank and world_size should be provided"):
_NativeDistModel.create_from_backend(backend="gloo", rank=local_rank, init_method=init_method)


@pytest.mark.parametrize(
"environ, expected",
[
# fmt: off
# usual SLURM env
(

{
"SLURM_PROCID": "1", "SLURM_LOCALID": "1", "SLURM_NTASKS": "2", "SLURM_JOB_NUM_NODES": "1",
"SLURM_JOB_NODELIST": "c1", "SLURM_JOB_ID": "12345",
},
[1, 1, 2, "c1", 17345]
),
# usual SLURM env mnode
(
{
"SLURM_PROCID": "5", "SLURM_LOCALID": "1", "SLURM_NTASKS": "8", "SLURM_JOB_NUM_NODES": "2",
"SLURM_JOB_NODELIST": "c1, c2", "SLURM_JOB_ID": "12345",
},
[5, 1, 8, "c1", 17345]
),
# usual SLURM env 1 node, 1 task + torch.distributed.launch
(
{
"SLURM_PROCID": "0", "SLURM_LOCALID": "0", "SLURM_NTASKS": "1", "SLURM_JOB_NUM_NODES": "1",
"SLURM_JOB_NODELIST": "c1", "SLURM_JOB_ID": "12345",
"MASTER_ADDR": "127.0.0.1", "MASTER_PORT": "2233", "RANK": "2", "LOCAL_RANK": "2", "WORLD_SIZE": "8",
},
[2, 2, 8, "127.0.0.1", 2233]
),
# usual SLURM env + enroot's pytorch hook
(
{
"SLURM_PROCID": "3", "SLURM_LOCALID": "3", "SLURM_NTASKS": "4", "SLURM_JOB_NUM_NODES": "1",
"SLURM_JOB_NODELIST": "c1", "SLURM_JOB_ID": "12345",
"MASTER_ADDR": "c1", "MASTER_PORT": "12233", "RANK": "3", "LOCAL_RANK": "3", "WORLD_SIZE": "4",
},
[3, 3, 4, "c1", 12233]
),
# usual SLURM env mnode + enroot's pytorch hook
(
{
"SLURM_PROCID": "3", "SLURM_LOCALID": "1", "SLURM_NTASKS": "4", "SLURM_JOB_NUM_NODES": "2",
"SLURM_JOB_NODELIST": "c1, c2", "SLURM_JOB_ID": "12345",
"MASTER_ADDR": "c1", "MASTER_PORT": "12233", "RANK": "3", "LOCAL_RANK": "1", "WORLD_SIZE": "4"
},
[3, 1, 4, "c1", 12233]
),
# fmt: on
],
)
def test__setup_ddp_vars_from_slurm_env(environ, expected):
ddp_keys = ["RANK", "LOCAL_RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
ddp_vars = _setup_ddp_vars_from_slurm_env(environ)
for key, value in zip(ddp_keys, expected):
assert key in ddp_vars
assert ddp_vars[key] == value


def test__setup_ddp_vars_from_slurm_env_bad_configs():
with pytest.raises(
RuntimeError, match=r"Environment variable defined for PyTorch Distributed context is inconsistent"
):
environ = {
"SLURM_PROCID": "3",
"SLURM_LOCALID": "1",
"SLURM_NTASKS": "4",
"SLURM_JOB_NUM_NODES": "2",
"SLURM_JOB_NODELIST": "c1, c2",
"SLURM_JOB_ID": "12345",
"MASTER_ADDR": "another-addr",
"MASTER_PORT": "12233",
"RANK": "1",
"LOCAL_RANK": "1",
"WORLD_SIZE": "2",
}
_setup_ddp_vars_from_slurm_env(environ)

with pytest.raises(
RuntimeError, match=r"Environment variable defined for PyTorch Distributed context is inconsistent"
):
environ = {
"SLURM_PROCID": "1",
"SLURM_LOCALID": "1",
"SLURM_NTASKS": "4",
"SLURM_JOB_NUM_NODES": "1",
"SLURM_JOB_NODELIST": "c1",
"SLURM_JOB_ID": "12345",
"MASTER_ADDR": "another-addr",
"MASTER_PORT": "12233",
"RANK": "1",
"LOCAL_RANK": "1",
"WORLD_SIZE": "2",
}
_setup_ddp_vars_from_slurm_env(environ)

with pytest.warns(UserWarning, match=r"We detected the following env variables"):
environ = {
"SLURM_PROCID": "3",
"SLURM_LOCALID": "1",
"SLURM_NTASKS": "4",
"SLURM_JOB_NUM_NODES": "2",
"SLURM_JOB_NODELIST": "c1, c2",
"SLURM_JOB_ID": "12345",
"RANK": "1",
"LOCAL_RANK": "1",
"WORLD_SIZE": "2",
}
_setup_ddp_vars_from_slurm_env(environ)