diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df68ddcf5192..590353f3bad3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,7 +39,7 @@ repos: name: check-torchdist entry: ./scripts/check-torchdist.py language: script - exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py) + exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py) # Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm - repo: https://github.com/codespell-project/codespell diff --git a/deepspeed/constants.py b/deepspeed/constants.py index 9576c9c078d5..9e7f511e443f 100644 --- a/deepspeed/constants.py +++ b/deepspeed/constants.py @@ -14,6 +14,5 @@ # To make an attempt at backwards compatibility with THD, we use an # extraordinarily high default timeout, given that THD did not have timeouts. default_pg_timeout = timedelta(minutes=30) - INFERENCE_GENERIC_MODE = 'generic' INFERENCE_SPECIALIZED_MODE = 'specialized' diff --git a/deepspeed/elasticity/__init__.py b/deepspeed/elasticity/__init__.py index be517de7df93..6ad9f20cf936 100644 --- a/deepspeed/elasticity/__init__.py +++ b/deepspeed/elasticity/__init__.py @@ -1 +1,5 @@ from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config +from .utils import is_torch_elastic_compatible +from .constants import ENABLED, ENABLED_DEFAULT, ELASTICITY +if is_torch_elastic_compatible(): + from .elastic_agent import DSElasticAgent diff --git a/deepspeed/elasticity/config.py b/deepspeed/elasticity/config.py index 67db58c70e71..ffbce7028e03 100644 --- a/deepspeed/elasticity/config.py +++ b/deepspeed/elasticity/config.py @@ -79,6 +79,7 @@ def __init__(self, param_dict): self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT) self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT) + if self.min_gpus < 1 or self.max_gpus < 1: raise ElasticityConfigError( "Elasticity min/max gpus must be > 0, " @@ -88,6 +89,20 @@ def __init__(self, param_dict): "Elasticity min_gpus cannot be greater than max_gpus, " f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}") + self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE, + MODEL_PARLLEL_SIZE_DEFAULT) + if self.model_parallel_size < 1: + raise ElasticityConfigError( + "Model-Parallel size cannot be less than 1, " + f"given model-parallel size: {self.model_parallel_size}") + + self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE, + NUM_GPUS_PER_NODE_DEFAULT) + if self.num_gpus_per_node < 1: + raise ElasticityConfigError( + "Number of GPUs per node cannot be less than 1, " + f"given number of GPUs per node: {self.num_gpus_per_node}") + self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT) if self.min_time < 0: raise ElasticityConfigError( diff --git a/deepspeed/elasticity/constants.py b/deepspeed/elasticity/constants.py index cf428b70a0f4..eb40edc84711 100644 --- a/deepspeed/elasticity/constants.py +++ b/deepspeed/elasticity/constants.py @@ -27,7 +27,7 @@ ELASTICITY = 'elasticity' # Current elasticity version -LATEST_ELASTICITY_VERSION = 0.1 +LATEST_ELASTICITY_VERSION = 0.2 ENABLED = 'enabled' ENABLED_DEFAULT = False @@ -46,6 +46,12 @@ MAX_GPUS = 'max_gpus' MAX_GPUS_DEFAULT = 10000 +NUM_GPUS_PER_NODE = 'num_gpus_per_node' +NUM_GPUS_PER_NODE_DEFAULT = 1 + +MODEL_PARLLEL_SIZE = "model_parallel_size" +MODEL_PARLLEL_SIZE_DEFAULT = 1 + # Minimum running time (minutes) before the scheduler will scale us, 0 implies it's unknown MIN_TIME = "min_time" MIN_TIME_DEFAULT = 0 diff --git a/deepspeed/elasticity/elastic_agent.py b/deepspeed/elasticity/elastic_agent.py new file mode 100644 index 000000000000..8d0981612725 --- /dev/null +++ b/deepspeed/elasticity/elastic_agent.py @@ -0,0 +1,187 @@ +from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent +from typing import Any, Dict, Optional, Tuple +from datetime import datetime +from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port +from torch.distributed.elastic.metrics import put_metric +from torch.distributed.elastic.agent.server.api import ( + RunResult, + WorkerGroup, + WorkerSpec, + WorkerState, +) +from torch.distributed import Store +import time +import os +from torch.distributed.elastic.multiprocessing import start_processes +from torch.distributed.elastic.utils import macros +import shutil +import copy +from contextlib import closing +import subprocess + + +class DSElasticAgent(LocalElasticAgent): + def __init__( + self, + spec: WorkerSpec, + env: Dict, + start_method="spawn", + exit_barrier_timeout: float = 300, + log_dir: Optional[str] = None, + ): + super().__init__(spec, start_method, exit_barrier_timeout, log_dir) + self.ds_env = env + + @staticmethod + def _set_master_addr_port(store: Store, + master_addr: Optional[str], + master_port: Optional[int]): + if master_port is None: + sock = _get_socket_with_port() + with closing(sock): + master_port = sock.getsockname()[1] + + if master_addr is None: + # master_addr = _get_fq_hostname() + result = subprocess.check_output("hostname -I", shell=True) + master_addr = result.decode('utf-8').split()[0] + + store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8")) + store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8")) + + def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: + spec = worker_group.spec + store = worker_group.store + assert store is not None + master_addr, master_port = super()._get_master_addr_port(store) + restart_count = spec.max_restarts - self._remaining_restarts + + use_agent_store = spec.rdzv_handler.get_backend() == "static" + + args: Dict[int, Tuple] = {} + envs: Dict[int, Dict[str, str]] = {} + for worker in worker_group.workers: + local_rank = worker.local_rank + + worker_env_ds = copy.deepcopy(self.ds_env) + worker_env_elastic = { + "LOCAL_RANK": str(local_rank), + "RANK": str(worker.global_rank), + "GROUP_RANK": str(worker_group.group_rank), + "ROLE_RANK": str(worker.role_rank), + "ROLE_NAME": spec.role, + "LOCAL_WORLD_SIZE": str(spec.local_world_size), + "WORLD_SIZE": str(worker.world_size), + "GROUP_WORLD_SIZE": str(worker_group.group_world_size), + "ROLE_WORLD_SIZE": str(worker.role_world_size), + "MASTER_ADDR": master_addr, + "MASTER_PORT": str(master_port), + "TORCHELASTIC_RESTART_COUNT": str(restart_count), + "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), + "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), + "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), + "NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING", + str(1)), + } + worker_env_ds.update(worker_env_elastic) + if "OMP_NUM_THREADS" in os.environ: + worker_env_ds["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] + + envs[local_rank] = worker_env_ds + worker_args = list(spec.args) + worker_args = macros.substitute(worker_args, str(local_rank)) + args[local_rank] = tuple(worker_args) + + # scaling events do not count towards restarts (gets same attempt #) + # remove existing log dir if this restart is due to a scaling event + attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}") + shutil.rmtree(attempt_log_dir, ignore_errors=True) + os.makedirs(attempt_log_dir) + + assert spec.entrypoint is not None + self._pcontext = start_processes( + name=spec.role, + entrypoint=spec.entrypoint, + args=args, + envs=envs, + log_dir=attempt_log_dir, + start_method=self._start_method, + redirects=spec.redirects, + tee=spec.tee, + ) + + return self._pcontext.pids() + + def _invoke_run(self, role: str = "default") -> RunResult: + # NOTE: currently only works for a single role + + spec = self._worker_group.spec + role = spec.role + + log.info( + f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}") + + self._initialize_workers(self._worker_group) + monitor_interval = spec.monitor_interval + rdzv_handler = spec.rdzv_handler + + participants = rdzv_handler._state_holder.state.participants + + while True: + assert self._worker_group.state != WorkerState.INIT + time.sleep(monitor_interval) + run_result = self._monitor_workers(self._worker_group) + state = run_result.state + self._worker_group.state = state + + expire_time = datetime.utcnow() - ( + rdzv_handler._settings.keep_alive_interval * + rdzv_handler._settings.keep_alive_max_attempt) + _dead_nodes = [ + node for node, + last_heartbeat in + rdzv_handler._state_holder.state.last_heartbeats.items() + if last_heartbeat < expire_time + ] + + put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) + put_metric(f"workers.{role}.{state.name.lower()}", 1) + + if state == WorkerState.SUCCEEDED: + log.info( + f"[{role}] worker group successfully finished." + f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish." + ) + self._exit_barrier() + return run_result + elif state in { + WorkerState.UNHEALTHY, + WorkerState.FAILED + } or len(participants) > len(rdzv_handler._state_holder.state.participants): + if self._remaining_restarts > 0: + log.info( + f"[{role}] Worker group {state.name}. " + f"{self._remaining_restarts}/{spec.max_restarts} attempts left;" + f" will restart worker group") + self._remaining_restarts -= 1 + # rdzv_handler._state_holder.state.restart = False + self._restart_workers(self._worker_group) + participants = rdzv_handler._state_holder.state.participants + + else: + self._stop_workers(self._worker_group) + self._worker_group.state = WorkerState.FAILED + self._exit_barrier() + return run_result + elif state == WorkerState.HEALTHY: + # membership changes do not count as retries + num_nodes_waiting = rdzv_handler.num_nodes_waiting() + group_rank = self._worker_group.group_rank + if num_nodes_waiting > 0: + log.info(f"[{role}] Detected {num_nodes_waiting} " + f"new nodes from group_rank={group_rank}; " + f"will restart worker group") + self._restart_workers(self._worker_group) + participants = rdzv_handler._state_holder.state.participants + else: + raise Exception(f"[{role}] Worker group in {state.name} state") diff --git a/deepspeed/elasticity/elasticity.py b/deepspeed/elasticity/elasticity.py index c17dab0319d9..17a8b6ecf394 100644 --- a/deepspeed/elasticity/elasticity.py +++ b/deepspeed/elasticity/elasticity.py @@ -4,7 +4,7 @@ import os import json import numpy as np - +import math from packaging import version as pkg_version from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \ @@ -91,7 +91,6 @@ def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus): valid_gpus.append(i) valid_gpus = set(valid_gpus) valid_gpus = sorted(list(valid_gpus)) - logger.info(f"Valid GPUs: {valid_gpus}") return valid_gpus @@ -171,6 +170,70 @@ def _get_compatible_gpus_v01(micro_batches, return final_batch_size, valid_gpus +def _get_compatible_gpus_v02(micro_batches, + max_acceptable_batch_size, + current_num_gpus, + min_gpus=None, + max_gpus=None, + prefer_larger=True, + num_gpus_per_node=1, + model_parallel_size=1): + ''' + Returns: + final_batch_size + valid_gpus + micro-batch size + ''' + if num_gpus_per_node % model_parallel_size != 0: + raise ElasticityError( + f"In Elasticity v0.2, number of GPUs per node:" \ + f"{num_gpus_per_node} should be divisible by " \ + f"model parallel size {model_parallel_size}") + + def get_microbatch(final_batch_size): + candidate_microbatch = None + + for micro_batch in micro_batches: + if final_batch_size // current_num_gpus % micro_batch == 0: + if candidate_microbatch == None: + candidate_microbatch = micro_batch + if prefer_larger and candidate_microbatch < micro_batch: + candidate_microbatch = micro_batch + return candidate_microbatch + + dp_size_per_node = num_gpus_per_node // model_parallel_size + + final_batch_size, valid_world_size = _get_compatible_gpus_v01(micro_batches, + int(max_acceptable_batch_size/dp_size_per_node), + int(min_gpus/num_gpus_per_node), + int(max_gpus/num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level + prefer_larger=prefer_larger) + + final_batch_size = int(final_batch_size) * dp_size_per_node + valid_dp_world_size = [i * dp_size_per_node for i in valid_world_size] + if current_num_gpus // model_parallel_size in valid_dp_world_size: + candidate_microbatch = get_microbatch(final_batch_size) + return final_batch_size, valid_dp_world_size, candidate_microbatch + + current_dp_size = (current_num_gpus / num_gpus_per_node) * dp_size_per_node + candidate_batch_sizes = [] + for micro_batch in micro_batches: + min_batch_size = micro_batch * current_dp_size + + factor = math.floor(max_acceptable_batch_size / float(min_batch_size)) + candidate_batch_sizes.append(factor * min_batch_size) + + used_microbatch = None + if prefer_larger: + candidate_batch_size = max(candidate_batch_sizes) + else: + candidate_batch_size = min(candidate_batch_sizes) + + candidate_microbatch = get_microbatch(candidate_batch_size) + + return candidate_batch_size, [int(current_dp_size)], candidate_microbatch + + def _compatible_ds_version_check(target_deepspeed_version: str): min_version = pkg_version.parse(MINIMUM_DEEPSPEED_VERSION) target_version = pkg_version.parse(target_deepspeed_version) @@ -221,7 +284,10 @@ def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict): "guarantee resource scheduler will scale this job using compatible GPU counts.") -def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0): +def compute_elastic_config(ds_config: dict, + target_deepspeed_version: str, + world_size=0, + return_microbatch=False): """Core deepspeed elasticity API. Given an elastic config (similar to the example below) DeepSpeed will compute a total train batch size corresponding valid GPU count list that provides a high level of elasticity. Elasticity in this case means we are safe to scale @@ -248,8 +314,9 @@ def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world target_deepspeed_version (str): When called from scheduling infrastructure we want to ensure that the target deepspeed version is compatible with the elasticity version used in the backend. - world_size (int, optional): Intended/current world size, will do some sanity + world_size (int, optional): Intended/current DP world size, will do some sanity checks to ensure world size is actually valid with the config. + return_microbatch (bool, optional): whether to return micro batch size or not. Raises: ElasticityConfigError: Missing required elasticity config or elasticity disabled @@ -275,6 +342,13 @@ def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world "('enabled':true) if running an elastic training job.") elastic_config = ElasticityConfig(elastic_config_dict) + model_parallel_size = elastic_config.model_parallel_size + num_gpus_per_node = elastic_config.num_gpus_per_node + + if model_parallel_size > 1 and float(elastic_config.version) != 0.2: + raise ElasticityConfigError(f"Elasticity V{elastic_config.version} " \ + f"does not support model-parallel training. Given model-parallel size: " \ + f"{model_parallel_size}") if float(elastic_config.version) > LATEST_ELASTICITY_VERSION: raise ElasticityConfigError("Attempting to run elasticity version " \ @@ -295,10 +369,39 @@ def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world prefer_larger=elastic_config.prefer_larger_batch_size) # ensure batch size is int dtype final_batch_size = int(final_batch_size) + elif float(elastic_config.version) == 0.2: + if world_size != 0: + current_num_gpus = world_size + else: + if "WORLD_SIZE" in os.environ and \ + os.getenv('WORLD_SIZE').isnumeric(): + current_num_gpus = int(os.getenv('WORLD_SIZE')) + else: + WORLD_SIZE = os.getenv('WORLD_SIZE') + raise ElasticityConfigError( + 'Elasticity V 0.2 needs WORLD_SIZE '\ + 'to compute valid batch size. '\ + 'Either give it as argument to function compute_elastic_config '\ + 'or set it as an environment variable. '\ + f'Value of WORLD_SIZE as environment variable is {WORLD_SIZE}') + + final_batch_size, valid_gpus, candidate_microbatch_size = _get_compatible_gpus_v02( + micro_batches=elastic_config.micro_batches, + max_acceptable_batch_size=elastic_config.max_acceptable_batch_size, + current_num_gpus=current_num_gpus, + min_gpus=elastic_config.min_gpus, + max_gpus=elastic_config.max_gpus, + prefer_larger=elastic_config.prefer_larger_batch_size, + num_gpus_per_node=num_gpus_per_node, + model_parallel_size=model_parallel_size) + # ensure batch size is int dtype + final_batch_size = int(final_batch_size) else: raise NotImplementedError( f"Unable to find elastic logic for version: {elastic_config.version}") + logger.info(f"Valid World Size (GPUs / Model Parallel Size): {valid_gpus}") + if world_size > 0: if world_size not in valid_gpus: raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \ @@ -315,4 +418,19 @@ def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world f" micro_batches={elastic_config.micro_batches}." return final_batch_size, valid_gpus, micro_batch_size + if return_microbatch: + # Pick a valid micro batch size + if float(elastic_config.version) == 0.2: + return final_batch_size, valid_gpus, candidate_microbatch_size + else: + micro_batch_size = None + for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True): + if final_batch_size // world_size % mbsz == 0: + micro_batch_size = mbsz + break + assert micro_batch_size is not None, "Unable to find divisible micro batch size" \ + f" world_size={world_size}, final_batch_size={final_batch_size}, and " \ + f" micro_batches={elastic_config.micro_batches}." + return final_batch_size, valid_gpus, micro_batch_size + return final_batch_size, valid_gpus diff --git a/deepspeed/elasticity/utils.py b/deepspeed/elasticity/utils.py new file mode 100644 index 000000000000..a1001c6b3806 --- /dev/null +++ b/deepspeed/elasticity/utils.py @@ -0,0 +1,14 @@ +import torch + + +def is_torch_elastic_compatible(): + ''' + Helper to lookup torch version. Elastic training is + introduced in 1.11.x + ''' + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + if TORCH_MAJOR == 1 and TORCH_MINOR >= 11: + return True + else: + return False diff --git a/deepspeed/launcher/constants.py b/deepspeed/launcher/constants.py index 15b262342d47..2ae22ea6c7da 100644 --- a/deepspeed/launcher/constants.py +++ b/deepspeed/launcher/constants.py @@ -7,3 +7,5 @@ MVAPICH_LAUNCHER = 'mvapich' MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile' + +ELASTIC_TRAINING_ID_DEFAULT = "123456789" diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index d37a970c5e3a..f39530552055 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -17,11 +17,13 @@ import signal import psutil from collections import defaultdict +from typing import Dict from argparse import ArgumentParser, REMAINDER - from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..nebula.constants import DLTS_POD_ENV_PATH from ..utils import logger +from ..elasticity import is_torch_elastic_compatible +from .constants import ELASTIC_TRAINING_ID_DEFAULT PID_FILE_BASEPATH = "/tmp" @@ -66,6 +68,20 @@ def parse_args(): help="Skip prepending the training script with " "'python' - just execute it directly.") + parser.add_argument("--enable_elastic_training", + action="store_true", + help="Enable elastic training support.") + + parser.add_argument("--min_elastic_nodes", + type=int, + default=-1, + help="Min number of nodes in elastic training.") + + parser.add_argument("--max_elastic_nodes", + type=int, + default=-1, + help="Max number of nodes in elastic training.") + parser.add_argument("--no_local_rank", action="store_true", help="Do not pass local_rank as an argument when calling " @@ -160,6 +176,12 @@ def main(): with open(pid_file, 'w') as fd: fd.write(f"{launcher_pid}") + if not is_torch_elastic_compatible(): + if args.enable_elastic_training: + logger.info(f"Disabling elastic training support as \ + PyTorch version should be greater than 1.11.x") + args.enable_elastic_training = False + if os.path.exists(DLTS_POD_ENV_PATH): with open(DLTS_POD_ENV_PATH) as file: lines = file.readlines() @@ -173,13 +195,48 @@ def main(): processes = [] cmd = [] - for local_rank in range(0, num_local_procs): - # each process's rank - dist_rank = global_rank_mapping[local_node][local_rank] - current_env["RANK"] = str(dist_rank) - current_env["LOCAL_RANK"] = str(local_rank) - # spawn the processes + if not args.enable_elastic_training: + for local_rank in range(0, num_local_procs): + # each process's rank + dist_rank = global_rank_mapping[local_node][local_rank] + current_env["RANK"] = str(dist_rank) + current_env["LOCAL_RANK"] = str(local_rank) + + # spawn the processes + cmd = [] + if not args.no_python: + cmd = [sys.executable, "-u"] + if args.module: + cmd.append("-m") + else: + if args.module: + raise ValueError("Don't use both the '--no_python' flag" + " and the '--module' flag at the same time.") + cmd.append(args.training_script) + # A user may not want to pass local_rank as a keyword arg so we make this optional. + if not args.no_local_rank: + cmd.append(f"--local_rank={local_rank}") + cmd += args.training_script_args + + process = subprocess.Popen(cmd, env=current_env) + processes.append(process) + else: + from ..elasticity import DSElasticAgent + from torch.distributed.elastic.rendezvous import RendezvousParameters + from torch.distributed.elastic.agent.server.api import WorkerSpec + import torch.distributed.elastic.rendezvous.registry as rdzv_registry + from torch.distributed.elastic.multiprocessing import Std + + if args.min_elastic_nodes == -1: + args.min_elastic_nodes = 1 + if args.max_elastic_nodes == -1: + args.max_elastic_nodes = args.nnodes + assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0 , "Max and Min nodes should be positive" + + current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1) + + # Get config and arguments cmd = [] if not args.no_python: cmd = [sys.executable, "-u"] @@ -190,13 +247,36 @@ def main(): raise ValueError("Don't use both the '--no_python' flag" " and the '--module' flag at the same time.") cmd.append(args.training_script) - # A user may not want to pass local_rank as a keyword arg so we make this optional. - if not args.no_local_rank: - cmd.append(f"--local_rank={local_rank}") cmd += args.training_script_args - - process = subprocess.Popen(cmd, env=current_env) - processes.append(process) + cmd_args = cmd[1:] + + rdzv_configs: Dict[str, str] = {'timeout': 100} + run_id = os.environ.get("ELASTIC_RUN_ID", ELASTIC_TRAINING_ID_DEFAULT) + + # Creating config for rendezvous class + rdzv_parameters = RendezvousParameters(backend='c10d', + endpoint=args.master_addr + ":" + + str(args.master_port), + run_id=run_id, + min_nodes=args.min_elastic_nodes, + max_nodes=args.max_elastic_nodes, + **rdzv_configs) + + spec = WorkerSpec( + role='trainer', + local_world_size=num_local_procs, + entrypoint=cmd[0], + args=cmd[1:], + rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), + max_restarts=100, + monitor_interval=5, + redirects=Std.from_str("0"), + tee=Std.from_str("0"), + master_addr=None, + master_port=None, + ) + agent = DSElasticAgent(spec, current_env) + agent.run() sig_names = {2: "SIGINT", 15: "SIGTERM"} last_return_code = None diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 6fb187cfde23..7c2828e75fc7 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -94,8 +94,16 @@ def get_cmd(self, environment, active_resources): deepspeed_launch.append("--no_local_rank") if self.args.save_pid: deepspeed_launch += ["--save_pid", f"{os.getpid()}"] + if self.args.elastic_training: + deepspeed_launch.append("--enable_elastic_training") + deepspeed_launch.append(f"--max_elastic_nodes={self.args.max_elastic_nodes}") + deepspeed_launch.append(f"--min_elastic_nodes={self.args.min_elastic_nodes}") + + cmd_to_search = [i + "\\" for i in deepspeed_launch[2:6]] + + kill_command = pdsh_cmd_args + ["pkill -f ", " ".join(cmd_to_search)[:-2]] return pdsh_cmd_args + deepspeed_launch + [self.user_script - ] + self.user_arguments + ] + self.user_arguments, kill_command class OpenMPIRunner(MultiNodeRunner): diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index cf3e98dc25bb..f35fb3994ac4 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -14,7 +14,8 @@ import subprocess import collections from copy import deepcopy - +import signal +import time import torch.cuda from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner @@ -77,6 +78,18 @@ def parse_args(args=None): help="Total number of worker nodes to run on, this will use " "the top N hosts from the given hostfile.") + parser.add_argument("--min_elastic_nodes", + type=int, + default=-1, + help="Minimum number of nodes to run elastic training on. " + "Default is 1 when elastic training is enabled") + + parser.add_argument("--max_elastic_nodes", + type=int, + default=-1, + help="Maximum number of nodes to run elastic training on. " + "Default is num_nodes when elastic training is enabled") + parser.add_argument("--num_gpus", type=int, default=-1, @@ -148,6 +161,10 @@ def parse_args(args=None): help="Run DeepSpeed autotuner to discover optimal configuration parameters " "before running job.") + parser.add_argument("--elastic_training", + action="store_true", + help="Enable elastic training support in DeepSpeed.") + parser.add_argument("user_script", type=str, help="User script to launch, followed by any required " @@ -316,9 +333,27 @@ def run_autotuning(args, active_resources): tuner.run_after_tuning() +def parse_num_nodes(str_num_nodes: str, elastic_training: bool): + node_list = str_num_nodes.split(":") + + if len(node_list) == 1: + min_nodes, max_nodes = int(node_list[0]), -1 + elif len(node_list) == 2 and elastic_training: + min_nodes, max_nodes = int(node_list[0]), int(node_list[1]) + elif len(node_list) == 2 and not elastic_training: + raise RuntimeError("MIN:MAX format is only supported in elastic training") + else: + raise RuntimeError("num_nodes {} is not in MIN:MAX format".format(str_num_nodes)) + + return min_nodes, max_nodes + + def main(args=None): args = parse_args(args) + if args.elastic_training: + assert args.master_addr != "", "Master Addr is required when elastic training is enabled" + resource_pool = fetch_hostfile(args.hostfile) # respect CUDA_VISIBLE_DEVICES for a single node and no explicit resource filters @@ -397,6 +432,9 @@ def main(args=None): updated_active_resources[hostname] = list(range(args.num_gpus)) active_resources = updated_active_resources + if args.elastic_training: + assert not args.no_local_rank, "--no_local_rank argument is not supported in Elastic training" + # encode world info as base64 to make it easier to pass via command line world_info_base64 = encode_world_info(active_resources) @@ -420,6 +458,10 @@ def main(args=None): deepspeed_launch.append("--no_local_rank") if args.save_pid: deepspeed_launch += ["--save_pid", f"{os.getpid()}"] + if args.elastic_training: + deepspeed_launch.append("--enable_elastic_training") + deepspeed_launch.append(f"--max_elastic_nodes={args.max_elastic_nodes}") + deepspeed_launch.append(f"--min_elastic_nodes={args.min_elastic_nodes}") cmd = deepspeed_launch + [args.user_script] + args.user_args else: args.launcher = args.launcher.lower() @@ -454,11 +496,26 @@ def main(args=None): key, val = var.split('=', maxsplit=1) runner.add_export(key, val) - cmd = runner.get_cmd(env, active_resources) + if args.launcher == PDSH_LAUNCHER: + cmd, kill_cmd = runner.get_cmd(env, active_resources) + else: + cmd = runner.get_cmd(env, active_resources) logger.info(f"cmd = {' '.join(cmd)}") result = subprocess.Popen(cmd, env=env) + def sigkill_handler(signum, frame): + result.send_signal(signal.SIGINT) + time.sleep(0.1) + result.send_signal(signal.SIGTERM) + result_kill = subprocess.Popen(kill_cmd, env=env) + result_kill.wait() + time.sleep(1) + sys.exit(1) + + if args.launcher == PDSH_LAUNCHER: + signal.signal(signal.SIGINT, sigkill_handler) + result.wait() # In case of failure must propagate the error-condition back to the caller (usually shell). The diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 0794e4525343..e0aa105be8f3 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -41,6 +41,10 @@ ELASTICITY, IGNORE_NON_ELASTIC_BATCH_INFO, IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT, + MODEL_PARLLEL_SIZE, + MODEL_PARLLEL_SIZE_DEFAULT, + NUM_GPUS_PER_NODE, + NUM_GPUS_PER_NODE_DEFAULT, ) from ..profiling.config import DeepSpeedFlopsProfilerConfig @@ -726,6 +730,21 @@ def __init__(self, config: Union[str, dict], mpu=None): # Ensure the resource scheduler saw the same elastic config we are using at runtime ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict) + self.elastic_model_parallel_size = elastic_dict.get( + MODEL_PARLLEL_SIZE, + MODEL_PARLLEL_SIZE_DEFAULT) + if self.elastic_model_parallel_size < 1: + raise ElasticityConfigError( + "Model-Parallel size cannot be less than 1, " + f"given model-parallel size: {self.elastic_model_parallel_size}") + + self.num_gpus_per_node = elastic_dict.get(NUM_GPUS_PER_NODE, + NUM_GPUS_PER_NODE_DEFAULT) + if self.num_gpus_per_node < 1: + raise ElasticityConfigError( + "NUmber of GPUs per node cannot be less than 1, " + f"given number of GPUs per node: {self.num_gpus_per_node}") + ignore_non_elastic_batch_info = elastic_dict.get( IGNORE_NON_ELASTIC_BATCH_INFO, IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9d1b8b6aac74..3e6ed4b174ac 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -264,9 +264,11 @@ def __init__( see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown()) if mpu is not None: - assert not self.elasticity_enabled(), ( - "Elasticity is not currently supported" " with model parallelism." - ) + if self.elasticity_enabled(): + if not self.is_elastic_model_parallel_supported(): + assert not self.elasticity_enabled(), ( + "Elasticity is not currently supported" " with model parallelism." + ) self._set_distributed_vars(args) @@ -470,6 +472,14 @@ def checkpoint_tag_validation_fail(self): def elasticity_enabled(self): return self._config.elasticity_enabled + def is_elastic_model_parallel_supported(self): + if self.elasticity_enabled(): + # Add code for finding number of GPUs per node automatically + if self._config.num_gpus_per_node % self._config.elastic_model_parallel_size == 0: + return True + else: + return False + def pld_enabled(self): return self._config.pld_enabled diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 0f3443d980f8..33edc2db1a6a 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -70,8 +70,10 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB self.pipeline_enable_backward_allreduce = True - assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ - " with pipeline parallelism." + if self.elasticity_enabled(): + if not self.is_elastic_model_parallel_supported(): + assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ + " with pipeline parallelism." # pipeline step for logging self.log_batch_step_id = -1 diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 11541aee9761..6118bece5272 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -1048,6 +1048,38 @@ Example of **csv_monitor** configuration: } ``` +### Elastic Training Config (V0.1 and V0.2) + +```json + "elasticity": { + "enabled": true, + "max_train_batch_size": "seqlen", + "micro_batch_sizes": 8, + "min_gpus": 1024, + "max_gpus": "fixed_linear", + "min_time": "seqlen", + "version": 8, + "ignore_non_elastic_batch_info": 1024, + "num_gpus_per_node": "fixed_linear", + "model_parallel_size": MODEL_PARALLEL_SIZE + } +``` + +| Field | Description |Default| +| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | +| `enabled` | Enables computation of global batch size in elastic training. | false | +| `max_train_batch_size` | Max acceptable batch size can be used in training. | 2000 | +| `micro_batch_sizes` | Acceptable micro batch sizes, same as train_micro_batch_size_per_gpu | [2,4,6] | +| `min_gpus` | Min number of GPUs to search over when computing highly composite batch size in v0.1 and v0.2. | 1 | +| `max_gpus` | Max number of GPUs to search over when computing highly composite batch size in v0.1 and v0.2. | 10000 | +| `min_time` |Minimum running time (minutes) before the scheduler will scale again (only used in v0.1). 0 implies it's unknown | 0 | +| `prefer_large_batch` | When finding a suitable batch size, attempt to find one that is closest to the max train batch size given. | true | +| `version` | Version of elastic logic to use. | 0.2 | +| `ignore_non_elastic_batch_info` | Ignore all batch info provided outside the elastic config. To reduce confusion, we require all batch related info to be given in elastic config only. | false | +| `num_gpus_per_node` | Number of GPUs per node. This information is used by v0.2 to support model-parallel training (only used by v0.2) | 1 | +| `model_parallel_size` | Tensor or model parallel size (only used by v0.2) | 1 | + + ### Communication Logging diff --git a/tests/unit/test_elastic.py b/tests/unit/test_elastic.py index 9f5d1f0d06bd..4ed2c0dd0c95 100644 --- a/tests/unit/test_elastic.py +++ b/tests/unit/test_elastic.py @@ -2,6 +2,7 @@ import deepspeed from .common import distributed_test from deepspeed.git_version_info import version as ds_version +import os from .simple_model import SimpleModel, args_from_dict base_ds_config = { @@ -78,7 +79,7 @@ def test_invalid_world_size(): def test_future_elastic_version(): ds_config = base_ds_config.copy() - ds_config['elasticity']['version'] = '0.2' + ds_config['elasticity']['version'] = '0.3' with pytest.raises(deepspeed.elasticity.config.ElasticityError): deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version) @@ -107,6 +108,42 @@ def test_empty_config(): target_deepspeed_version=ds_version) +def test_model_parallel_v1_invalid(): + ds_config = base_ds_config.copy() + ds_config["elasticity"]["model_parallel_size"] = 4 + ds_config["elasticity"]["num_gpus_per_node"] = 8 + ds_config["elasticity"]["version"] = 0.1 + + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_model_parallel_v2_invalid(): + ds_config = base_ds_config.copy() + ds_config["elasticity"]["model_parallel_size"] = 16 + ds_config["elasticity"]["num_gpus_per_node"] = 8 + ds_config["elasticity"]["version"] = 0.2 + + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version, + world_size=16) + + +def test_model_parallel_v2_valid(): + + ds_config = base_ds_config.copy() + ds_config["elasticity"]["model_parallel_size"] = 4 + ds_config["elasticity"]["num_gpus_per_node"] = 8 + ds_config["elasticity"]["version"] = 0.2 + + os.environ["WORLD_SIZE"] = str(16) + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + os.environ.pop("WORLD_SIZE") + + @pytest.mark.parametrize('key, value', [('micro_batch_sizes', [1,