Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
name: check-torchdist
entry: ./scripts/check-torchdist.py
language: script
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm

- repo: https://github.com/codespell-project/codespell
Expand Down
1 change: 0 additions & 1 deletion deepspeed/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@
# To make an attempt at backwards compatibility with THD, we use an
# extraordinarily high default timeout, given that THD did not have timeouts.
default_pg_timeout = timedelta(minutes=30)

INFERENCE_GENERIC_MODE = 'generic'
INFERENCE_SPECIALIZED_MODE = 'specialized'
4 changes: 4 additions & 0 deletions deepspeed/elasticity/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config
from .utils import is_torch_elastic_compatible
from .constants import ENABLED, ENABLED_DEFAULT, ELASTICITY
if is_torch_elastic_compatible():
from .elastic_agent import DSElasticAgent
15 changes: 15 additions & 0 deletions deepspeed/elasticity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, param_dict):

self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)

if self.min_gpus < 1 or self.max_gpus < 1:
raise ElasticityConfigError(
"Elasticity min/max gpus must be > 0, "
Expand All @@ -88,6 +89,20 @@ def __init__(self, param_dict):
"Elasticity min_gpus cannot be greater than max_gpus, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")

self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE,
MODEL_PARLLEL_SIZE_DEFAULT)
if self.model_parallel_size < 1:
raise ElasticityConfigError(
"Model-Parallel size cannot be less than 1, "
f"given model-parallel size: {self.model_parallel_size}")

self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE,
NUM_GPUS_PER_NODE_DEFAULT)
if self.num_gpus_per_node < 1:
raise ElasticityConfigError(
"Number of GPUs per node cannot be less than 1, "
f"given number of GPUs per node: {self.num_gpus_per_node}")

self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
if self.min_time < 0:
raise ElasticityConfigError(
Expand Down
8 changes: 7 additions & 1 deletion deepspeed/elasticity/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ELASTICITY = 'elasticity'

# Current elasticity version
LATEST_ELASTICITY_VERSION = 0.1
LATEST_ELASTICITY_VERSION = 0.2

ENABLED = 'enabled'
ENABLED_DEFAULT = False
Expand All @@ -46,6 +46,12 @@
MAX_GPUS = 'max_gpus'
MAX_GPUS_DEFAULT = 10000

NUM_GPUS_PER_NODE = 'num_gpus_per_node'
NUM_GPUS_PER_NODE_DEFAULT = 1

MODEL_PARLLEL_SIZE = "model_parallel_size"
MODEL_PARLLEL_SIZE_DEFAULT = 1

# Minimum running time (minutes) before the scheduler will scale us, 0 implies it's unknown
MIN_TIME = "min_time"
MIN_TIME_DEFAULT = 0
Expand Down
187 changes: 187 additions & 0 deletions deepspeed/elasticity/elastic_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from typing import Any, Dict, Optional, Tuple
from datetime import datetime
from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port
from torch.distributed.elastic.metrics import put_metric
from torch.distributed.elastic.agent.server.api import (
RunResult,
WorkerGroup,
WorkerSpec,
WorkerState,
)
from torch.distributed import Store
import time
import os
from torch.distributed.elastic.multiprocessing import start_processes
from torch.distributed.elastic.utils import macros
import shutil
import copy
from contextlib import closing
import subprocess


class DSElasticAgent(LocalElasticAgent):
def __init__(
self,
spec: WorkerSpec,
env: Dict,
start_method="spawn",
exit_barrier_timeout: float = 300,
log_dir: Optional[str] = None,
):
super().__init__(spec, start_method, exit_barrier_timeout, log_dir)
self.ds_env = env

@staticmethod
def _set_master_addr_port(store: Store,
master_addr: Optional[str],
master_port: Optional[int]):
if master_port is None:
sock = _get_socket_with_port()
with closing(sock):
master_port = sock.getsockname()[1]

if master_addr is None:
# master_addr = _get_fq_hostname()
result = subprocess.check_output("hostname -I", shell=True)
master_addr = result.decode('utf-8').split()[0]

store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))

