diff --git a/deepspeed/nebula/config.py b/deepspeed/nebula/config.py index f9928d66147c..4622d66952f8 100644 --- a/deepspeed/nebula/config.py +++ b/deepspeed/nebula/config.py @@ -3,51 +3,56 @@ Licensed under the MIT license. """ -from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject -from deepspeed.nebula.constants import * - - -class DeepSpeedNebulaConfig(DeepSpeedConfigObject): - def __init__(self, param_dict): - super(DeepSpeedNebulaConfig, self).__init__() - - self.enabled = None - self.persistent_storage_path = None - self.persistent_time_interval = None - self.num_of_version_in_retention = None - self.enable_nebula_load = None - - if NEBULA in param_dict.keys(): - nebula_dict = param_dict[NEBULA] - else: - nebula_dict = {} - - self._initialize(nebula_dict) - - def _initialize(self, nebula_dict): - self.enabled = get_scalar_param(nebula_dict, - NEBULA_ENABLED, - NEBULA_ENABLED_DEFAULT) - - self.load_path = get_scalar_param(nebula_dict, - NEBULA_LOAD_PATH, - NEBULA_LOAD_PATH_DEFAULT) - - self.enable_nebula_load = get_scalar_param(nebula_dict, - NEBULA_ENABLE_NEBULA_LOAD, - NEBULA_ENABLE_NEBULA_LOAD_DEFAULT) - - self.persistent_storage_path = get_scalar_param( - nebula_dict, - NEBULA_PERSISTENT_STORAGE_PATH, - NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT) - - self.persistent_time_interval = get_scalar_param( - nebula_dict, - NEBULA_PERSISTENT_TIME_INTERVAL, - NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT) - - self.num_of_version_in_retention = get_scalar_param( - nebula_dict, - NEBULA_NUM_OF_VERSION_IN_RETENTION, - NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT) +from pydantic import Field, validator +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + +# TODO: remove once main deepspeed config uses pydantic +NEBULA = "nebula" + + +def get_nebula_config(param_dict): + nebula_config_dict = param_dict.get(NEBULA, {}) + return DeepSpeedNebulaConfig(**nebula_config_dict) + + +class DeepSpeedNebulaConfig(DeepSpeedConfigModel): + """ Sets parameters for Nebula checkpoint engine. """ + + enabled: bool = False + """ Enable or disable Nebula checkpoint engine. """ + + load_path: str = None + """ + When you want to resume the previous checkpoint saved by nebula, you can + set `load_path` as the parent folder of checkpoint. If `load_path` is + None, the `persistent_storage_path` will be the default path to load. + """ + + persistent_storage_path: str = None + """ Nebula will save the checkpoint under `load_path` in the asynchronous way. """ + + persistent_time_interval: int = Field(None, gt=0) + """ Time interval to trigger the nebula persistence. """ + + num_of_version_in_retention: int = Field(2, gt=0) + """ + Checkpoint number which will be kept in memory. Let us say, if the value is + `2`. Then we have checkpoints `1` and `2` are ready now. When it comes to + checkpoint `3`, the `1` will be removed if `1` has been persisted to disk. + """ + + enable_nebula_load: bool = True + """ + There is a case where customer want to load the checkpoint saved by raw + torch. Because nebula cannot load torch checkpoint directly as they have + different folder structures to bring the gap for loading(the data are + totaly same in bytes for torch and nebula saving). In this case, we must + disable nebula load to use raw torch load. Customer can just set + `enable_nebula_load` to False. Then use original way of deepspeed to load, + i.e. set the value of "--load". + """ + @validator("persistent_storage_path") + def load_path_check(cls, field_value, values): + if values["load_path"] is None: + values["load_path"] = field_value + return field_value diff --git a/deepspeed/nebula/constants.py b/deepspeed/nebula/constants.py index 0e66fa8d1536..31a85df7a6f4 100644 --- a/deepspeed/nebula/constants.py +++ b/deepspeed/nebula/constants.py @@ -3,63 +3,6 @@ Licensed under the MIT license. """ -######################################### -# nebula -######################################### -# Nebula. By default, this feature is not enabled. -# Users can configure in ds_config.json as below example: -NEBULA_FORMAT = ''' -nebula should be enabled as: -"session_params": { - "nebula": { - "enabled": true, - "persistent_storage_path": "/foo/bar", - "persistent_time_interval": 100, - "num_of_version_in_retention": 2, - "enable_nebula_load": true - } -} -''' - -NEBULA = "nebula" - -NEBULA_ENABLED = "enabled" -NEBULA_ENABLED_DEFAULT = False - -# There is a case where customer want to load the checkpoint saved -# by raw torch. Because nebula cannot load torch checkpoint directly -# as they have different folder structures to bring the gap for -# loading(the data are totaly same in bytes for torch and enbula s -# aving). -# In this case, we must disable nebula load to use raw torch load. -# Customer can just set NEBULA_ENABLE_NEBULA_LOAD to False. Then use -# original way of deepspeed to load, i.e. set the value of "--load". -NEBULA_ENABLE_NEBULA_LOAD = "enable_nebula_load" -NEBULA_ENABLE_NEBULA_LOAD_DEFAULT = True - -# When you want to resume the previous checkpoint saved by nebula, -# you can set NEBULA_LOAD_PATH as the parent folder of checkpoint. -# If NEBULA_LOAD_PATH is None, the NEBULA_PERSISTENT_STORAGE_PATH -# will be the default path to load. -NEBULA_LOAD_PATH = "nebula_load_path" -NEBULA_LOAD_PATH_DEFAULT = None - -# Nebula will save the checkpoint under NEBULA_LOAD_PATH in the -# asynchronous way. -NEBULA_PERSISTENT_STORAGE_PATH = "persistent_storage_path" -NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT = None - -# Time interval to trigger the nebula persistence. -NEBULA_PERSISTENT_TIME_INTERVAL = "persistent_time_interval" -NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT = 100 - -# Checkpoint number which will be kept in memory. Let us say, -# if the value is 2. Then we have checkpoints 1 and 2 are ready -# now. When it comes to checkpoint 3, the 1 will be removed if -# 1 has been persisted to disk. -NEBULA_NUM_OF_VERSION_IN_RETENTION = "num_of_version_in_retention" -NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2 - # Neubla envs NEBULA_EXPORT_ENVS = [ 'DLTS_JOB_ID', diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index 4cdd0856e7b7..ad5e12d5d70b 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -5,7 +5,6 @@ from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ CheckpointEngine from deepspeed.utils import logger, log_dist -from deepspeed.nebula.constants import * def _get_tag_from_path(path): @@ -19,14 +18,13 @@ def __init__(self, config_params=None): self.tag_flag = None self.enable_nebula_load = config_params.enable_nebula_load self.nebula_load_path = config_params.load_path - if self.nebula_load_path is None: - self.nebula_load_path = config_params.persistent_storage_path nebula_config_params = { - NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path, - NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval, - NEBULA_NUM_OF_VERSION_IN_RETENTION: - config_params.num_of_version_in_retention, + key: getattr(config_params, + key) + for key in ("persistent_storage_path", + "persisten_time_interval", + "num_of_version_in_retention") } torch_nebula.init(**nebula_config_params) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 3e5d2cfff81c..efc437a5fd9f 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -51,7 +51,7 @@ from ..profiling.config import DeepSpeedFlopsProfilerConfig from ..autotuning.config import DeepSpeedAutotuningConfig -from ..nebula.config import DeepSpeedNebulaConfig +from ..nebula.config import get_nebula_config from ..compression.config import get_compression_config, get_quantize_enabled from ..compression.constants import * @@ -914,7 +914,7 @@ def _initialize_params(self, param_dict): self.dataloader_drop_last = get_dataloader_drop_last(param_dict) - self.nebula_config = DeepSpeedNebulaConfig(param_dict) + self.nebula_config = get_nebula_config(param_dict) def _batch_assertion(self): diff --git a/docs/code-docs/source/model-checkpointing.rst b/docs/code-docs/source/model-checkpointing.rst index c797943dd662..97dde782e09e 100644 --- a/docs/code-docs/source/model-checkpointing.rst +++ b/docs/code-docs/source/model-checkpointing.rst @@ -11,7 +11,6 @@ Saving Training Checkpoints --------------------------- .. autofunction:: deepspeed.DeepSpeedEngine.save_checkpoint - ZeRO Checkpoint fp32 Weights Recovery ------------------------------------- @@ -22,3 +21,12 @@ DeepSpeed provides routines for extracting fp32 weights from the saved ZeRO chec .. autofunction:: deepspeed.utils.zero_to_fp32.load_state_dict_from_zero_checkpoint .. autofunction:: deepspeed.utils.zero_to_fp32.convert_zero_checkpoint_to_fp32_state_dict + +Nebula Checkpoint Engine +------------------------ + +DeepSpeed supports the Nebula checkpoint engine. The behavior of this +checkpoint engine can be controlled with the following configuration class. + +.. _DeepSpeedNebulaConfig: +.. autopydantic_model:: deepspeed.nebula.config.DeepSpeedNebulaConfig