diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 794adeb7ab00..173a51cda5de 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -4,14 +4,12 @@ name: Build # Controls when the action will run. on: - # Triggers the workflow on push or pull request events but only for the master branch push: - branches: [ master ] + paths-ignore: + - 'docs/**' pull_request: - branches: [ master ] - - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: + paths-ignore: + - 'docs/**' # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: diff --git a/bin/ds_elastic b/bin/ds_elastic new file mode 100644 index 000000000000..ef92cbdab32d --- /dev/null +++ b/bin/ds_elastic @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +import argparse +import json + +import deepspeed +from deepspeed.elasticity import compute_elastic_config + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json") + parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size") + args = parser.parse_args() + ds_config = json.load(open(args.config, 'r')) + + ds_version = deepspeed.__version__ + + elastic_config = ds_config['elasticity'] + print('------------------------------------------') + print("Elasticity config:") + print('------------------------------------------') + print(json.dumps(elastic_config, indent=4, sort_keys=True)) + + if args.world_size > 0: + final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size) + print('------------------------------------------') + print(f"Calculated results for world size {args.world_size}:") + print('------------------------------------------') + print(f'final_batch_size .... {final_batch_size}') + print(f'valid_gpus .......... {valid_gpus}') + print(f'micro_batch_size .... {micro_batch_size}') + else: + final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version) + print('------------------------------------------') + print("Calculated results:") + print('------------------------------------------') + print(f'final_batch_size .... {final_batch_size}') + print(f'valid_gpus .......... {valid_gpus}') diff --git a/deepspeed/elasticity/__init__.py b/deepspeed/elasticity/__init__.py new file mode 100644 index 000000000000..be517de7df93 --- /dev/null +++ b/deepspeed/elasticity/__init__.py @@ -0,0 +1 @@ +from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config diff --git a/deepspeed/elasticity/config.py b/deepspeed/elasticity/config.py new file mode 100644 index 000000000000..dda56d72882c --- /dev/null +++ b/deepspeed/elasticity/config.py @@ -0,0 +1,80 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" + +import json +from .constants import * + + +class ElasticityError(Exception): + """ + Base exception for all elasticity related errors + """ + pass + + +class ElasticityConfigError(ElasticityError): + """ + Elasticity configuration error + """ + pass + + +class ElasticityIncompatibleWorldSize(ElasticityError): + """ + Attempting to run a world size that is incompatible with a given elastic config + """ + pass + + +class ElasticityConfig: + """ + Elastic config object, constructed from a param dictionary that only contains elastic + config parameters, example below: + + If elasticity is enabled, user must specify (at least) max_train_batch_size + and micro_batch_sizes. + + { + "enabled": true, + "max_train_batch_size": 2000, + "micro_batch_sizes": [2,4,6], + "min_gpus": 1, + "max_gpus" : 10000 + "min_time": 20 + "ignore_non_elastic_batch_info": false + "version": 0.1 + } + """ + def __init__(self, param_dict): + self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT) + if self.enabled: + if MAX_ACCEPTABLE_BATCH_SIZE in param_dict: + self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE] + else: + raise ElasticityConfigError( + f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}") + if MICRO_BATCHES in param_dict: + self.micro_batches = param_dict[MICRO_BATCHES] + else: + raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}") + else: + self.max_acceptable_batch_size = param_dict.get( + MAX_ACCEPTABLE_BATCH_SIZE, + MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT) + self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT) + self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT) + self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT) + self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT) + self.version = param_dict.get(VERSION, VERSION_DEFAULT) + self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH, + PREFER_LARGER_BATCH_DEFAULT) + self.ignore_non_elastic_batch_info = param_dict.get( + IGNORE_NON_ELASTIC_BATCH_INFO, + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT) + + def repr(self): + return self.__dict__ + + def __repr__(self): + return json.dumps(self.__dict__, sort_keys=True, indent=4) diff --git a/deepspeed/elasticity/constants.py b/deepspeed/elasticity/constants.py new file mode 100644 index 000000000000..7db563a83de2 --- /dev/null +++ b/deepspeed/elasticity/constants.py @@ -0,0 +1,74 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" + +######################################### +# Elasticity +######################################### +''' Elasticity Utility in DeepSpeed can be used to create highly elastic jobs compatible +with a large number of GPUs. For elastic jobs, DeepSpeed will provide a batch size that +can support a large number of GPUs based on the user specified parameters +''' +FORMAT = ''' +Elasticity should be enabled as: +"elasticity": { + "enabled": true, + "max_train_batch_size": 2000, + "micro_batch_sizes": [2,4,6], + "min_gpus": 1, + "max_gpus" : 10000 + "min_time": 20, + "prefer_larger_batch": true, + "ignore_non_elastic_batch_info": false, + "version": 0.1 +} +''' + +ELASTICITY = 'elasticity' + +# Current elasticity version +LATEST_ELASTICITY_VERSION = 0.1 + +ENABLED = 'enabled' +ENABLED_DEFAULT = False + +# Max acceptable train_batch_size +MAX_ACCEPTABLE_BATCH_SIZE = 'max_train_batch_size' +MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT = 2000 + +# Acceptable micro batch sizes, same as train_micro_batch_size_per_gpu +MICRO_BATCHES = 'micro_batch_sizes' +MICRO_BATCHES_DEFAULT = [2, 4, 6] + +# Min/max of GPUs to search over +MIN_GPUS = 'min_gpus' +MIN_GPUS_DEFAULT = 1 +MAX_GPUS = 'max_gpus' +MAX_GPUS_DEFAULT = 10000 + +# Minimum running time (minutes) before the scheduler will scale us +MIN_TIME = "min_time" +MIN_TIME_DEFAULT = "20" + +# When finding a suitable batch size, attempt to find one that is closest +# to the max train batch size given. +PREFER_LARGER_BATCH = 'prefer_larger_batch' +PREFER_LARGER_BATCH_DEFAULT = True + +# In order to reduce confusion, if elastic mode is enabled we +# require (via assert) that no batch info is set outside of the +# elastic config. You can turn off this assert via this config +# but keep in mind that all batch info defined outside the +# elastic mode *will be ignored*. +IGNORE_NON_ELASTIC_BATCH_INFO = 'ignore_non_elastic_batch_info' +IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT = False + +# Version of elastic logic to use +VERSION = "version" +VERSION_DEFAULT = LATEST_ELASTICITY_VERSION + +# Minimum deepspeed version to use elasticity +MINIMUM_DEEPSPEED_VERSION = "0.3.8" + +# Environment variable storing elastic config from resource scheduler +DEEPSPEED_ELASTICITY_CONFIG = "DEEPSPEED_ELASTICITY_CONFIG" diff --git a/deepspeed/elasticity/elasticity.py b/deepspeed/elasticity/elasticity.py new file mode 100644 index 000000000000..ae91877f5f24 --- /dev/null +++ b/deepspeed/elasticity/elasticity.py @@ -0,0 +1,334 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" +import os +import re +import json +import numpy as np + +from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \ + ElasticityIncompatibleWorldSize +from .constants import ELASTICITY, ENABLED, ENABLED_DEFAULT, LATEST_ELASTICITY_VERSION, \ + MINIMUM_DEEPSPEED_VERSION, IGNORE_NON_ELASTIC_BATCH_INFO, \ + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT, DEEPSPEED_ELASTICITY_CONFIG +from ..git_version_info import version as __version__ +from ..utils import logger + +# Thirty eight smallest highly composite numbers. The list should +# be enough to support up to 720K batch size. +HCN_LIST = [ + 1, + 2, + 4, + 6, + 12, + 24, + 36, + 48, + 60, + 120, + 180, + 240, + 360, + 720, + 840, + 1260, + 1680, + 2520, + 5040, + 7560, + 10080, + 15120, + 20160, + 25200, + 27720, + 45360, + 50400, + 55440, + 83160, + 110880, + 166320, + 221760, + 277200, + 332640, + 498960, + 554400, + 665280, + 720720 +] + + +def get_candidate_batch_sizes(base_list, max_acceptable_batch_size): + candidate_batch_size = [] + + #brute force is fine here. We are working with very small lists + for base in base_list: + batch_size = base + for hcn in HCN_LIST: + new_batch_size = base * hcn + if new_batch_size > max_acceptable_batch_size: + break + batch_size = new_batch_size + candidate_batch_size.append(batch_size) + return list(set(candidate_batch_size)) + + +def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus): + valid_gpus = [] + for micro_batch in micro_batches: + if batch_size % micro_batch == 0: + + max_gpus = batch_size // micro_batch + if max_gpus >= min_valid_gpus and max_gpus <= max_valid_gpus: + valid_gpus.append(max_gpus) + + for i in range(1, max_gpus // 2 + 1): + if max_gpus % i == 0: + if i >= min_valid_gpus and i <= max_valid_gpus: + valid_gpus.append(i) + valid_gpus = set(valid_gpus) + valid_gpus = sorted(list(valid_gpus)) + return valid_gpus + + +def get_best_candidates(candidate_batch_sizes, + micro_batches, + min_gpus, + max_gpus, + prefer_larger): + + max_valid_gpus = 0 + valid_gpus = None + final_batch_size = int(min(micro_batches)) + + for batch_size in candidate_batch_sizes: + + current_valid_gpus = get_valid_gpus(batch_size, + micro_batches, + min_gpus, + max_gpus) + + if (len(current_valid_gpus) > max_valid_gpus + or (len(current_valid_gpus) == max_valid_gpus and + ((prefer_larger and batch_size > final_batch_size) or + (not prefer_larger and batch_size < final_batch_size)))): + max_valid_gpus = len(current_valid_gpus) + valid_gpus = current_valid_gpus + final_batch_size = batch_size + + return final_batch_size, valid_gpus + + +def _get_compatible_gpus_v01(micro_batches, + max_acceptable_batch_size, + min_gpus=None, + max_gpus=None, + prefer_larger=True): + '''We use two heuristics to compute the batch size + 1. We use the Lowest Common Multiple of the micro-batches + as the base batch size and scale it by a HCN such that the result is + the largest batch size less than the max_acceptable batch size + 2. We use each of the micro batches as a base and scale it + by a HCN such that the result is the largest batch size less than the + max_acceptable batch size. + + We then use brute force to count the number of compatible GPU count for + each of the aforementioned cases, and return the batch size with the most number of + compatible GPU counts in the min-max GPU range if provided, other wise + we return the batch size with the most number of total compatible GPU counts. + + Returns: + final_batch_size + valid_gpus + ''' + + if min_gpus is None: + min_gpus = int(1) + + if max_gpus is None: + max_gpus = int(max_acceptable_batch_size / min(micro_batches)) + + assert all(mb <= max_acceptable_batch_size for mb in micro_batches ), \ + f"All micro batches must be less than \ + or equal to max_acceptable_batch_size: {max_acceptable_batch_size}" + + lcm = np.lcm.reduce(micro_batches) + + base_list = [] + base_list.extend(micro_batches) + base_list.append(lcm) + + candidate_batch_sizes = get_candidate_batch_sizes(base_list, + max_acceptable_batch_size) + + final_batch_size, valid_gpus = get_best_candidates( + candidate_batch_sizes, + micro_batches, + min_gpus, + max_gpus, + prefer_larger) + + return final_batch_size, valid_gpus + + +def _parse_version(version_str): + '''Parse a version string and extract the major and minor versions (and possibly patch version).''' + matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str) + if matched: + return int(matched.group(1)), int(matched.group(2)), int(matched.group(3)) + else: + matched = re.search('^(\d+)\.(\d+)', version_str) + assert matched != None, "Unable to parse version number, expecting" \ + f"major.minor[.patch] format but received {version_str}" + return int(matched.group(1)), int(matched.group(2)), 0 + + +def _compatible_ds_version_check(target_deepspeed_version: str): + min_major, min_minor, min_patch = _parse_version(MINIMUM_DEEPSPEED_VERSION) + trg_major, trg_minor, trg_patch = _parse_version(target_deepspeed_version) + + err_str = f"Target deepspeed version of {target_deepspeed_version} is not compatible " \ + f"with minimum version {MINIMUM_DEEPSPEED_VERSION} supporting elasticity." + if trg_major < min_major: + raise ElasticityError(err_str) + if trg_minor < min_minor: + raise ElasticityError(err_str) + if trg_patch < min_patch: + raise ElasticityError(err_str) + return True + + +def elasticity_enabled(ds_config: dict): + if ELASTICITY not in ds_config: + return False + return ds_config[ELASTICITY].get(ENABLED, ENABLED_DEFAULT) + + +def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict): + """ + Ensure the resource scheduler saw the same elastic config we are using at runtime + """ + if DEEPSPEED_ELASTICITY_CONFIG in os.environ: + scheduler_elastic_config_dict = json.loads( + os.environ[DEEPSPEED_ELASTICITY_CONFIG]) + scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict) + runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict) + err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}" + if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size: + raise ElasticityConfigError( + err_str.format('max_acceptable_batch_size', + scheduler_elastic_config.max_acceptable_batch_size, + 'max_acceptable_batch_size', + runtime_elastic_config.max_acceptable_batch_size)) + if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches: + raise ElasticityConfigError( + err_str.format('micro_batches', + scheduler_elastic_config.micro_batches, + 'micro_batches', + runtime_elastic_config.micro_batches)) + if runtime_elastic_config.version != scheduler_elastic_config.version: + raise ElasticityConfigError( + err_str.format('version', + scheduler_elastic_config.version, + 'version', + runtime_elastic_config.version)) + else: + logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \ + "guarantee resource scheduler will scale this job using compatible GPU counts.") + + +def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0): + """Core deepspeed elasticity API. Given an elastic config (similar to the example below) + DeepSpeed will compute a total train batch size corresponding valid GPU count list that + provides a high level of elasticity. Elasticity in this case means we are safe to scale + the training job up/down across the GPU count list *without* any negative impacts on + training convergence. This is achievable primarily due to DeepSpeed's gradient accumulation + feature which allows us to decompose a global training batch size into: + micro-batch-size * gradient-accumulation-steps * world-size. + + "elasticity": { + "enabled": true, + "max_train_batch_size": 2000, + "micro_batch_sizes": [2,4,6], + "min_gpus": 1, + "max_gpus" : 10000 + "min_time": 20 + "version": 0.1 + } + + Intended to be called both by scheduling infrastructure and deepspeed runtime. + For the same `ds_config` we should return deterministic results. + + Args: + ds_config (dict): DeepSpeed config dictionary/json + target_deepspeed_version (str): When called from scheduling + infrastructure we want to ensure that the target deepspeed version is + compatible with the elasticity version used in the backend. + world_size (int, optional): Intended/current world size, will do some sanity + checks to ensure world size is actually valid with the config. + + Raises: + ElasticityConfigError: Missing required elasticity config or elasticity disabled + ElasticityError: If target deepspeed version is not compatible with current version + + Returns: + final_batch_size (int): total batch size used for training + valid_gpus (list(int)): list of valid GPU counts with this config + micro_batch_size (int, optional): if world_size is provided will return + specific micro batch size + """ + if not isinstance(ds_config, dict): + raise ValueError("Expected ds_config to be a dictionary but received " \ + f"a {type(ds_config)}, containing: {ds_config}") + + if ELASTICITY not in ds_config: + raise ElasticityConfigError(f"'{ELASTICITY}' is missing from config json," \ + " please add it if running an elastic training job.") + + elastic_config_dict = ds_config[ELASTICITY] + if not elastic_config_dict.get(ENABLED, ENABLED_DEFAULT): + raise ElasticityConfigError("Elasticity is disabled, please enable it " \ + "('enabled':true) if running an elastic training job.") + + elastic_config = ElasticityConfig(elastic_config_dict) + + if float(elastic_config.version) > LATEST_ELASTICITY_VERSION: + raise ElasticityConfigError("Attempting to run elasticity version " \ + f"{elastic_config.version} but runtime only supports up " \ + f"to {LATEST_ELASTICITY_VERSION}") + + # Ensure target deepspeed version works with intended elasticity version + if not _compatible_ds_version_check(target_deepspeed_version): + raise ElasticityError("Unable to run elasticity on target deepspeed version of" \ + f" {target_deepspeed_version}, currently {__version__}") + + if float(elastic_config.version) == 0.1: + final_batch_size, valid_gpus = _get_compatible_gpus_v01( + micro_batches=elastic_config.micro_batches, + max_acceptable_batch_size=elastic_config.max_acceptable_batch_size, + min_gpus=elastic_config.min_gpus, + max_gpus=elastic_config.max_gpus, + prefer_larger=elastic_config.prefer_larger_batch_size) + # ensure batch size is int dtype + final_batch_size = int(final_batch_size) + else: + raise NotImplementedError( + f"Unable to find elastic logic for version: {elastic_config.version}") + + if world_size > 0: + if world_size not in valid_gpus: + raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \ + f"with the current list of valid GPU counts: {valid_gpus}") + + # Pick largest valid micro batch size + micro_batch_size = None + for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True): + if final_batch_size // world_size % mbsz == 0: + micro_batch_size = mbsz + break + assert micro_batch_size is not None, "Unable to find divisible micro batch size" \ + f" world_size={world_size}, final_batch_size={final_batch_size}, and " \ + f" micro_batches={elastic_config.micro_batches}." + return final_batch_size, valid_gpus, micro_batch_size + + return final_batch_size, valid_gpus diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 4a56aafbc539..9d52dfe6d766 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -6,13 +6,21 @@ import torch import json import copy -from deepspeed.runtime.constants import * -from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE -from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys -from deepspeed.runtime.zero.config import DeepSpeedZeroConfig -from deepspeed.runtime.zero.constants import * -from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig -from deepspeed.utils import logger + +from .constants import * +from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE +from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys +from .zero.config import DeepSpeedZeroConfig +from .zero.constants import * +from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig + +from ..git_version_info import version as __version__ +from ..utils import logger + +from ..elasticity import elasticity_enabled, compute_elastic_config, ensure_immutable_elastic_config +from ..elasticity.config import ElasticityConfigError +from ..elasticity.constants import ELASTICITY, IGNORE_NON_ELASTIC_BATCH_INFO, \ + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT TENSOR_CORE_ALIGN_SIZE = 8 @@ -504,6 +512,59 @@ def __init__(self, json_file, mpu=None, param_dict=None): self.global_rank = 0 self.world_size = 1 + # If elastic-mode enabled, update compute + update _param_dict + self.elasticity_enabled = elasticity_enabled(self._param_dict) + if self.elasticity_enabled: + logger.info("DeepSpeed elasticity support enabled") + final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config( + ds_config=self._param_dict, + target_deepspeed_version=__version__, + world_size=self.world_size) + + elastic_dict = self._param_dict[ELASTICITY] + + # Ensure the resource scheduler saw the same elastic config we are using at runtime + ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict) + + ignore_non_elastic_batch_info = elastic_dict.get( + IGNORE_NON_ELASTIC_BATCH_INFO, + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT) + + if not ignore_non_elastic_batch_info: + batch_params = [ + TRAIN_BATCH_SIZE, + TRAIN_MICRO_BATCH_SIZE_PER_GPU, + GRADIENT_ACCUMULATION_STEPS + ] + if any(map(lambda t: t in self._param_dict, batch_params)): + raise ElasticityConfigError("One or more batch related parameters were found in your " \ + f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \ + f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \ + "elastic training is enabled, which takes control of these parameters. " \ + "If you want to supress this error (the parameters will be silently ignored) " \ + f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.") + + # micro_bsz * world_size * gas = total_batch_size + # gas = total_batch_size // (micro_bsz * world_size) + gradient_accu_steps = final_batch_size // (micro_batch_size * + self.world_size) + + if TRAIN_BATCH_SIZE in self._param_dict: + logger.warning("[Elasticity] overriding training_batch_size: " \ + f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}") + if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict: + logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: " \ + f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}") + if GRADIENT_ACCUMULATION_STEPS in self._param_dict: + logger.warning("[Elasticity] overriding gradient_accumulation_steps: "\ + f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}") + + logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}") + + self._param_dict[TRAIN_BATCH_SIZE] = final_batch_size + self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = micro_batch_size + self._param_dict[GRADIENT_ACCUMULATION_STEPS] = gradient_accu_steps + self._initialize_params(self._param_dict) self._configure_train_batch_size() self._do_sanity_check() diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 38fdb647f61d..37f35692369b 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -13,6 +13,10 @@ def get_scalar_param(param_dict, param_name, param_default_value): return param_dict.get(param_name, param_default_value) +def get_list_param(param_dict, param_name, param_default_value): + return param_dict.get(param_name, param_default_value) + + def dict_raise_error_on_duplicate_keys(ordered_pairs): """Reject duplicate keys.""" d = dict((k, v) for k, v in ordered_pairs) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 49e1bedd3cfc..8f86469e1073 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -137,6 +137,10 @@ def __init__(self, self._configure_with_arguments(args, mpu) self._do_sanity_check() + if mpu is not None: + assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ + " with model parallelism." + self._set_distributed_vars() if self.tensorboard_enabled() and self.global_rank == 0: @@ -194,6 +198,22 @@ def __init__(self, self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten + def get_batch_info(self): + """ Get all training batch related settings. + + Returns: + train_batch_size (int): The effective training batch size. This is the amount of data + samples that leads to one step of model update. + train_micro_batch_size_per_gpu (int): Batch size to be processed by one GPU in one + step (without gradient accumulation). + gradient_accumulation_steps (int): Number of training steps to accumulate gradients + before averaging and applying them. + """ + return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps + + def elasticity_enabled(self): + return self._config.elasticity_enabled + def pld_enabled(self): return self._config.pld_enabled @@ -1224,10 +1244,13 @@ def load_checkpoint(self, if tag is None: latest_path = os.path.join(load_dir, 'latest') - assert os.path.isfile(latest_path), f"Unable to find latest file at {latest_path}, if trying to load latest " \ - "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." - with open(latest_path, 'r') as fd: - tag = fd.read().strip() + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + logger.warning(f"Unable to find latest file at {latest_path}, if trying to load latest " \ + "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint.") + return None, None load_path, client_states = self._load_checkpoint(load_dir, tag, diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 5c5d896dfc0d..87cc64950006 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -54,6 +54,8 @@ def __init__(self, *super_args, **super_kwargs): # We schedule the all-reduces, so disable it in super().backward() self.enable_backward_allreduce = False + assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ + " with pipeline parallelism." # pipeline step for logging self.log_batch_step_id = -1 diff --git a/op_builder/builder.py b/op_builder/builder.py index 1f350065b4f6..4bdb9e036708 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -33,7 +33,9 @@ def installed_cuda_version(): def get_default_compute_capatabilities(): compute_caps = DEFAULT_COMPUTE_CAPABILITIES - if installed_cuda_version()[0] >= 11: + import torch.utils.cpp_extension + if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version( + )[0] >= 11: compute_caps += ";8.0;8.6" return compute_caps diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 5845cdff4452..9192befdd35c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,3 +3,4 @@ torchvision>=0.4.0 tqdm tensorboardX==1.8 ninja +numpy diff --git a/setup.py b/setup.py index bf2ff9813537..19df040dcc88 100755 --- a/setup.py +++ b/setup.py @@ -184,7 +184,8 @@ def op_enabled(op_name): 'bin/deepspeed.pt', 'bin/ds', 'bin/ds_ssh', - 'bin/ds_report' + 'bin/ds_report', + 'bin/ds_elastic' ], classifiers=[ 'Programming Language :: Python :: 3.6', diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 1fbcacfa2aa4..1cd817ebc561 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -757,7 +757,7 @@ def _helper(args, model, hidden_dim): model, _, _,_ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) - with pytest.raises(AssertionError): - model.load_checkpoint(tmpdir) + # should be no-op, since latest doesn't exist + model.load_checkpoint(tmpdir) _helper(args=args, model=model, hidden_dim=hidden_dim) diff --git a/tests/unit/test_elastic.py b/tests/unit/test_elastic.py new file mode 100644 index 000000000000..339112b1bc93 --- /dev/null +++ b/tests/unit/test_elastic.py @@ -0,0 +1,241 @@ +import pytest +import deepspeed +from common import distributed_test +from deepspeed.git_version_info import version as ds_version +from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict + +base_ds_config = { + "elasticity": { + "enabled": True, + "max_train_batch_size": 10000, + "micro_batch_sizes": [8, + 12, + 16, + 17], + "min_gpus": 32, + "max_gpus": 1500, + "min_time": 20, + "version": 0.1 + } +} + + +def test_basic_10k(): + ds_config = base_ds_config.copy() + final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version) + + for gpu_num in valid_gpus: + assert final_batch_size % gpu_num == 0, f"Batch {final_batch_size} is not divisible by GPU count {gpu_num}" + batch_per_gpu = final_batch_size // gpu_num + found_valid_mbsize = False + + for mb in ds_config['elasticity']['micro_batch_sizes']: + if batch_per_gpu % mb == 0: + found_valid_mb = True + break + assert found_valid_mb, "No valid mb found" + + assert len(valid_gpus) == 23 + assert final_batch_size == 9792 + + +def test_old_version(): + ds_config = base_ds_config.copy() + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version="0.2") + + +def test_disabled(): + ds_config = base_ds_config.copy() + ds_config['elasticity']['enabled'] = False + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_valid_world_size(): + ds_config = base_ds_config.copy() + final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version, + world_size=64) + assert mbsize == 17 + + +def test_invalid_world_size(): + ds_config = base_ds_config.copy() + with pytest.raises(deepspeed.elasticity.config.ElasticityIncompatibleWorldSize): + final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version, + world_size=128) + + +def test_future_elastic_version(): + ds_config = base_ds_config.copy() + ds_config['elasticity']['version'] = '0.2' + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_missing_max_batch(): + ds_config = base_ds_config.copy() + del ds_config['elasticity']['max_train_batch_size'] + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_missing_micro_batch(): + ds_config = base_ds_config.copy() + del ds_config['elasticity']['micro_batch_sizes'] + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_empty_config(): + ds_config = {"elasticity": {"enabled": True}} + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, + target_deepspeed_version=ds_version) + + +def test_proper_mbsz(): + ds_config = base_ds_config.copy() + ds_config["elasticity"]["max_train_batch_size"] = 32 + ds_config["elasticity"]["micro_batch_sizes"] = [1, 2, 3, 7] + ds_config["elasticity"]["min_gpus"] = 1 + final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config( + ds_config=ds_config, + target_deepspeed_version=ds_version, + world_size=7) + assert mbsize == 3 + + +def test_non_elastic_batch_params(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Lamb", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "elasticity": { + "enabled": True, + "max_train_batch_size": 4, + "micro_batch_sizes": [1, + 2, + 3, + 4], + "min_gpus": 1, + "max_gpus": 4, + "min_time": 20, + "version": 0.1 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1, 2]) + def _test_elastic(args, model, hidden_dim): + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + _test_elastic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_non_elastic_batch_params_w_override(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Lamb", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "elasticity": { + "enabled": True, + "max_train_batch_size": 4, + "micro_batch_sizes": [1, + 2, + 3, + 4], + "min_gpus": 1, + "max_gpus": 4, + "min_time": 20, + "version": 0.1, + "ignore_non_elastic_batch_info": True + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1, 2]) + def _test_elastic(args, model, hidden_dim): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + _test_elastic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_elastic_config_changed(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Lamb", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "elasticity": { + "enabled": True, + "max_train_batch_size": 4, + "micro_batch_sizes": [1, + 2, + 3, + 4], + "min_gpus": 1, + "max_gpus": 4, + "min_time": 20, + "version": 0.1, + "ignore_non_elastic_batch_info": True + } + } + import json, os + scheduler_elastic_config = config_dict.copy() + scheduler_elastic_config["elasticity"]["max_train_batch_size"] = 27 + os.environ['DEEPSPEED_ELASTICITY_CONFIG'] = json.dumps(scheduler_elastic_config) + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1, 2]) + def _test_elastic(args, model, hidden_dim): + with pytest.raises(deepspeed.elasticity.config.ElasticityError): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + _test_elastic(args=args, model=model, hidden_dim=hidden_dim) diff --git a/version.txt b/version.txt index 667843220966..940ac09aa677 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.8 +0.3.9