def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
spec = worker_group.spec
store = worker_group.store
assert store is not None
master_addr, master_port = super()._get_master_addr_port(store)
restart_count = spec.max_restarts - self._remaining_restarts

use_agent_store = spec.rdzv_handler.get_backend() == "static"

args: Dict[int, Tuple] = {}
envs: Dict[int, Dict[str, str]] = {}
for worker in worker_group.workers:
local_rank = worker.local_rank

worker_env_ds = copy.deepcopy(self.ds_env)
worker_env_elastic = {
"LOCAL_RANK": str(local_rank),
"RANK": str(worker.global_rank),
"GROUP_RANK": str(worker_group.group_rank),
"ROLE_RANK": str(worker.role_rank),
"ROLE_NAME": spec.role,
"LOCAL_WORLD_SIZE": str(spec.local_world_size),
"WORLD_SIZE": str(worker.world_size),
"GROUP_WORLD_SIZE": str(worker_group.group_world_size),
"ROLE_WORLD_SIZE": str(worker.role_world_size),
"MASTER_ADDR": master_addr,
"MASTER_PORT": str(master_port),
"TORCHELASTIC_RESTART_COUNT": str(restart_count),
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING",
str(1)),
}
worker_env_ds.update(worker_env_elastic)
if "OMP_NUM_THREADS" in os.environ:
worker_env_ds["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]

envs[local_rank] = worker_env_ds
worker_args = list(spec.args)
worker_args = macros.substitute(worker_args, str(local_rank))
args[local_rank] = tuple(worker_args)

# scaling events do not count towards restarts (gets same attempt #)
# remove existing log dir if this restart is due to a scaling event
attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
shutil.rmtree(attempt_log_dir, ignore_errors=True)
os.makedirs(attempt_log_dir)

assert spec.entrypoint is not None
self._pcontext = start_processes(
name=spec.role,
entrypoint=spec.entrypoint,
args=args,
envs=envs,
log_dir=attempt_log_dir,
start_method=self._start_method,
redirects=spec.redirects,
tee=spec.tee,
)

return self._pcontext.pids()

def _invoke_run(self, role: str = "default") -> RunResult:
# NOTE: currently only works for a single role

spec = self._worker_group.spec
role = spec.role

log.info(
f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")

self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler

participants = rdzv_handler._state_holder.state.participants

while True:
assert self._worker_group.state != WorkerState.INIT
time.sleep(monitor_interval)
run_result = self._monitor_workers(self._worker_group)
state = run_result.state
self._worker_group.state = state

expire_time = datetime.utcnow() - (
rdzv_handler._settings.keep_alive_interval *
rdzv_handler._settings.keep_alive_max_attempt)
_dead_nodes = [
node for node,
last_heartbeat in
rdzv_handler._state_holder.state.last_heartbeats.items()
if last_heartbeat < expire_time
]

put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1)

if state == WorkerState.SUCCEEDED:
log.info(
f"[{role}] worker group successfully finished."
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
)
self._exit_barrier()
return run_result
elif state in {
WorkerState.UNHEALTHY,
WorkerState.FAILED
} or len(participants) > len(rdzv_handler._state_holder.state.participants):
if self._remaining_restarts > 0:
log.info(
f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f" will restart worker group")
self._remaining_restarts -= 1
# rdzv_handler._state_holder.state.restart = False
self._restart_workers(self._worker_group)
participants = rdzv_handler._state_holder.state.participants

else:
self._stop_workers(self._worker_group)
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
elif state == WorkerState.HEALTHY:
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
log.info(f"[{role}] Detected {num_nodes_waiting} "
f"new nodes from group_rank={group_rank}; "
f"will restart worker group")
self._restart_workers(self._worker_group)
participants = rdzv_handler._state_holder.state.participants
else:
raise Exception(f"[{role}] Worker group in {state.name} state")
Loading