Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c5ce9d8
proof of concept for elastic training using pytorch
Jun 9, 2022
5635052
Add command line options for elastic training
Jun 16, 2022
9b5e72c
Remove functionAgent
Jun 16, 2022
4881ce8
Add NCCL BLOCKING ERROR flag to elastic training
Jun 16, 2022
9f1c997
transient change
Jun 16, 2022
183b6bf
Added DS elastic agent
Jun 17, 2022
4d024f0
Cleanup
Jun 30, 2022
1ab3104
pass environment variables to worker processes
Jul 1, 2022
f55b767
Enable elastic checkpoint for scale down in elastic training
Jul 1, 2022
e37c761
added detection of master addr and port on rank 0
Jul 2, 2022
a41757d
fixed formatting
Jul 5, 2022
95f70ce
Merge lastest master to elasticity branch
Jul 5, 2022
f637569
add launch.py and elastic_agent.py files to skip list in torchdist check
Jul 5, 2022
8aafc3a
add pytorch dependency for elastic training
Jul 6, 2022
b0a8802
add function for checking pytorch version
Jul 6, 2022
96d678b
added kill command for pdsh when SIGINT is received
Jul 10, 2022
3ac5119
re-enable elastic checkpoint assertion
Jul 14, 2022
a23594e
Merge branch 'staging-ft-elastic-v1' into arpan/elasticity
aj-prime Jul 14, 2022
4f9c535
Add support for variable batch size
Jul 19, 2022
d995fb3
Fix elasticity V2, enable pipeline parallelism in Elastic Training, a…
Jul 22, 2022
706ebce
updated elastic unit test
Jul 22, 2022
9063a94
added an assertion for moded-parallel support and added code to prote…
Jul 23, 2022
f4ace71
modified elastic training unit test, added config options for elastic…
Jul 28, 2022
bb4a7f3
resolved conflicts
Jul 28, 2022
6ed8066
fixed a typo
Jul 28, 2022
f2405bd
fixed test_elastic
Jul 28, 2022
66205a1
removed extra imports
Jul 28, 2022
7e601b3
renamed min and max nodes arguments
Jul 29, 2022
6de2cd8
use deafult elastic ID
Jul 29, 2022
ae26b52
expose elastic run id as an env variable
Jul 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
name: check-torchdist
entry: ./scripts/check-torchdist.py
language: script
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm

- repo: https://github.com/codespell-project/codespell
Expand Down
1 change: 0 additions & 1 deletion deepspeed/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@
# To make an attempt at backwards compatibility with THD, we use an
# extraordinarily high default timeout, given that THD did not have timeouts.
default_pg_timeout = timedelta(minutes=30)

INFERENCE_GENERIC_MODE = 'generic'
INFERENCE_SPECIALIZED_MODE = 'specialized'
4 changes: 4 additions & 0 deletions deepspeed/elasticity/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config
from .utils import is_torch_elastic_compatible
from .constants import ENABLED, ENABLED_DEFAULT, ELASTICITY
if is_torch_elastic_compatible():
from .elastic_agent import DSElasticAgent
15 changes: 15 additions & 0 deletions deepspeed/elasticity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, param_dict):

self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)

if self.min_gpus < 1 or self.max_gpus < 1:
raise ElasticityConfigError(
"Elasticity min/max gpus must be > 0, "
Expand All @@ -88,6 +89,20 @@ def __init__(self, param_dict):
"Elasticity min_gpus cannot be greater than max_gpus, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")

self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE,
MODEL_PARLLEL_SIZE_DEFAULT)
if self.model_parallel_size < 1:
raise ElasticityConfigError(
"Model-Parallel size cannot be less than 1, "
f"given model-parallel size: {self.model_parallel_size}")

self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE,
NUM_GPUS_PER_NODE_DEFAULT)
if self.num_gpus_per_node < 1:
raise ElasticityConfigError(
"Number of GPUs per node cannot be less than 1, "
f"given number of GPUs per node: {self.num_gpus_per_node}")

self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
if self.min_time < 0:
raise ElasticityConfigError(
Expand Down
8 changes: 7 additions & 1 deletion deepspeed/elasticity/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ELASTICITY = 'elasticity'

# Current elasticity version
LATEST_ELASTICITY_VERSION = 0.1
LATEST_ELASTICITY_VERSION = 0.2

ENABLED = 'enabled'
ENABLED_DEFAULT = False
Expand All @@ -46,6 +46,12 @@
MAX_GPUS = 'max_gpus'
MAX_GPUS_DEFAULT = 10000

NUM_GPUS_PER_NODE = 'num_gpus_per_node'
NUM_GPUS_PER_NODE_DEFAULT = 1

MODEL_PARLLEL_SIZE = "model_parallel_size"
MODEL_PARLLEL_SIZE_DEFAULT = 1

# Minimum running time (minutes) before the scheduler will scale us, 0 implies it's unknown
MIN_TIME = "min_time"
MIN_TIME_DEFAULT = 0
Expand Down
187 changes: 187 additions & 0 deletions deepspeed/elasticity/elastic_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from typing import Any, Dict, Optional, Tuple
from datetime import datetime
from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port
from torch.distributed.elastic.metrics import put_metric
from torch.distributed.elastic.agent.server.api import (
RunResult,
WorkerGroup,
WorkerSpec,
WorkerState,
)
from torch.distributed import Store
import time
import os
from torch.distributed.elastic.multiprocessing import start_processes
from torch.distributed.elastic.utils import macros
import shutil
import copy
from contextlib import closing
import subprocess


class DSElasticAgent(LocalElasticAgent):
def __init__(
self,
spec: WorkerSpec,
env: Dict,
start_method="spawn",
exit_barrier_timeout: float = 300,
log_dir: Optional[str] = None,
):
super().__init__(spec, start_method, exit_barrier_timeout, log_dir)
self.ds_env = env

@staticmethod
def _set_master_addr_port(store: Store,
master_addr: Optional[str],
master_port: Optional[int]):
if master_port is None:
sock = _get_socket_with_port()
with closing(sock):
master_port = sock.getsockname()[1]

if master_addr is None:
# master_addr = _get_fq_hostname()
result = subprocess.check_output("hostname -I", shell=True)
master_addr = result.decode('utf-8').split()[0]

store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))

def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
spec = worker_group.spec
store = worker_group.store
assert store is not None
master_addr, master_port = super()._get_master_addr_port(store)
restart_count = spec.max_restarts - self._remaining_restarts

use_agent_store = spec.rdzv_handler.get_backend() == "static"

args: Dict[int, Tuple] = {}
envs: Dict[int, Dict[str, str]] = {}
for worker in worker_group.workers:
local_rank = worker.local_rank

worker_env_ds = copy.deepcopy(self.ds_env)
worker_env_elastic = {
"LOCAL_RANK": str(local_rank),
"RANK": str(worker.global_rank),
"GROUP_RANK": str(worker_group.group_rank),
"ROLE_RANK": str(worker.role_rank),
"ROLE_NAME": spec.role,
"LOCAL_WORLD_SIZE": str(spec.local_world_size),
"WORLD_SIZE": str(worker.world_size),
"GROUP_WORLD_SIZE": str(worker_group.group_world_size),
"ROLE_WORLD_SIZE": str(worker.role_world_size),
"MASTER_ADDR": master_addr,
"MASTER_PORT": str(master_port),
"TORCHELASTIC_RESTART_COUNT": str(restart_count),
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING",
str(1)),
}
worker_env_ds.update(worker_env_elastic)
if "OMP_NUM_THREADS" in os.environ:
worker_env_ds["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]

envs[local_rank] = worker_env_ds
worker_args = list(spec.args)
worker_args = macros.substitute(worker_args, str(local_rank))
args[local_rank] = tuple(worker_args)

# scaling events do not count towards restarts (gets same attempt #)
# remove existing log dir if this restart is due to a scaling event
attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
shutil.rmtree(attempt_log_dir, ignore_errors=True)
os.makedirs(attempt_log_dir)

assert spec.entrypoint is not None
self._pcontext = start_processes(
name=spec.role,
entrypoint=spec.entrypoint,
args=args,
envs=envs,
log_dir=attempt_log_dir,
start_method=self._start_method,
redirects=spec.redirects,
tee=spec.tee,
)

return self._pcontext.pids()

def _invoke_run(self, role: str = "default") -> RunResult:
# NOTE: currently only works for a single role

spec = self._worker_group.spec
role = spec.role

log.info(
f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")

self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler

participants = rdzv_handler._state_holder.state.participants

while True:
assert self._worker_group.state != WorkerState.INIT
time.sleep(monitor_interval)
run_result = self._monitor_workers(self._worker_group)
state = run_result.state
self._worker_group.state = state

expire_time = datetime.utcnow() - (
rdzv_handler._settings.keep_alive_interval *
rdzv_handler._settings.keep_alive_max_attempt)
_dead_nodes = [
node for node,
last_heartbeat in
rdzv_handler._state_holder.state.last_heartbeats.items()
if last_heartbeat < expire_time
]

put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1)

if state == WorkerState.SUCCEEDED:
log.info(
f"[{role}] worker group successfully finished."
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
)
self._exit_barrier()
return run_result
elif state in {
WorkerState.UNHEALTHY,
WorkerState.FAILED
} or len(participants) > len(rdzv_handler._state_holder.state.participants):
if self._remaining_restarts > 0:
log.info(
f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f" will restart worker group")
self._remaining_restarts -= 1
# rdzv_handler._state_holder.state.restart = False
self._restart_workers(self._worker_group)
participants = rdzv_handler._state_holder.state.participants

else:
self._stop_workers(self._worker_group)
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
elif state == WorkerState.HEALTHY:
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
log.info(f"[{role}] Detected {num_nodes_waiting} "
f"new nodes from group_rank={group_rank}; "
f"will restart worker group")
self._restart_workers(self._worker_group)
participants = rdzv_handler._state_holder.state.participants
else:
raise Exception(f"[{role}] Worker group in {state.name} state")
Loading