From d88e591d60c747ba7e6d136d12cd1986be946a63 Mon Sep 17 00:00:00 2001 From: trajepl Date: Thu, 30 Jun 2022 16:29:41 +0800 Subject: [PATCH 01/25] enable checkpoint engine --- deepspeed/__init__.py | 32 ++++++- deepspeed/constants.py | 1 + deepspeed/launcher/launch.py | 13 ++- deepspeed/launcher/runner.py | 25 ++++- .../runtime/checkpoint_engine/__init__.py | 0 .../checkpoint_engine/checkpoint_engine.py | 24 +++++ .../nebula_checkpoint_engine.py | 50 ++++++++++ deepspeed/runtime/engine.py | 93 +++++++++++++------ deepspeed/runtime/pipe/engine.py | 18 +++- deepspeed/runtime/pipe/module.py | 17 +++- deepspeed/runtime/state_dict_factory.py | 67 ++++++++++--- 11 files changed, 284 insertions(+), 56 deletions(-) create mode 100644 deepspeed/runtime/checkpoint_engine/__init__.py create mode 100644 deepspeed/runtime/checkpoint_engine/checkpoint_engine.py create mode 100644 deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 25229bde06d8..e01d2e9d7563 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -60,7 +60,9 @@ def initialize(args=None, dist_init_required: Optional[bool] = None, collate_fn=None, config=None, - config_params=None): + config_params=None, + enable_nebula=None, + nebula_config_params=None): """Initialize the DeepSpeed Engine. Arguments: @@ -128,7 +130,9 @@ def initialize(args=None, dist_init_required=dist_init_required, collate_fn=collate_fn, config=config, - config_params=config_params) + config_params=config_params, + enable_nebula=enable_nebula, + nebula_config_params=nebula_config_params) else: assert mpu is None, "mpu must be None with pipeline parallelism" engine = PipelineEngine(args=args, @@ -141,7 +145,9 @@ def initialize(args=None, dist_init_required=dist_init_required, collate_fn=collate_fn, config=config, - config_params=config_params) + config_params=config_params, + enable_nebula=enable_nebula, + nebula_config_params=nebula_config_params) return_items = [ engine, @@ -200,6 +206,26 @@ def _add_core_arguments(parser): "Run via MPI, this will attempt to discover the necessary variables to initialize torch " "distributed from the MPI environment") + group.add_argument( + '--nebula', + default=False, + action='store_true', + help= + "Save checkpoint via torch_nebula.save, this will attempt to save the time from torch.save" + ) + + group.add_argument('--persistent_storage_path', + default=None, + help="Iter3 path for persistence") + + group.add_argument('--persistent_time_interval', + default=None, + help="Time interval for tier3 saving") + + group.add_argument('--num_of_version_in_retention', + default=2, + help="File numbers to be remained") + return parser diff --git a/deepspeed/constants.py b/deepspeed/constants.py index 9576c9c078d5..8b4cd415d1e4 100644 --- a/deepspeed/constants.py +++ b/deepspeed/constants.py @@ -17,3 +17,4 @@ INFERENCE_GENERIC_MODE = 'generic' INFERENCE_SPECIALIZED_MODE = 'specialized' +DLTS_POD_ENV_PATH = '/dlts-runtime/env/pod.env' diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index 06321cbf90b3..731c8a216119 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -18,7 +18,7 @@ from collections import defaultdict from argparse import ArgumentParser, REMAINDER -from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, DLTS_POD_ENV_PATH from ..utils import logger PID_FILE_BASEPATH = "/tmp" @@ -143,6 +143,17 @@ def main(): with open(pid_file, 'w') as fd: fd.write(f"{launcher_pid}") + if os.path.exists(DLTS_POD_ENV_PATH): + with open(DLTS_POD_ENV_PATH) as file: + lines = file.readlines() + lines = [line.rstrip() for line in lines] + for line in lines: + if line.startswith('export FC_TASKROLE_NAME') or line.startswith( + 'export FC_TASK_INDEX'): + key_val = line.split()[1] + key, val = key_val.split('=') + current_env[key] = val + processes = [] cmd = [] for local_rank in range(0, num_local_procs): diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index be25715660db..6cb1d3faf619 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -25,7 +25,30 @@ from ..autotuning import Autotuner DLTS_HOSTFILE = "/job/hostfile" -EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", "UCX"] +EXPORT_ENVS = [ + 'NCCL', + 'PYTHON', + 'MV2', + 'UCX', + 'DLTS_JOB_ID', + 'DLTS_NUM_WORKER', + 'NEBULA_PERSISTENT_STORAGE_PATH', + 'NEBULA_PERSISTENT_TIME_INTERVAL', + 'AML_RUN_ID', + 'AZUREML_RUN_TOKEN', + 'AZUREML_WORKSPACE_SCOPE', + 'AZUREML_EXPERIMENT_SCOPE', + 'AZUREML_RUN_HISTORY_SERVICE_ENDPOINT', + 'AZUREML_RUN_ID', + 'NEBULA_MEMORY_BUFFER_SIZE', + 'AZUREML_PARAMETER_ITPJOB_NAME', + 'FC_TASKROLE_NAME', + 'FC_TASK_INDEX', + 'MASTER_HOST', + 'LOCAL_HOST', + 'AZUREML_BLOB_ACCOUNT_NAME', + 'AZUREML_BLOB_ACCOUNT_KEY' +] DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env" DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.'] PDSH_MAX_FAN_OUT = 1024 diff --git a/deepspeed/runtime/checkpoint_engine/__init__.py b/deepspeed/runtime/checkpoint_engine/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py new file mode 100644 index 000000000000..bdd2ed381f38 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py @@ -0,0 +1,24 @@ +import torch +from deepspeed.utils import logger + + +class CheckpointEngine(object): + + def __init__(self): + return + + def save(self, state_dict, path: str, tag: str): + logger.info(f"Saving {path} under tag{tag}...") + torch.save(state_dict, path) + logger.info(f"Saved {path} under tag{tag}.") + return None + + def load(self, path: str, tag: str = None, persist_path: str = None, map_location=None): + logger.info(f"Loading {path} under tag{tag} from {persist_path}...") + partition = torch.load(path, map_location=map_location) + logger.info(f"Loaded {path} under tag{tag} from {persist_path}.") + return partition + + def commit(self, tag): + print(f"Checkpoint {tag} is ready now!") + return True diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py new file mode 100644 index 000000000000..a99a2b84d83f --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -0,0 +1,50 @@ +import os + +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ + CheckpointEngine +from deepspeed.utils import logger + +try: + import torch_nebula +except ImportError: + logger.warning( + "Warning: cannot find the package of torch-nebula. Will use the legacy way for checkpoint management." + ) + torch_nebula = None + + +class NebulaCheckpointEngine(CheckpointEngine): + + def __init__(self, config_params=None): + torch_nebula.init(**config_params) + + def save(self, state_dict, path: str, tag: str): + logger.info(f"[Nebula] Saving {path} under tag{tag}...") + partititon_name = os.path.basename(path) + + # -2 means: customer needs to explicitly tell nebula + # current checkpoint is complete by commit methond. + checkpoint = torch_nebula.Checkpoint(tag, -2) + checkpoint.save(partititon_name, state_dict) + logger.info(f"[Nebula] Saved {path} under tag{tag}.") + return None + + def load(self, path: str, tag: str = None, persist_path: str = None, map_location=None): + logger.info(f"[Nebula] Loading {path} under tag{tag} from {persist_path}...") + partititon_name = os.path.basename(path) + checkpoint = None + if tag is None: + checkpoint = torch_nebula.get_latest_checkpoint(persist_path=persist_path) + if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''): + logger.warning(f"Unable to find latest valid checkpoint from Nebula!") + return None + else: + checkpoint = torch_nebula.get_checkpoint(tag=tag, persist_path=persist_path) + partition = checkpoint.load(partititon_name) + logger.info(f"[Nebula] Loaded {path} under tag{tag} from {persist_path}.") + return partition + + def commit(self, tag): + checkpoint = torch_nebula.Checkpoint(tag, -2) + commit_rls = checkpoint.commit() + return commit_rls diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 892394e049ad..f5167dbfcc42 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -48,7 +48,7 @@ from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT from deepspeed.runtime.sparse_tensor import SparseTensor -from deepspeed.runtime import lr_schedules +from deepspeed.runtime import checkpoint_engine, lr_schedules from deepspeed.utils import groups from deepspeed.runtime.utils import get_grad_norm from deepspeed.utils import logger, log_dist, instrument_w_nvtx @@ -60,6 +60,7 @@ from deepspeed.runtime.utils import clip_grad_norm_ from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine from .pipe.module import PipelineModule from .utils import ensure_directory_exists, get_ma_status @@ -188,6 +189,8 @@ def __init__( config=None, config_params=None, dont_change_device=False, + enable_nebula=None, + nebula_config_params=None, ): super(DeepSpeedEngine, self).__init__() self.dont_change_device = dont_change_device @@ -221,6 +224,10 @@ def __init__( self._global_grad_norm = None self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. + self.checkpoint_engine = None + self.enable_nebula = enable_nebula + self.nebula_config_params = nebula_config_params + global dist from deepspeed import comm as dist self._is_gradient_accumulation_boundary = None @@ -767,6 +774,13 @@ def _configure_lr_scheduler(self, client_lr_scheduler): log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) def _configure_checkpointing(self, dist_init_required): + if self.enable_nebula: + from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ + NebulaCheckpointEngine + self.checkpoint_engine = NebulaCheckpointEngine( + config_params=self.nebula_config_params) + else: + self.checkpoint_engine = CheckpointEngine() dp_rank = self.global_rank if self.mpu: @@ -2287,7 +2301,9 @@ def load_moe_state_dict(checkpoint_path, old_moe_load, model=None, mpu=None, - num_experts=1): + num_experts=1, + persist_path=None, + checkpoint_engine=CheckpointEngine()): if old_moe_load: expp_rank = groups._get_expert_data_parallel_rank( groups._get_max_expert_size_name()) @@ -2297,13 +2313,15 @@ def load_moe_state_dict(checkpoint_path, groups._get_max_expert_size_name()) for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id - expert_state_dict = torch.load(DeepSpeedEngine._get_expert_ckpt_name( + expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name( checkpoint_path, -1, # -1 means ignore layer_id global_expert_id, tag, mpu), - map_location=torch.device('cpu')) + map_location=torch.device('cpu'), + tag=tag, + persist_path=persist_path) # Updating global -> local expert ids moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' @@ -2323,14 +2341,16 @@ def load_moe_state_dict(checkpoint_path, # loop all local_experts for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id - expert_state_dict = torch.load( + expert_state_dict = checkpoint_engine.load( DeepSpeedEngine._get_expert_ckpt_name( checkpoint_path, moe_layer_id, global_expert_id, tag, mpu), - map_location=torch.device('cpu')) + map_location=torch.device('cpu'), + tag=tag, + persist_path=persist_path) # print(expert_state_dict.keys()) # Updating global -> local expert ids moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' @@ -2442,7 +2462,9 @@ def load_checkpoint(self, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False, - custom_load_fn=None): + custom_load_fn=None, + enable_nebula_load=True, + nebula_load_path_tier3=None): """Load training checkpoint Arguments: @@ -2465,6 +2487,10 @@ def load_checkpoint(self, ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine before ``load_checkpoint()``. """ + checkpoint_engine_tmp = self.checkpoint_engine + if self.enable_nebula and enable_nebula_load == False: + self.checkpoint_engine = CheckpointEngine() + self.persist_path = nebula_load_path_tier3 if tag is None: latest_path = os.path.join(load_dir, "latest") @@ -2502,6 +2528,9 @@ def load_checkpoint(self, if self.zero_optimization_partition_weights(): self.optimizer.checkpoint_event_epilogue() + self.checkpoint_engine = checkpoint_engine_tmp + self.persist_path = None + return load_path, client_states def _load_checkpoint(self, @@ -2543,7 +2572,9 @@ def _load_checkpoint(self, old_moe_load=old_moe_load, model=self.module, mpu=self.mpu, - num_experts=self.num_experts) + num_experts=self.num_experts, + persist_path=self.persist_path, + checkpoint_engine=self.checkpoint_engine) self.load_module_state_dict(state_dict=checkpoint['module'], strict=load_module_strict, @@ -2560,8 +2591,11 @@ def _load_checkpoint(self, largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) - optim_checkpoint = torch.load(optim_load_path, - map_location=torch.device('cpu')) + optim_checkpoint = self.checkpoint_engine.load( + optim_load_path, + map_location=torch.device('cpu'), + tag=tag, + persist_path=self.persist_path) else: optim_checkpoint = checkpoint @@ -2711,14 +2745,17 @@ def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode): return zero_ckpt_names - def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): + def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names, tag=None): zero_sd_list = [] for i, ckpt_name in enumerate(zero_ckpt_names): _state = None # Fully load state for current rank if self.zero_elastic_checkpoint() or dist.get_rank( group=self.optimizer.dp_process_group) == i: - _state = torch.load(ckpt_name, map_location='cpu') + _state = self.checkpoint_engine.load(ckpt_name, + map_location='cpu', + tag=tag, + persist_path=self.persist_path) else: _state = {OPTIMIZER_STATE_DICT: None} zero_sd_list.append(_state) @@ -2743,7 +2780,7 @@ def _get_all_zero_checkpoints(self, load_dir, tag): logger.warn( f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine' ) - return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) + return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names, tag) return None @@ -2818,6 +2855,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) # Save latest checkpoint tag dist.barrier() + self.checkpoint_engine.commit(tag) if save_latest and self.global_rank == 0: with open(os.path.join(save_dir, 'latest'), 'w') as fd: fd.write(tag) @@ -2887,7 +2925,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): global_expert_id, tag, self.mpu) - torch.save(expert_state_dict, moe_save_path) + self.checkpoint_engine.save(expert_state_dict, + moe_save_path, + tag=tag) moe_layer_id += 1 self._curr_ckpt_path = os.path.join(save_dir, tag) @@ -2908,9 +2948,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None } - with open(self._get_optimizer_ckpt_name(save_dir, tag, expp_rank), 'wb') as fd: - torch.save(optimizer_state, fd) - fd.flush() + # TODO: why use BufferedWriter not the path + file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) + self.checkpoint_engine.save(optimizer_state, file_path, tag=tag) # get non-moe parameters model_state_dict = self._get_non_moe_state_dict(self.module_state_dict()) @@ -2940,9 +2980,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): } state.update(client_state) logger.info(f'Saving model checkpoint: {save_path}') - with open(save_path, 'wb') as fd: - torch.save(state, fd) - fd.flush() + self.checkpoint_engine.save(state, save_path, tag=tag) self._curr_save_path = None def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): @@ -2995,7 +3033,7 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): state.update(client_state) log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1]) - torch.save(state, save_path) + self.checkpoint_engine.save(state, save_path, tag=tag) self._curr_save_path = None def _get_buffer_names(self): @@ -3078,9 +3116,8 @@ def _save_zero_checkpoint(self, save_path, tag): zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version) - with open(zero_checkpoint_name, 'wb') as fd: - torch.save(zero_sd, fd) - fd.flush() + self.checkpoint_engine.save(zero_sd, zero_checkpoint_name, tag=tag) + if self.global_rank == 0: self._copy_recovery_script(save_path) ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero' @@ -3157,12 +3194,12 @@ def get_layer_state_dict(module, prefix=""): return state_dict - def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): + def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin", tag=None): """has been renamed to save_16bit_model, keeping this around for backwards compatibility""" - return self.save_16bit_model(save_dir, save_filename) + return self.save_16bit_model(save_dir, save_filename, tag=tag) - def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): + def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin", tag=None): r"""Save 16bit model weights This method saves the 16bit model weights at the desired destination. @@ -3199,6 +3236,6 @@ def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): if dist.get_rank() == 0: os.makedirs(save_dir, exist_ok=True) logger.info(f"Saving model weights to {path}") - torch.save(state_dict, path) + self.checkpoint_engine.save(state_dict, path, tag=tag) return True diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 94add6f9c8e4..74cd3f5f4ebe 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1306,7 +1306,7 @@ def mem_status(self, msg, print_rank=-1, reset_max=False): f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)' ) - def module_state_dict(self): + def module_state_dict(self, tag=None): """Override hack to save a pipe model and return the directory path of the save. This method should only be called by DeepSpeed's ``save_checkpoint()``. The @@ -1320,10 +1320,16 @@ def module_state_dict(self): assert self._curr_ckpt_path is not None, \ "PipelineEngine expects module_state_dict() to be called from save_checkpoint()" - self.module.save_state_dict(self._curr_ckpt_path) + self.module.save_state_dict(self._curr_ckpt_path, + tag=tag, + checkpoint_engine=self.checkpoint_engine) return None - def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): + def load_module_state_dict(self, + state_dict, + strict=True, + custom_load_fn=None, + tag=None): """Override hack to instead use a directory path. This is important because pipeline models checkpoint by layer instead of rank. @@ -1339,7 +1345,11 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): super().load_module_state_dict(state_dict, strict) return - self.module.load_state_dir(load_dir=self._curr_ckpt_path, strict=strict) + self.module.load_state_dir(load_dir=self._curr_ckpt_path, + strict=strict, + tag=tag, + persist_path=self.persist_path, + checkpoint_engine=self.checkpoint_engine) # A map of PipeInstruction types to methods. Each method will be executed with the # kwargs provided to the PipeInstruction from the scheduler. diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index f47806ee8673..9f687418a889 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -6,6 +6,7 @@ from collections import defaultdict from functools import partial +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine import torch import torch.nn as nn @@ -563,7 +564,7 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx): ckpt_files.sort() return ckpt_files - def save_state_dict(self, save_dir): + def save_state_dict(self, save_dir, tag=None, checkpoint_engine=None): if self._grid.data_parallel_id != 0: return @@ -584,9 +585,14 @@ def save_state_dict(self, save_dir): {k: v.clone() for k, v in orig_state_dict.items()}) - torch.save(final_state_dict, model_ckpt_path) - - def load_state_dir(self, load_dir, strict=True): + checkpoint_engine.save(final_state_dict, model_ckpt_path, tag) + + def load_state_dir(self, + load_dir, + strict=True, + tag=None, + persist_path=None, + checkpoint_engine=CheckpointEngine()): for idx, layer in enumerate(self.forward_funcs): # Functions, etc. will not have state_dicts if not hasattr(layer, 'load_state_dict'): @@ -597,7 +603,8 @@ def load_state_dir(self, load_dir, strict=True): mp_rank = self._grid.get_slice_parallel_rank() mp_world_size = self._grid.get_slice_parallel_world_size() - sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, version=2.0) + sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, version=2.0, tag=tag,\ + persist_path=persist_path, checkpoint_engine=checkpoint_engine) load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True) layer.load_state_dict(checkpoint) diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 09887aaa275c..2470c2a42cf4 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -8,7 +8,10 @@ import collections import json from abc import ABC, abstractmethod + from deepspeed.utils import logger +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine + from .weight_quantizer import WeightQuantization AUTO_MODULE_KEY = 'auto' @@ -25,19 +28,37 @@ def get_sd_loader_json(json_file): return SDLoaderFactory.get_sd_loader(ckpt_list, sd_type, version) @staticmethod - def get_sd_loader(ckpt_list, sd_type='Megatron', version=None): + def get_sd_loader(ckpt_list, + sd_type='Megatron', + version=None, + tag=None, + persist_path=None, + checkpoint_engine=None): if sd_type == 'Megatron': - return MegatronSDLoader(ckpt_list, version) + return MegatronSDLoader(ckpt_list, + version, + tag, + persist_path, + checkpoint_engine) else: assert False, '{} checkpoint type is not supported'.format(sd_type) class SDLoaderBase(ABC): - def __init__(self, ckpt_list, version): + def __init__(self, + ckpt_list, + version, + tag=None, + persist_path=None, + checkpoint_engine=None): self.module_key = None self.ckpt_list = ckpt_list self.check_ckpt_list() self.version = version + self.tag = tag + self.persist_path = persist_path + self.checkpoint_engine = CheckpointEngine( + ) if checkpoint_engine is None else checkpoint_engine def load(self, mp_world_size, @@ -79,7 +100,8 @@ def load(self, if num_ckpt == mp_world_size: assert os.path.exists(load_path) #logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}') - sd = torch.load(load_path, map_location=lambda storage, loc: storage) + sd = self.checkpoint_engine.load(load_path, map_location=lambda storage, \ + loc: storage, tag=self.tag, persist_path=self.persist_path) if quantize: quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, @@ -108,9 +130,11 @@ def get_merge_state_dicts(self, mp_world_size, mp_rank): logger.info(f"mp_rank: {mp_rank}, ckpt_list: {ckpt_list}") sd_list = [ - torch.load(ckpt, - map_location=lambda storage, - loc: storage) for ckpt in ckpt_list + self.checkpoint_engine.load(ckpt, + tag=self.tag, + persist_path=self.persist_path, + map_location=lambda storage, + loc: storage) for ckpt in ckpt_list ] return sd_list @@ -126,9 +150,11 @@ def get_split_state_dict(self, mp_world_size, mp_rank): f"mp_rank: {mp_rank}, ckpt_list: {self.ckpt_list[ckpt_index]}, offset: {ckpt_offset}" ) - sd = torch.load(self.ckpt_list[ckpt_index], - map_location=lambda storage, - loc: storage) + sd = self.checkpoint_engine.load(self.ckpt_list[ckpt_index], + tag=self.tag, + persist_path=self.persist_path, + map_location=lambda storage, + loc: storage) return sd, num_to_split, ckpt_offset @@ -161,7 +187,11 @@ def check_ckpt_list(self): #logger.info(f'checkpoint file list: {self.ckpt_list}') assert len(self.ckpt_list) > 0 - sd = torch.load(self.ckpt_list[0], map_location=lambda storage, loc: storage) + sd = self.checkpoint_engine.load(self.ckpt_list[0], + tag=self.tag, + persist_path=self.persist_path, + map_location=lambda storage, + loc: storage) # check checkpoint count is same with saved mp_world_size if 'mp_world_size' in sd.keys(): @@ -193,8 +223,13 @@ def sanity_check(self, ckpt_file_name): class MegatronSDLoader(SDLoaderBase): - def __init__(self, ckpt_list, version): - super().__init__(ckpt_list, version) + def __init__(self, + ckpt_list, + version, + tag=None, + persist_path=None, + checkpoint_engine=None): + super().__init__(ckpt_list, version, tag, persist_path, checkpoint_engine) """ ## Q/K/V data need special processing key: transformer.layers.0.attention.query_key_value.weight, shape: torch.Size([3192, 4256]) @@ -431,7 +466,11 @@ def sanity_check(self, ckpt_file_name): "mlp.dense_h_to_4h.bias" ] - sd = torch.load(ckpt_file_name, map_location=lambda storage, loc: storage) + sd = self.checkpoint_engine.load(ckpt_file_name, + tag=self.tag, + persist_path=self.persist_path, + map_location=lambda storage, + loc: storage) # partial_key is a sub-string of one key in the sd def check_key_exist(partial_key, sd): From 07e59d664814a8301f1cf39765ef10d8ca97623a Mon Sep 17 00:00:00 2001 From: trajepl Date: Mon, 11 Jul 2022 16:25:18 +0800 Subject: [PATCH 02/25] seprated nebula config --- deepspeed/__init__.py | 32 +---------- deepspeed/launcher/runner.py | 27 +-------- deepspeed/nebula/config.py | 42 ++++++++++++++ deepspeed/nebula/constants.py | 56 +++++++++++++++++++ .../checkpoint_engine/checkpoint_engine.py | 7 ++- .../nebula_checkpoint_engine.py | 25 +++++---- deepspeed/runtime/config.py | 3 + deepspeed/runtime/engine.py | 10 +--- 8 files changed, 129 insertions(+), 73 deletions(-) create mode 100644 deepspeed/nebula/config.py create mode 100644 deepspeed/nebula/constants.py diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index e01d2e9d7563..25229bde06d8 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -60,9 +60,7 @@ def initialize(args=None, dist_init_required: Optional[bool] = None, collate_fn=None, config=None, - config_params=None, - enable_nebula=None, - nebula_config_params=None): + config_params=None): """Initialize the DeepSpeed Engine. Arguments: @@ -130,9 +128,7 @@ def initialize(args=None, dist_init_required=dist_init_required, collate_fn=collate_fn, config=config, - config_params=config_params, - enable_nebula=enable_nebula, - nebula_config_params=nebula_config_params) + config_params=config_params) else: assert mpu is None, "mpu must be None with pipeline parallelism" engine = PipelineEngine(args=args, @@ -145,9 +141,7 @@ def initialize(args=None, dist_init_required=dist_init_required, collate_fn=collate_fn, config=config, - config_params=config_params, - enable_nebula=enable_nebula, - nebula_config_params=nebula_config_params) + config_params=config_params) return_items = [ engine, @@ -206,26 +200,6 @@ def _add_core_arguments(parser): "Run via MPI, this will attempt to discover the necessary variables to initialize torch " "distributed from the MPI environment") - group.add_argument( - '--nebula', - default=False, - action='store_true', - help= - "Save checkpoint via torch_nebula.save, this will attempt to save the time from torch.save" - ) - - group.add_argument('--persistent_storage_path', - default=None, - help="Iter3 path for persistence") - - group.add_argument('--persistent_time_interval', - default=None, - help="Time interval for tier3 saving") - - group.add_argument('--num_of_version_in_retention', - default=2, - help="File numbers to be remained") - return parser diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 6cb1d3faf619..cf3e98dc25bb 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -20,35 +20,14 @@ from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from ..nebula.constants import NEBULA_EXPORT_ENVS from ..utils import logger from ..autotuning import Autotuner DLTS_HOSTFILE = "/job/hostfile" -EXPORT_ENVS = [ - 'NCCL', - 'PYTHON', - 'MV2', - 'UCX', - 'DLTS_JOB_ID', - 'DLTS_NUM_WORKER', - 'NEBULA_PERSISTENT_STORAGE_PATH', - 'NEBULA_PERSISTENT_TIME_INTERVAL', - 'AML_RUN_ID', - 'AZUREML_RUN_TOKEN', - 'AZUREML_WORKSPACE_SCOPE', - 'AZUREML_EXPERIMENT_SCOPE', - 'AZUREML_RUN_HISTORY_SERVICE_ENDPOINT', - 'AZUREML_RUN_ID', - 'NEBULA_MEMORY_BUFFER_SIZE', - 'AZUREML_PARAMETER_ITPJOB_NAME', - 'FC_TASKROLE_NAME', - 'FC_TASK_INDEX', - 'MASTER_HOST', - 'LOCAL_HOST', - 'AZUREML_BLOB_ACCOUNT_NAME', - 'AZUREML_BLOB_ACCOUNT_KEY' -] +EXPORT_ENVS = ['NCCL', 'PYTHON', 'MV2', 'UCX'] +EXPORT_ENVS += NEBULA_EXPORT_ENVS DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env" DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.'] PDSH_MAX_FAN_OUT = 1024 diff --git a/deepspeed/nebula/config.py b/deepspeed/nebula/config.py new file mode 100644 index 000000000000..2cdab173d7b1 --- /dev/null +++ b/deepspeed/nebula/config.py @@ -0,0 +1,42 @@ +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" + +from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject +from deepspeed.nebula.constants import * + + +class DeepSpeedNebulaConfig(DeepSpeedConfigObject): + def __init__(self, param_dict): + super(DeepSpeedNebulaConfig, self).__init__() + + self.enabled = None + self.persistent_storage_path = None + self.persistent_time_interval = None + self.num_of_version_in_retention = None + + if NEBULA in param_dict.keys(): + nebula_dict = param_dict[NEBULA] + else: + nebula_dict = {} + + self._initialize(nebula_dict) + + def _initialize(self, nebula_dict): + self.enabled = get_scalar_param(nebula_dict, + NEBULA_ENABLED, + NEBULA_ENABLED_DEFAULT) + + self.persistent_storage_path = get_scalar_param(nebula_dict, + NEBULA_PERSISTENT_STORAGE_PATH, + NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT) + + self.persistent_time_interval = get_scalar_param(nebula_dict, + NEBULA_PERSISTENT_TIME_INTERVAL, + NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT) + + self.num_of_version_in_retention = get_scalar_param(nebula_dict, + NEBULA_NUM_OF_VERSION_IN_RETENTION, + NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT) + diff --git a/deepspeed/nebula/constants.py b/deepspeed/nebula/constants.py new file mode 100644 index 000000000000..458b09f05a22 --- /dev/null +++ b/deepspeed/nebula/constants.py @@ -0,0 +1,56 @@ +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" + +######################################### +# nebula +######################################### +# Nebula. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +NEBULA_FORMAT = ''' +nebula should be enabled as: +"session_params": { + "nebula": { + "enabled": true, + "persistent_storage_path": "/foo/bar", + "persistent_time_interval": 100, + "num_of_version_in_retention": 2 + } +} +''' + +NEBULA = "nebula" + +NEBULA_ENABLED = "enabled" +NEBULA_ENABLED_DEFAULT = False + +NEBULA_PERSISTENT_STORAGE_PATH = "persistent_storage_path" +NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT = None + +NEBULA_PERSISTENT_TIME_INTERVAL = "persistent_time_interval" +NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT = 100 + +NEBULA_NUM_OF_VERSION_IN_RETENTION = "num_of_version_in_retention" +NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2 + +NEBULA_EXPORT_ENVS = [ + 'DLTS_JOB_ID', + 'DLTS_NUM_WORKER', + 'NEBULA_PERSISTENT_STORAGE_PATH', + 'NEBULA_PERSISTENT_TIME_INTERVAL', + 'AML_RUN_ID', + 'AZUREML_RUN_TOKEN', + 'AZUREML_WORKSPACE_SCOPE', + 'AZUREML_EXPERIMENT_SCOPE', + 'AZUREML_RUN_HISTORY_SERVICE_ENDPOINT', + 'AZUREML_RUN_ID', + 'NEBULA_MEMORY_BUFFER_SIZE', + 'AZUREML_PARAMETER_ITPJOB_NAME', + 'FC_TASKROLE_NAME', + 'FC_TASK_INDEX', + 'MASTER_HOST', + 'LOCAL_HOST', + 'AZUREML_BLOB_ACCOUNT_NAME', + 'AZUREML_BLOB_ACCOUNT_KEY' +] diff --git a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py index bdd2ed381f38..84d6005a8520 100644 --- a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py @@ -3,7 +3,6 @@ class CheckpointEngine(object): - def __init__(self): return @@ -13,7 +12,11 @@ def save(self, state_dict, path: str, tag: str): logger.info(f"Saved {path} under tag{tag}.") return None - def load(self, path: str, tag: str = None, persist_path: str = None, map_location=None): + def load(self, + path: str, + tag: str = None, + persist_path: str = None, + map_location=None): logger.info(f"Loading {path} under tag{tag} from {persist_path}...") partition = torch.load(path, map_location=map_location) logger.info(f"Loaded {path} under tag{tag} from {persist_path}.") diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index a99a2b84d83f..cf919296d97a 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -1,22 +1,21 @@ import os +import torch_nebula from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ CheckpointEngine from deepspeed.utils import logger - -try: - import torch_nebula -except ImportError: - logger.warning( - "Warning: cannot find the package of torch-nebula. Will use the legacy way for checkpoint management." - ) - torch_nebula = None +from deepspeed.nebula.constants import * class NebulaCheckpointEngine(CheckpointEngine): - def __init__(self, config_params=None): - torch_nebula.init(**config_params) + nebula_config_params = { + NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path, + NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_storage_path, + NEBULA_NUM_OF_VERSION_IN_RETENTION: + config_params.num_of_version_in_retention, + } + torch_nebula.init(**nebula_config_params) def save(self, state_dict, path: str, tag: str): logger.info(f"[Nebula] Saving {path} under tag{tag}...") @@ -29,7 +28,11 @@ def save(self, state_dict, path: str, tag: str): logger.info(f"[Nebula] Saved {path} under tag{tag}.") return None - def load(self, path: str, tag: str = None, persist_path: str = None, map_location=None): + def load(self, + path: str, + tag: str = None, + persist_path: str = None, + map_location=None): logger.info(f"[Nebula] Loading {path} under tag{tag} from {persist_path}...") partititon_name = os.path.basename(path) checkpoint = None diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 4571cbdf7056..f5370bc3a602 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -45,6 +45,7 @@ from ..profiling.config import DeepSpeedFlopsProfilerConfig from ..autotuning.config import DeepSpeedAutotuningConfig +from ..nebula.config import DeepSpeedNebulaConfig from .swap_tensor.aio_config import get_aio_config @@ -950,6 +951,8 @@ def _initialize_params(self, param_dict): self.dataloader_drop_last = get_dataloader_drop_last(param_dict) + self.nebula_config = DeepSpeedNebulaConfig(param_dict) + def _batch_assertion(self): train_batch = self.train_batch_size diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f5167dbfcc42..a8ea422a6389 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -189,8 +189,6 @@ def __init__( config=None, config_params=None, dont_change_device=False, - enable_nebula=None, - nebula_config_params=None, ): super(DeepSpeedEngine, self).__init__() self.dont_change_device = dont_change_device @@ -225,8 +223,6 @@ def __init__( self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. self.checkpoint_engine = None - self.enable_nebula = enable_nebula - self.nebula_config_params = nebula_config_params global dist from deepspeed import comm as dist @@ -774,11 +770,11 @@ def _configure_lr_scheduler(self, client_lr_scheduler): log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) def _configure_checkpointing(self, dist_init_required): - if self.enable_nebula: + if self.config.nebula_config.enabled: from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ NebulaCheckpointEngine self.checkpoint_engine = NebulaCheckpointEngine( - config_params=self.nebula_config_params) + config_params=self.config.nebula_config) else: self.checkpoint_engine = CheckpointEngine() @@ -2488,7 +2484,7 @@ def load_checkpoint(self, before ``load_checkpoint()``. """ checkpoint_engine_tmp = self.checkpoint_engine - if self.enable_nebula and enable_nebula_load == False: + if self.config.nebula_config.enabled and enable_nebula_load == False: self.checkpoint_engine = CheckpointEngine() self.persist_path = nebula_load_path_tier3 From 4cbdfe6fe3434d1f5ac6b3e706becfd91b52bb7e Mon Sep 17 00:00:00 2001 From: trajepl Date: Mon, 11 Jul 2022 16:31:19 +0800 Subject: [PATCH 03/25] add __init__.py for nebula importing --- deepspeed/nebula/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 deepspeed/nebula/__init__.py diff --git a/deepspeed/nebula/__init__.py b/deepspeed/nebula/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From 1f2f40c3b82909d78caf085e6a1d6e8b6405e0cc Mon Sep 17 00:00:00 2001 From: trajepl Date: Mon, 11 Jul 2022 16:35:46 +0800 Subject: [PATCH 04/25] linter fix --- deepspeed/nebula/config.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/deepspeed/nebula/config.py b/deepspeed/nebula/config.py index 2cdab173d7b1..374e3f85bef1 100644 --- a/deepspeed/nebula/config.py +++ b/deepspeed/nebula/config.py @@ -28,15 +28,17 @@ def _initialize(self, nebula_dict): NEBULA_ENABLED, NEBULA_ENABLED_DEFAULT) - self.persistent_storage_path = get_scalar_param(nebula_dict, - NEBULA_PERSISTENT_STORAGE_PATH, - NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT) - - self.persistent_time_interval = get_scalar_param(nebula_dict, - NEBULA_PERSISTENT_TIME_INTERVAL, - NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT) - - self.num_of_version_in_retention = get_scalar_param(nebula_dict, - NEBULA_NUM_OF_VERSION_IN_RETENTION, - NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT) - + self.persistent_storage_path = get_scalar_param( + nebula_dict, + NEBULA_PERSISTENT_STORAGE_PATH, + NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT) + + self.persistent_time_interval = get_scalar_param( + nebula_dict, + NEBULA_PERSISTENT_TIME_INTERVAL, + NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT) + + self.num_of_version_in_retention = get_scalar_param( + nebula_dict, + NEBULA_NUM_OF_VERSION_IN_RETENTION, + NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT) From d90014595b5ea0818067fbc857afe1943533d7c0 Mon Sep 17 00:00:00 2001 From: trajepl Date: Mon, 11 Jul 2022 16:45:03 +0800 Subject: [PATCH 05/25] fix: ds_config is None --- deepspeed/runtime/engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a8ea422a6389..875a9c442ad8 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -770,7 +770,7 @@ def _configure_lr_scheduler(self, client_lr_scheduler): log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) def _configure_checkpointing(self, dist_init_required): - if self.config.nebula_config.enabled: + if self.config is not None and self.config.nebula_config.enabled: from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ NebulaCheckpointEngine self.checkpoint_engine = NebulaCheckpointEngine( @@ -2484,7 +2484,9 @@ def load_checkpoint(self, before ``load_checkpoint()``. """ checkpoint_engine_tmp = self.checkpoint_engine - if self.config.nebula_config.enabled and enable_nebula_load == False: + if not self.config is None and \ + self.config.nebula_config.enabled and \ + enable_nebula_load == False: self.checkpoint_engine = CheckpointEngine() self.persist_path = nebula_load_path_tier3 From b44832b04b8fa58929dafb160c34f1e72f7229f7 Mon Sep 17 00:00:00 2001 From: trajepl Date: Mon, 11 Jul 2022 17:42:34 +0800 Subject: [PATCH 06/25] fix: ds config --- deepspeed/runtime/engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 875a9c442ad8..d95d9065ad3f 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -770,11 +770,11 @@ def _configure_lr_scheduler(self, client_lr_scheduler): log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) def _configure_checkpointing(self, dist_init_required): - if self.config is not None and self.config.nebula_config.enabled: + if self._config is not None and self._config.nebula_config.enabled: from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ NebulaCheckpointEngine self.checkpoint_engine = NebulaCheckpointEngine( - config_params=self.config.nebula_config) + config_params=self._config.nebula_config) else: self.checkpoint_engine = CheckpointEngine() @@ -2485,7 +2485,7 @@ def load_checkpoint(self, """ checkpoint_engine_tmp = self.checkpoint_engine if not self.config is None and \ - self.config.nebula_config.enabled and \ + self._config.nebula_config.enabled and \ enable_nebula_load == False: self.checkpoint_engine = CheckpointEngine() self.persist_path = nebula_load_path_tier3 From e4a57bdf57eb702f5c6c4f14606e1dbb5b169441 Mon Sep 17 00:00:00 2001 From: trajepl Date: Mon, 11 Jul 2022 20:18:58 +0800 Subject: [PATCH 07/25] fix: get sd loader fix --- deepspeed/runtime/checkpoint_engine/checkpoint_engine.py | 2 +- deepspeed/runtime/engine.py | 4 +++- deepspeed/runtime/state_dict_factory.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py index 84d6005a8520..c44e9d757920 100644 --- a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py @@ -4,7 +4,7 @@ class CheckpointEngine(object): def __init__(self): - return + pass def save(self, state_dict, path: str, tag: str): logger.info(f"Saving {path} under tag{tag}...") diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d95d9065ad3f..5cf69b090356 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2543,7 +2543,9 @@ def _load_checkpoint(self, from deepspeed.runtime.state_dict_factory import SDLoaderFactory ckpt_list = self._get_all_ckpt_names(load_dir, tag) - sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list) + sd_loader = SDLoaderFactory.get_sd_loader( + ckpt_list, + checkpoint_engine=self.checkpoint_engine) is_pipe_parallel = isinstance(self.module, PipelineModule) diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 2470c2a42cf4..d266319fec1a 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -53,12 +53,12 @@ def __init__(self, checkpoint_engine=None): self.module_key = None self.ckpt_list = ckpt_list - self.check_ckpt_list() self.version = version self.tag = tag self.persist_path = persist_path self.checkpoint_engine = CheckpointEngine( ) if checkpoint_engine is None else checkpoint_engine + self.check_ckpt_list() def load(self, mp_world_size, From d70bcd160f3f74fbc50310f21ca54aca221c9f5c Mon Sep 17 00:00:00 2001 From: trajepl Date: Tue, 12 Jul 2022 15:40:18 +0800 Subject: [PATCH 08/25] align the API with torch raw code --- deepspeed/nebula/config.py | 5 +++ deepspeed/nebula/constants.py | 3 ++ .../checkpoint_engine/checkpoint_engine.py | 17 ++++---- .../nebula_checkpoint_engine.py | 40 ++++++++++++------- deepspeed/runtime/engine.py | 33 +++++---------- deepspeed/runtime/pipe/engine.py | 2 - deepspeed/runtime/pipe/module.py | 9 ++--- deepspeed/runtime/state_dict_factory.py | 22 +--------- 8 files changed, 57 insertions(+), 74 deletions(-) diff --git a/deepspeed/nebula/config.py b/deepspeed/nebula/config.py index 374e3f85bef1..313dcd952f4f 100644 --- a/deepspeed/nebula/config.py +++ b/deepspeed/nebula/config.py @@ -28,6 +28,11 @@ def _initialize(self, nebula_dict): NEBULA_ENABLED, NEBULA_ENABLED_DEFAULT) + self.load_path = get_scalar_param( + nebula_dict, + NEBULA_LOAD_PATH, + NEBULA_LOAD_PATH_DEFAULT) + self.persistent_storage_path = get_scalar_param( nebula_dict, NEBULA_PERSISTENT_STORAGE_PATH, diff --git a/deepspeed/nebula/constants.py b/deepspeed/nebula/constants.py index 458b09f05a22..eba56b05beac 100644 --- a/deepspeed/nebula/constants.py +++ b/deepspeed/nebula/constants.py @@ -25,6 +25,9 @@ NEBULA_ENABLED = "enabled" NEBULA_ENABLED_DEFAULT = False +NEBULA_LOAD_PATH = "persistent_storage_path" +NEBULA_LOAD_PATH_DEFAULT = None + NEBULA_PERSISTENT_STORAGE_PATH = "persistent_storage_path" NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT = None diff --git a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py index c44e9d757920..fd5847299658 100644 --- a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py @@ -3,23 +3,20 @@ class CheckpointEngine(object): + def __init__(self): pass - def save(self, state_dict, path: str, tag: str): - logger.info(f"Saving {path} under tag{tag}...") + def save(self, state_dict, path: str): + logger.info(f"Saving {path}...") torch.save(state_dict, path) - logger.info(f"Saved {path} under tag{tag}.") + logger.info(f"Saved {path}.") return None - def load(self, - path: str, - tag: str = None, - persist_path: str = None, - map_location=None): - logger.info(f"Loading {path} under tag{tag} from {persist_path}...") + def load(self, path: str, map_location=None): + logger.info(f"Loading checkpoint from {path}...") partition = torch.load(path, map_location=map_location) - logger.info(f"Loaded {path} under tag{tag} from {persist_path}.") + logger.info(f"Loaded checkpoint from {path}.") return partition def commit(self, tag): diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index cf919296d97a..d56537eb3524 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -7,44 +7,56 @@ from deepspeed.nebula.constants import * +def _get_tag_from_path(path): + return os.path.basename(os.path.dirname(path)) + + class NebulaCheckpointEngine(CheckpointEngine): + def __init__(self, config_params=None): + self.nebula_load_path = config_params.load_path + if self.nebula_load_path is None: + self.nebula_load_path = config_params.persistent_storage_path + nebula_config_params = { NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path, - NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_storage_path, + NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval, NEBULA_NUM_OF_VERSION_IN_RETENTION: config_params.num_of_version_in_retention, } torch_nebula.init(**nebula_config_params) - def save(self, state_dict, path: str, tag: str): - logger.info(f"[Nebula] Saving {path} under tag{tag}...") + def save(self, state_dict, path: str): + tag = _get_tag_from_path(path) partititon_name = os.path.basename(path) + logger.info(f"[Nebula] Saving {partititon_name} under tag{tag}...") # -2 means: customer needs to explicitly tell nebula # current checkpoint is complete by commit methond. checkpoint = torch_nebula.Checkpoint(tag, -2) checkpoint.save(partititon_name, state_dict) - logger.info(f"[Nebula] Saved {path} under tag{tag}.") + logger.info(f"[Nebula] Saved {partititon_name} under tag{tag}.") return None - def load(self, - path: str, - tag: str = None, - persist_path: str = None, - map_location=None): - logger.info(f"[Nebula] Loading {path} under tag{tag} from {persist_path}...") + def load(self, path: str, map_location=None): + tag = _get_tag_from_path(path) partititon_name = os.path.basename(path) + logger.info( + f"[Nebula] Loading {path} under tag{tag} from {self.nebula_load_path}...") + checkpoint = None if tag is None: - checkpoint = torch_nebula.get_latest_checkpoint(persist_path=persist_path) + checkpoint = torch_nebula.get_latest_checkpoint( + persist_path=self.nebula_load_path) if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''): logger.warning(f"Unable to find latest valid checkpoint from Nebula!") return None else: - checkpoint = torch_nebula.get_checkpoint(tag=tag, persist_path=persist_path) - partition = checkpoint.load(partititon_name) - logger.info(f"[Nebula] Loaded {path} under tag{tag} from {persist_path}.") + checkpoint = torch_nebula.get_checkpoint(tag=tag, + persist_path=self.nebula_load_path) + partition = checkpoint.load(partititon_name, map_location=map_location) + logger.info( + f"[Nebula] Loaded {path} under tag{tag} from {self.nebula_load_path}.") return partition def commit(self, tag): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5cf69b090356..e6b3cb5da63a 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -48,7 +48,7 @@ from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT from deepspeed.runtime.sparse_tensor import SparseTensor -from deepspeed.runtime import checkpoint_engine, lr_schedules +from deepspeed.runtime import lr_schedules from deepspeed.utils import groups from deepspeed.runtime.utils import get_grad_norm from deepspeed.utils import logger, log_dist, instrument_w_nvtx @@ -2298,7 +2298,6 @@ def load_moe_state_dict(checkpoint_path, model=None, mpu=None, num_experts=1, - persist_path=None, checkpoint_engine=CheckpointEngine()): if old_moe_load: expp_rank = groups._get_expert_data_parallel_rank( @@ -2315,9 +2314,7 @@ def load_moe_state_dict(checkpoint_path, global_expert_id, tag, mpu), - map_location=torch.device('cpu'), - tag=tag, - persist_path=persist_path) + map_location=torch.device('cpu')) # Updating global -> local expert ids moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' @@ -2344,9 +2341,7 @@ def load_moe_state_dict(checkpoint_path, global_expert_id, tag, mpu), - map_location=torch.device('cpu'), - tag=tag, - persist_path=persist_path) + map_location=torch.device('cpu')) # print(expert_state_dict.keys()) # Updating global -> local expert ids moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' @@ -2573,7 +2568,6 @@ def _load_checkpoint(self, model=self.module, mpu=self.mpu, num_experts=self.num_experts, - persist_path=self.persist_path, checkpoint_engine=self.checkpoint_engine) self.load_module_state_dict(state_dict=checkpoint['module'], @@ -2593,9 +2587,7 @@ def _load_checkpoint(self, optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) optim_checkpoint = self.checkpoint_engine.load( optim_load_path, - map_location=torch.device('cpu'), - tag=tag, - persist_path=self.persist_path) + map_location=torch.device('cpu')) else: optim_checkpoint = checkpoint @@ -2753,9 +2745,7 @@ def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names, tag=None): if self.zero_elastic_checkpoint() or dist.get_rank( group=self.optimizer.dp_process_group) == i: _state = self.checkpoint_engine.load(ckpt_name, - map_location='cpu', - tag=tag, - persist_path=self.persist_path) + map_location='cpu',) else: _state = {OPTIMIZER_STATE_DICT: None} zero_sd_list.append(_state) @@ -2926,8 +2916,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): tag, self.mpu) self.checkpoint_engine.save(expert_state_dict, - moe_save_path, - tag=tag) + moe_save_path) moe_layer_id += 1 self._curr_ckpt_path = os.path.join(save_dir, tag) @@ -2950,7 +2939,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): } # TODO: why use BufferedWriter not the path file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) - self.checkpoint_engine.save(optimizer_state, file_path, tag=tag) + self.checkpoint_engine.save(optimizer_state, file_path) # get non-moe parameters model_state_dict = self._get_non_moe_state_dict(self.module_state_dict()) @@ -2980,7 +2969,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): } state.update(client_state) logger.info(f'Saving model checkpoint: {save_path}') - self.checkpoint_engine.save(state, save_path, tag=tag) + self.checkpoint_engine.save(state, save_path) self._curr_save_path = None def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): @@ -3033,7 +3022,7 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): state.update(client_state) log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1]) - self.checkpoint_engine.save(state, save_path, tag=tag) + self.checkpoint_engine.save(state, save_path) self._curr_save_path = None def _get_buffer_names(self): @@ -3116,7 +3105,7 @@ def _save_zero_checkpoint(self, save_path, tag): zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version) - self.checkpoint_engine.save(zero_sd, zero_checkpoint_name, tag=tag) + self.checkpoint_engine.save(zero_sd, zero_checkpoint_name) if self.global_rank == 0: self._copy_recovery_script(save_path) @@ -3236,6 +3225,6 @@ def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin", tag=None if dist.get_rank() == 0: os.makedirs(save_dir, exist_ok=True) logger.info(f"Saving model weights to {path}") - self.checkpoint_engine.save(state_dict, path, tag=tag) + self.checkpoint_engine.save(state_dict, path) return True diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 74cd3f5f4ebe..179d97051aec 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1347,8 +1347,6 @@ def load_module_state_dict(self, self.module.load_state_dir(load_dir=self._curr_ckpt_path, strict=strict, - tag=tag, - persist_path=self.persist_path, checkpoint_engine=self.checkpoint_engine) # A map of PipeInstruction types to methods. Each method will be executed with the diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 9f687418a889..191a8a05fb10 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -564,7 +564,7 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx): ckpt_files.sort() return ckpt_files - def save_state_dict(self, save_dir, tag=None, checkpoint_engine=None): + def save_state_dict(self, save_dir, checkpoint_engine=None): if self._grid.data_parallel_id != 0: return @@ -585,13 +585,11 @@ def save_state_dict(self, save_dir, tag=None, checkpoint_engine=None): {k: v.clone() for k, v in orig_state_dict.items()}) - checkpoint_engine.save(final_state_dict, model_ckpt_path, tag) + checkpoint_engine.save(final_state_dict, model_ckpt_path) def load_state_dir(self, load_dir, strict=True, - tag=None, - persist_path=None, checkpoint_engine=CheckpointEngine()): for idx, layer in enumerate(self.forward_funcs): # Functions, etc. will not have state_dicts @@ -603,8 +601,7 @@ def load_state_dir(self, mp_rank = self._grid.get_slice_parallel_rank() mp_world_size = self._grid.get_slice_parallel_world_size() - sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, version=2.0, tag=tag,\ - persist_path=persist_path, checkpoint_engine=checkpoint_engine) + sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, version=2.0, checkpoint_engine=checkpoint_engine) load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True) layer.load_state_dict(checkpoint) diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index d266319fec1a..347ba9a82017 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -31,14 +31,10 @@ def get_sd_loader_json(json_file): def get_sd_loader(ckpt_list, sd_type='Megatron', version=None, - tag=None, - persist_path=None, checkpoint_engine=None): if sd_type == 'Megatron': return MegatronSDLoader(ckpt_list, version, - tag, - persist_path, checkpoint_engine) else: assert False, '{} checkpoint type is not supported'.format(sd_type) @@ -48,14 +44,10 @@ class SDLoaderBase(ABC): def __init__(self, ckpt_list, version, - tag=None, - persist_path=None, checkpoint_engine=None): self.module_key = None self.ckpt_list = ckpt_list self.version = version - self.tag = tag - self.persist_path = persist_path self.checkpoint_engine = CheckpointEngine( ) if checkpoint_engine is None else checkpoint_engine self.check_ckpt_list() @@ -101,7 +93,7 @@ def load(self, assert os.path.exists(load_path) #logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}') sd = self.checkpoint_engine.load(load_path, map_location=lambda storage, \ - loc: storage, tag=self.tag, persist_path=self.persist_path) + loc: storage) if quantize: quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, @@ -131,8 +123,6 @@ def get_merge_state_dicts(self, mp_world_size, mp_rank): logger.info(f"mp_rank: {mp_rank}, ckpt_list: {ckpt_list}") sd_list = [ self.checkpoint_engine.load(ckpt, - tag=self.tag, - persist_path=self.persist_path, map_location=lambda storage, loc: storage) for ckpt in ckpt_list ] @@ -151,8 +141,6 @@ def get_split_state_dict(self, mp_world_size, mp_rank): ) sd = self.checkpoint_engine.load(self.ckpt_list[ckpt_index], - tag=self.tag, - persist_path=self.persist_path, map_location=lambda storage, loc: storage) @@ -188,8 +176,6 @@ def check_ckpt_list(self): assert len(self.ckpt_list) > 0 sd = self.checkpoint_engine.load(self.ckpt_list[0], - tag=self.tag, - persist_path=self.persist_path, map_location=lambda storage, loc: storage) @@ -226,10 +212,8 @@ class MegatronSDLoader(SDLoaderBase): def __init__(self, ckpt_list, version, - tag=None, - persist_path=None, checkpoint_engine=None): - super().__init__(ckpt_list, version, tag, persist_path, checkpoint_engine) + super().__init__(ckpt_list, version, checkpoint_engine) """ ## Q/K/V data need special processing key: transformer.layers.0.attention.query_key_value.weight, shape: torch.Size([3192, 4256]) @@ -467,8 +451,6 @@ def sanity_check(self, ckpt_file_name): ] sd = self.checkpoint_engine.load(ckpt_file_name, - tag=self.tag, - persist_path=self.persist_path, map_location=lambda storage, loc: storage) From 5d987a086c1ec90d6b9702bee58ebbf6aa09b1b1 Mon Sep 17 00:00:00 2001 From: trajepl Date: Tue, 12 Jul 2022 15:49:58 +0800 Subject: [PATCH 09/25] linter fix --- deepspeed/nebula/config.py | 7 +++---- .../runtime/checkpoint_engine/checkpoint_engine.py | 1 - .../checkpoint_engine/nebula_checkpoint_engine.py | 1 - deepspeed/runtime/engine.py | 9 +++++---- deepspeed/runtime/pipe/module.py | 5 ++++- deepspeed/runtime/state_dict_factory.py | 14 +++----------- 6 files changed, 15 insertions(+), 22 deletions(-) diff --git a/deepspeed/nebula/config.py b/deepspeed/nebula/config.py index 313dcd952f4f..2b2cf9a83d8e 100644 --- a/deepspeed/nebula/config.py +++ b/deepspeed/nebula/config.py @@ -28,10 +28,9 @@ def _initialize(self, nebula_dict): NEBULA_ENABLED, NEBULA_ENABLED_DEFAULT) - self.load_path = get_scalar_param( - nebula_dict, - NEBULA_LOAD_PATH, - NEBULA_LOAD_PATH_DEFAULT) + self.load_path = get_scalar_param(nebula_dict, + NEBULA_LOAD_PATH, + NEBULA_LOAD_PATH_DEFAULT) self.persistent_storage_path = get_scalar_param( nebula_dict, diff --git a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py index fd5847299658..43ac8b5a8732 100644 --- a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py @@ -3,7 +3,6 @@ class CheckpointEngine(object): - def __init__(self): pass diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index d56537eb3524..5ff663fc8064 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -12,7 +12,6 @@ def _get_tag_from_path(path): class NebulaCheckpointEngine(CheckpointEngine): - def __init__(self, config_params=None): self.nebula_load_path = config_params.load_path if self.nebula_load_path is None: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ed8793658752..27ffe4cd4c58 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2746,8 +2746,10 @@ def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names, tag=None): # Fully load state for current rank if self.zero_elastic_checkpoint() or dist.get_rank( group=self.optimizer.dp_process_group) == i: - _state = self.checkpoint_engine.load(ckpt_name, - map_location='cpu',) + _state = self.checkpoint_engine.load( + ckpt_name, + map_location='cpu', + ) else: _state = {OPTIMIZER_STATE_DICT: None} zero_sd_list.append(_state) @@ -2917,8 +2919,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): global_expert_id, tag, self.mpu) - self.checkpoint_engine.save(expert_state_dict, - moe_save_path) + self.checkpoint_engine.save(expert_state_dict, moe_save_path) moe_layer_id += 1 self._curr_ckpt_path = os.path.join(save_dir, tag) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 191a8a05fb10..ded70bdcf79d 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -601,7 +601,10 @@ def load_state_dir(self, mp_rank = self._grid.get_slice_parallel_rank() mp_world_size = self._grid.get_slice_parallel_world_size() - sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, version=2.0, checkpoint_engine=checkpoint_engine) + sd_loader = SDLoaderFactory.get_sd_loader( + model_ckpt_list, + version=2.0, + checkpoint_engine=checkpoint_engine) load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True) layer.load_state_dict(checkpoint) diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 347ba9a82017..9cd6c6a9c3ea 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -33,18 +33,13 @@ def get_sd_loader(ckpt_list, version=None, checkpoint_engine=None): if sd_type == 'Megatron': - return MegatronSDLoader(ckpt_list, - version, - checkpoint_engine) + return MegatronSDLoader(ckpt_list, version, checkpoint_engine) else: assert False, '{} checkpoint type is not supported'.format(sd_type) class SDLoaderBase(ABC): - def __init__(self, - ckpt_list, - version, - checkpoint_engine=None): + def __init__(self, ckpt_list, version, checkpoint_engine=None): self.module_key = None self.ckpt_list = ckpt_list self.version = version @@ -209,10 +204,7 @@ def sanity_check(self, ckpt_file_name): class MegatronSDLoader(SDLoaderBase): - def __init__(self, - ckpt_list, - version, - checkpoint_engine=None): + def __init__(self, ckpt_list, version, checkpoint_engine=None): super().__init__(ckpt_list, version, checkpoint_engine) """ ## Q/K/V data need special processing From a04a81a8e97065f7832db551045b281f91bc5eb4 Mon Sep 17 00:00:00 2001 From: trajepl Date: Tue, 12 Jul 2022 15:55:04 +0800 Subject: [PATCH 10/25] remove duplicate tag params --- deepspeed/runtime/engine.py | 10 +++++----- deepspeed/runtime/pipe/engine.py | 9 ++------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 27ffe4cd4c58..33df454e640d 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2739,7 +2739,7 @@ def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode): return zero_ckpt_names - def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names, tag=None): + def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): zero_sd_list = [] for i, ckpt_name in enumerate(zero_ckpt_names): _state = None @@ -2774,7 +2774,7 @@ def _get_all_zero_checkpoints(self, load_dir, tag): logger.warn( f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine' ) - return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names, tag) + return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) return None @@ -3186,12 +3186,12 @@ def get_layer_state_dict(module, prefix=""): return state_dict - def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin", tag=None): + def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): """has been renamed to save_16bit_model, keeping this around for backwards compatibility""" - return self.save_16bit_model(save_dir, save_filename, tag=tag) + return self.save_16bit_model(save_dir, save_filename) - def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin", tag=None): + def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): r"""Save 16bit model weights This method saves the 16bit model weights at the desired destination. diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 179d97051aec..2835efcebcba 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1306,7 +1306,7 @@ def mem_status(self, msg, print_rank=-1, reset_max=False): f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)' ) - def module_state_dict(self, tag=None): + def module_state_dict(self): """Override hack to save a pipe model and return the directory path of the save. This method should only be called by DeepSpeed's ``save_checkpoint()``. The @@ -1321,15 +1321,10 @@ def module_state_dict(self, tag=None): "PipelineEngine expects module_state_dict() to be called from save_checkpoint()" self.module.save_state_dict(self._curr_ckpt_path, - tag=tag, checkpoint_engine=self.checkpoint_engine) return None - def load_module_state_dict(self, - state_dict, - strict=True, - custom_load_fn=None, - tag=None): + def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): """Override hack to instead use a directory path. This is important because pipeline models checkpoint by layer instead of rank. From 4b42bc232ec4e698d32bfd9a135f3087e4c567ed Mon Sep 17 00:00:00 2001 From: trajepl Date: Thu, 21 Jul 2022 17:08:08 +0800 Subject: [PATCH 11/25] make checkpoint_engine as required args --- deepspeed/inference/engine.py | 7 +++++-- deepspeed/runtime/pipe/module.py | 7 ++----- deepspeed/runtime/state_dict_factory.py | 16 ++++++++-------- tests/unit/test_checkpointing.py | 6 ++++-- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index b5e68aaf297f..8f9ca86410a0 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -9,6 +9,7 @@ from torch.nn.modules import Module from packaging import version as pkg_version +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization @@ -90,6 +91,7 @@ def __init__(self, self.expert_mp_group = expert_mp_group self.enable_cuda_graph = enable_cuda_graph self.cuda_graph_created = False + self.checkpoint_engine = CheckpointEngine() self._init_quantization_setting(quantization_setting) if enable_cuda_graph: @@ -299,9 +301,10 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): tag = fd.read().strip() ckpt_list = self._get_all_ckpt_names(load_dir, tag) - sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list) + sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine) else: - sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir) + sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, + self.checkpoint_engine) mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index ded70bdcf79d..9618fb666ada 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -564,7 +564,7 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx): ckpt_files.sort() return ckpt_files - def save_state_dict(self, save_dir, checkpoint_engine=None): + def save_state_dict(self, save_dir, checkpoint_engine): if self._grid.data_parallel_id != 0: return @@ -587,10 +587,7 @@ def save_state_dict(self, save_dir, checkpoint_engine=None): v in orig_state_dict.items()}) checkpoint_engine.save(final_state_dict, model_ckpt_path) - def load_state_dir(self, - load_dir, - strict=True, - checkpoint_engine=CheckpointEngine()): + def load_state_dir(self, load_dir, checkpoint_engine, strict=True): for idx, layer in enumerate(self.forward_funcs): # Functions, etc. will not have state_dicts if not hasattr(layer, 'load_state_dict'): diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 9cd6c6a9c3ea..480849407a67 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -19,19 +19,19 @@ class SDLoaderFactory: @staticmethod - def get_sd_loader_json(json_file): + def get_sd_loader_json(json_file, checkpoint_engine): with open(json_file) as f: data = json.load(f) sd_type = data['type'] ckpt_list = data['checkpoints'] version = data['version'] - return SDLoaderFactory.get_sd_loader(ckpt_list, sd_type, version) + return SDLoaderFactory.get_sd_loader(ckpt_list, + checkpoint_engine, + sd_type, + version) @staticmethod - def get_sd_loader(ckpt_list, - sd_type='Megatron', - version=None, - checkpoint_engine=None): + def get_sd_loader(ckpt_list, checkpoint_engine, sd_type='Megatron', version=None): if sd_type == 'Megatron': return MegatronSDLoader(ckpt_list, version, checkpoint_engine) else: @@ -39,7 +39,7 @@ def get_sd_loader(ckpt_list, class SDLoaderBase(ABC): - def __init__(self, ckpt_list, version, checkpoint_engine=None): + def __init__(self, ckpt_list, version, checkpoint_engine): self.module_key = None self.ckpt_list = ckpt_list self.version = version @@ -204,7 +204,7 @@ def sanity_check(self, ckpt_file_name): class MegatronSDLoader(SDLoaderBase): - def __init__(self, ckpt_list, version, checkpoint_engine=None): + def __init__(self, ckpt_list, version, checkpoint_engine): super().__init__(ckpt_list, version, checkpoint_engine) """ ## Q/K/V data need special processing diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index ddac8a3dcd02..56151a431b1b 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -7,6 +7,7 @@ from deepspeed.utils import groups from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer from deepspeed.runtime.pipe.topology import * @@ -728,13 +729,14 @@ def _test(save_folder, num_stages): def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir): @distributed_test(world_size=4) def _test(base_topo, test_topo, save_folder): + checkpoint_engine = CheckpointEngine() base_model = LinearStackPipe(topology=base_topo) - base_model.save_state_dict(save_folder) + base_model.save_state_dict(save_folder, checkpoint_engine=checkpoint_engine) dist.barrier() test_model = LinearStackPipe(topology=test_topo) - test_model.load_state_dir(save_folder) + test_model.load_state_dir(save_folder, checkpoint_engine=checkpoint_engine) # Base and test can have different lengths, so make sure we map from the # smaller to larger model From 19063988eeab469e35d8173aa1743d4fba074d4e Mon Sep 17 00:00:00 2001 From: trajepl Date: Thu, 21 Jul 2022 17:55:02 +0800 Subject: [PATCH 12/25] fix args --- deepspeed/inference/engine.py | 5 ++++- deepspeed/runtime/engine.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index f08fea85beef..dd04d41c321c 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -1,6 +1,7 @@ ''' Copyright 2021 The Microsoft DeepSpeed Team ''' +from deepspeed.runtime.checkpoint_engine import checkpoint_engine import torch import os @@ -421,10 +422,12 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): state_dict=checkpoint[self._choose_module_key(checkpoint)], old_moe_load=old_moe_load, model=self.module, - mpu=self.mpu) + mpu=self.mpu, + checkpoint_engine=self.checkpoint_engine) self.module.load_state_dict( state_dict=checkpoint[self._choose_module_key(checkpoint)], + checkpoint_engine=self.checkpoint_engine, strict=load_module_strict) def _choose_module_key(self, sd): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ae1c664701cf..aa470dd77fb4 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2398,7 +2398,9 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): if custom_load_fn: custom_load_fn(src=state_dict, dst=self.module) else: - self.module.load_state_dict(state_dict, strict=strict) + self.module.load_state_dict(state_dict, + checkpoint_engine=self.checkpoint_engine, + strict=strict) def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}' From 432e7c67261544d5d733d2596e518bcf0e388f08 Mon Sep 17 00:00:00 2001 From: trajepl Date: Thu, 21 Jul 2022 22:00:20 +0800 Subject: [PATCH 13/25] extract parameters out to config --- deepspeed/nebula/config.py | 10 ++++++++++ deepspeed/nebula/constants.py | 6 ++++++ deepspeed/runtime/engine.py | 22 +++++++++++++--------- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/deepspeed/nebula/config.py b/deepspeed/nebula/config.py index 2b2cf9a83d8e..f82f4418e8a6 100644 --- a/deepspeed/nebula/config.py +++ b/deepspeed/nebula/config.py @@ -15,6 +15,8 @@ def __init__(self, param_dict): self.persistent_storage_path = None self.persistent_time_interval = None self.num_of_version_in_retention = None + self.enable_nebula_load = None + self.load_path_tier3 = None if NEBULA in param_dict.keys(): nebula_dict = param_dict[NEBULA] @@ -46,3 +48,11 @@ def _initialize(self, nebula_dict): nebula_dict, NEBULA_NUM_OF_VERSION_IN_RETENTION, NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT) + + self.enable_nebula_load = get_scalar_param(nebula_dict, + NEBULA_ENABLE_NEBULA_LOAD, + NEBULA_ENABLE_NEBULA_LOAD_DEFAULT) + + self.nebula_load_path_tier3 = get_scalar_param(nebula_dict, + NEBULA_LOAD_PATH_TIER3, + NEBULA_LOAD_PATH_TIER3_DEFAULT) diff --git a/deepspeed/nebula/constants.py b/deepspeed/nebula/constants.py index eba56b05beac..bee82ffae089 100644 --- a/deepspeed/nebula/constants.py +++ b/deepspeed/nebula/constants.py @@ -37,6 +37,12 @@ NEBULA_NUM_OF_VERSION_IN_RETENTION = "num_of_version_in_retention" NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2 +NEBULA_ENABLE_NEBULA_LOAD = "nebula_enable_nebula_load" +NEBULA_ENABLE_NEBULA_LOAD_DEFAULT = True + +NEBULA_LOAD_PATH_TIER3 = "nebula_load_path_tier3" +NEBULA_LOAD_PATH_TIER3_DEFAULT = None + NEBULA_EXPORT_ENVS = [ 'DLTS_JOB_ID', 'DLTS_NUM_WORKER', diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index aa470dd77fb4..188b7eb0f243 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -801,10 +801,16 @@ def _configure_lr_scheduler(self, client_lr_scheduler): def _configure_checkpointing(self, dist_init_required): if self._config is not None and self._config.nebula_config.enabled: - from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ - NebulaCheckpointEngine - self.checkpoint_engine = NebulaCheckpointEngine( - config_params=self._config.nebula_config) + try: + from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ + NebulaCheckpointEngine + self.checkpoint_engine = NebulaCheckpointEngine( + config_params=self._config.nebula_config) + except ImportError as err: + logger.error( + f"No torch_nebula was found! Will fall back to torch.save. Details: {err}" + ) + self.checkpoint_engine = CheckpointEngine() else: self.checkpoint_engine = CheckpointEngine() @@ -2496,9 +2502,7 @@ def load_checkpoint(self, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False, - custom_load_fn=None, - enable_nebula_load=True, - nebula_load_path_tier3=None): + custom_load_fn=None): """Load training checkpoint Arguments: load_dir: Required. Directory to load the checkpoint from @@ -2520,9 +2524,9 @@ def load_checkpoint(self, checkpoint_engine_tmp = self.checkpoint_engine if not self.config is None and \ self._config.nebula_config.enabled and \ - enable_nebula_load == False: + self._config.nebula_config.enable_nebula_load == False: self.checkpoint_engine = CheckpointEngine() - self.persist_path = nebula_load_path_tier3 + self.persist_path = self._config.nebula_config.load_path_tier3 if tag is None: latest_tag = "latest_universal" if self.load_universal_checkpoint( From 7dbb6d8aa6cd155f9dd077e025628425663973dd Mon Sep 17 00:00:00 2001 From: trajepl Date: Fri, 22 Jul 2022 01:25:19 +0800 Subject: [PATCH 14/25] fix: load state dict --- deepspeed/runtime/engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 188b7eb0f243..bb3b0915ab90 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2404,8 +2404,7 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): if custom_load_fn: custom_load_fn(src=state_dict, dst=self.module) else: - self.module.load_state_dict(state_dict, - checkpoint_engine=self.checkpoint_engine, + self.module.load_state_dict(state_dict, # TODO strict=strict) def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): From e912e31bd70e20c713e2419d713674247196e691 Mon Sep 17 00:00:00 2001 From: trajepl Date: Fri, 22 Jul 2022 07:56:53 +0800 Subject: [PATCH 15/25] separate load engine --- deepspeed/runtime/engine.py | 22 ++++++++++------------ deepspeed/runtime/pipe/engine.py | 2 +- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index bb3b0915ab90..878d5560ad6d 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -245,6 +245,7 @@ def __init__( self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. self.checkpoint_engine = None + self.checkpoint_load_engine = None global dist from deepspeed import comm as dist @@ -800,19 +801,22 @@ def _configure_lr_scheduler(self, client_lr_scheduler): log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) def _configure_checkpointing(self, dist_init_required): + self.checkpoint_engine = CheckpointEngine() + self.checkpoint_load_engine = self.checkpoint_engine + if self._config is not None and self._config.nebula_config.enabled: try: from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ NebulaCheckpointEngine self.checkpoint_engine = NebulaCheckpointEngine( config_params=self._config.nebula_config) + if self._config.nebula_config.enable_nebula_load: + self.checkpoint_load_engine = self.checkpoint_engine except ImportError as err: logger.error( f"No torch_nebula was found! Will fall back to torch.save. Details: {err}" ) self.checkpoint_engine = CheckpointEngine() - else: - self.checkpoint_engine = CheckpointEngine() dp_rank = self.global_rank if self.mpu: @@ -2520,11 +2524,6 @@ def load_checkpoint(self, ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine before ``load_checkpoint()``. """ - checkpoint_engine_tmp = self.checkpoint_engine - if not self.config is None and \ - self._config.nebula_config.enabled and \ - self._config.nebula_config.enable_nebula_load == False: - self.checkpoint_engine = CheckpointEngine() self.persist_path = self._config.nebula_config.load_path_tier3 if tag is None: @@ -2570,7 +2569,6 @@ def load_checkpoint(self, if self.zero_optimization_partition_weights(): self.optimizer.checkpoint_event_epilogue() - self.checkpoint_engine = checkpoint_engine_tmp self.persist_path = None return load_path, client_states @@ -2589,7 +2587,7 @@ def _load_checkpoint(self, ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader( ckpt_list, - checkpoint_engine=self.checkpoint_engine) + checkpoint_engine=self.checkpoint_load_engine) is_pipe_parallel = isinstance(self.module, PipelineModule) @@ -2617,7 +2615,7 @@ def _load_checkpoint(self, model=self.module, mpu=self.mpu, num_experts=self.num_experts, - checkpoint_engine=self.checkpoint_engine) + checkpoint_engine=self.checkpoint_load_engine) if not self.load_universal_checkpoint(): self.load_module_state_dict(state_dict=checkpoint['module'], strict=load_module_strict, @@ -2634,7 +2632,7 @@ def _load_checkpoint(self, largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) - optim_checkpoint = self.checkpoint_engine.load( + optim_checkpoint = self.checkpoint_load_engine.load( optim_load_path, map_location=torch.device('cpu')) else: @@ -2803,7 +2801,7 @@ def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): # Fully load state for current rank if self.zero_elastic_checkpoint() or dist.get_rank( group=self.optimizer.dp_process_group) == i: - _state = self.checkpoint_engine.load( + _state = self.checkpoint_load_engine.load( ckpt_name, map_location='cpu', ) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 2835efcebcba..8711e0f909cb 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1342,7 +1342,7 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): self.module.load_state_dir(load_dir=self._curr_ckpt_path, strict=strict, - checkpoint_engine=self.checkpoint_engine) + checkpoint_engine=self.checkpoint_load_engine) # A map of PipeInstruction types to methods. Each method will be executed with the # kwargs provided to the PipeInstruction from the scheduler. From 7fc279bff7f8ac617d4a1746bb6cfa3adc87e6ce Mon Sep 17 00:00:00 2001 From: trajepl Date: Fri, 22 Jul 2022 08:40:20 +0800 Subject: [PATCH 16/25] linter fix --- deepspeed/runtime/engine.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 878d5560ad6d..730f6c1829ed 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2608,14 +2608,15 @@ def _load_checkpoint(self, old_moe_load = False if not isinstance(checkpoint['num_experts'], list): old_moe_load = True - DeepSpeedEngine.load_moe_state_dict(load_dir, - tag, - state_dict=checkpoint['module'], - old_moe_load=old_moe_load, - model=self.module, - mpu=self.mpu, - num_experts=self.num_experts, - checkpoint_engine=self.checkpoint_load_engine) + DeepSpeedEngine.load_moe_state_dict( + load_dir, + tag, + state_dict=checkpoint['module'], + old_moe_load=old_moe_load, + model=self.module, + mpu=self.mpu, + num_experts=self.num_experts, + checkpoint_engine=self.checkpoint_load_engine) if not self.load_universal_checkpoint(): self.load_module_state_dict(state_dict=checkpoint['module'], strict=load_module_strict, From c70c8187f7391bef3f76702665558efc4824b0c2 Mon Sep 17 00:00:00 2001 From: trajepl Date: Fri, 22 Jul 2022 11:09:40 +0800 Subject: [PATCH 17/25] extract checkpoint engine to abstract calss --- deepspeed/inference/engine.py | 4 +-- .../checkpoint_engine/checkpoint_engine.py | 24 +++++++-------- .../nebula_checkpoint_engine.py | 22 +++++++++----- .../torch_checkpoint_engine.py | 29 +++++++++++++++++++ deepspeed/runtime/engine.py | 9 +++--- deepspeed/runtime/pipe/module.py | 2 +- deepspeed/runtime/state_dict_factory.py | 4 +-- tests/unit/test_checkpointing.py | 4 +-- 8 files changed, 66 insertions(+), 32 deletions(-) create mode 100644 deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index dd04d41c321c..57c7f704fb0b 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -11,7 +11,7 @@ from torch.nn.modules import Module from packaging import version as pkg_version -from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization @@ -96,7 +96,7 @@ def __init__(self, self.expert_mp_group = expert_mp_group self.enable_cuda_graph = enable_cuda_graph self.cuda_graph_created = False - self.checkpoint_engine = CheckpointEngine() + self.checkpoint_engine = TorchCheckpointEngine() self._init_quantization_setting(quantization_setting) if enable_cuda_graph: diff --git a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py index 43ac8b5a8732..67c888815242 100644 --- a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py @@ -1,23 +1,19 @@ -import torch -from deepspeed.utils import logger +class CheckpointEngine(object): + # init checkpoint engine for save/load + def __init__(self, config_params=None): + pass -class CheckpointEngine(object): - def __init__(self): + def create(self, tag): + # create checkpoint on give tag for save/load. pass def save(self, state_dict, path: str): - logger.info(f"Saving {path}...") - torch.save(state_dict, path) - logger.info(f"Saved {path}.") - return None + pass def load(self, path: str, map_location=None): - logger.info(f"Loading checkpoint from {path}...") - partition = torch.load(path, map_location=map_location) - logger.info(f"Loaded checkpoint from {path}.") - return partition + pass def commit(self, tag): - print(f"Checkpoint {tag} is ready now!") - return True + # to tell checkpoint services if all files are readys. + pass diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index 5ff663fc8064..3a80bdda7e06 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -12,7 +12,9 @@ def _get_tag_from_path(path): class NebulaCheckpointEngine(CheckpointEngine): + def __init__(self, config_params=None): + self.checkpoint = None self.nebula_load_path = config_params.load_path if self.nebula_load_path is None: self.nebula_load_path = config_params.persistent_storage_path @@ -25,15 +27,17 @@ def __init__(self, config_params=None): } torch_nebula.init(**nebula_config_params) + def create(self, tag): + logger.info(f"[Nebula] Start Checkpoint for tag:{tag}") + # -2 means: customer needs to explicitly tell nebula + # current checkpoint is complete by commit methond. + self.checkpoint = torch_nebula.Checkpoint(tag, -2) + def save(self, state_dict, path: str): tag = _get_tag_from_path(path) partititon_name = os.path.basename(path) logger.info(f"[Nebula] Saving {partititon_name} under tag{tag}...") - - # -2 means: customer needs to explicitly tell nebula - # current checkpoint is complete by commit methond. - checkpoint = torch_nebula.Checkpoint(tag, -2) - checkpoint.save(partititon_name, state_dict) + self.checkpoint.save(partititon_name, state_dict) logger.info(f"[Nebula] Saved {partititon_name} under tag{tag}.") return None @@ -59,6 +63,10 @@ def load(self, path: str, map_location=None): return partition def commit(self, tag): - checkpoint = torch_nebula.Checkpoint(tag, -2) - commit_rls = checkpoint.commit() + # nebula commit will be call when all files under give tag are ready to be persisted in the async way. + logger.info(f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting") + commit_rls = self.checkpoint.commit() + if not commit_rls: + logger.error(f"[Nebula] failed to commit the checkpoint, please check the log.") + return False return commit_rls diff --git a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py new file mode 100644 index 000000000000..176640775178 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py @@ -0,0 +1,29 @@ +import torch +from deepspeed.utils import logger +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ + CheckpointEngine + + +class TorchCheckpointEngine(CheckpointEngine): + + def __init__(self, config_params=None): + super().__init__(self, config_params) + + def create(self, tag): + print(f"[Torch] Checkpoint {tag} is begin to save!") + + def save(self, state_dict, path: str): + logger.info(f"[Torch] Saving {path}...") + torch.save(state_dict, path) + logger.info(f"[Torch] Saved {path}.") + return None + + def load(self, path: str, map_location=None): + logger.info(f"[Torch] Loading checkpoint from {path}...") + partition = torch.load(path, map_location=map_location) + logger.info(f"[Torch] Loaded checkpoint from {path}.") + return partition + + def commit(self, tag): + print(f"[Torch] Checkpoint {tag} is ready now!") + return True diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 730f6c1829ed..95efca324eaf 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -82,7 +82,7 @@ from deepspeed.runtime.utils import clip_grad_norm_ from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler -from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from .pipe.module import PipelineModule from .utils import ensure_directory_exists, get_ma_status @@ -801,7 +801,7 @@ def _configure_lr_scheduler(self, client_lr_scheduler): log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) def _configure_checkpointing(self, dist_init_required): - self.checkpoint_engine = CheckpointEngine() + self.checkpoint_engine = TorchCheckpointEngine() self.checkpoint_load_engine = self.checkpoint_engine if self._config is not None and self._config.nebula_config.enabled: @@ -816,7 +816,7 @@ def _configure_checkpointing(self, dist_init_required): logger.error( f"No torch_nebula was found! Will fall back to torch.save. Details: {err}" ) - self.checkpoint_engine = CheckpointEngine() + self.checkpoint_engine = TorchCheckpointEngine() dp_rank = self.global_rank if self.mpu: @@ -2349,7 +2349,7 @@ def load_moe_state_dict(checkpoint_path, model=None, mpu=None, num_experts=1, - checkpoint_engine=CheckpointEngine()): + checkpoint_engine=TorchCheckpointEngine()): if old_moe_load: expp_rank = groups._get_expert_data_parallel_rank( groups._get_max_expert_size_name()) @@ -2881,6 +2881,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) # Ensure tag is a string tag = str(tag) + self.checkpoint_engine.create(tag) # Ensure checkpoint tag is consistent across ranks self._checkpoint_tag_validation(tag) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 9618fb666ada..7ad4da25920b 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -6,7 +6,7 @@ from collections import defaultdict from functools import partial -from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine import torch import torch.nn as nn diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 0e6f8f49847b..6097e8baa004 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from deepspeed.utils import logger -from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from .weight_quantizer import WeightQuantization @@ -45,7 +45,7 @@ def __init__(self, ckpt_list, version, checkpoint_engine): self.module_key = None self.ckpt_list = ckpt_list self.version = version - self.checkpoint_engine = CheckpointEngine( + self.checkpoint_engine = TorchCheckpointEngine( ) if checkpoint_engine is None else checkpoint_engine self.check_ckpt_list() diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 005af9498fb9..86345400c637 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -7,7 +7,7 @@ from deepspeed.utils import groups from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer -from deepspeed.runtime.checkpoint_engine.checkpoint_engine import CheckpointEngine +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer from deepspeed.runtime.pipe.topology import * @@ -736,7 +736,7 @@ def _test(save_folder, num_stages): def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir): @distributed_test(world_size=4) def _test(base_topo, test_topo, save_folder): - checkpoint_engine = CheckpointEngine() + checkpoint_engine = TorchCheckpointEngine() base_model = LinearStackPipe(topology=base_topo) base_model.save_state_dict(save_folder, checkpoint_engine=checkpoint_engine) From e6dd7943b0669a1a84589b241329cfc2b0864d08 Mon Sep 17 00:00:00 2001 From: trajepl Date: Fri, 22 Jul 2022 11:37:06 +0800 Subject: [PATCH 18/25] linter fix --- .../runtime/checkpoint_engine/nebula_checkpoint_engine.py | 8 +++++--- .../runtime/checkpoint_engine/torch_checkpoint_engine.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index 3a80bdda7e06..7eb65f290dad 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -12,7 +12,6 @@ def _get_tag_from_path(path): class NebulaCheckpointEngine(CheckpointEngine): - def __init__(self, config_params=None): self.checkpoint = None self.nebula_load_path = config_params.load_path @@ -64,9 +63,12 @@ def load(self, path: str, map_location=None): def commit(self, tag): # nebula commit will be call when all files under give tag are ready to be persisted in the async way. - logger.info(f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting") + logger.info( + f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting" + ) commit_rls = self.checkpoint.commit() if not commit_rls: - logger.error(f"[Nebula] failed to commit the checkpoint, please check the log.") + logger.error( + f"[Nebula] failed to commit the checkpoint, please check the log.") return False return commit_rls diff --git a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py index 176640775178..8bda42873500 100644 --- a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py @@ -5,7 +5,6 @@ class TorchCheckpointEngine(CheckpointEngine): - def __init__(self, config_params=None): super().__init__(self, config_params) From 1efd2ce079107e838e6740eda0853c82048039e8 Mon Sep 17 00:00:00 2001 From: trajepl Date: Fri, 22 Jul 2022 13:28:27 +0800 Subject: [PATCH 19/25] construct function args fix --- deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py | 1 + deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index 7eb65f290dad..39b9f6ffac7b 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -13,6 +13,7 @@ def _get_tag_from_path(path): class NebulaCheckpointEngine(CheckpointEngine): def __init__(self, config_params=None): + super().__init__(config_params) self.checkpoint = None self.nebula_load_path = config_params.load_path if self.nebula_load_path is None: diff --git a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py index 8bda42873500..ae46c877a43d 100644 --- a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py @@ -6,7 +6,7 @@ class TorchCheckpointEngine(CheckpointEngine): def __init__(self, config_params=None): - super().__init__(self, config_params) + super().__init__(config_params) def create(self, tag): print(f"[Torch] Checkpoint {tag} is begin to save!") From dce0fb5182cde6991eaa510ed9739695ee4665bb Mon Sep 17 00:00:00 2001 From: trajepl Date: Mon, 25 Jul 2022 11:08:19 +0800 Subject: [PATCH 20/25] add docs for dev/customers --- deepspeed/nebula/config.py | 13 ++----- deepspeed/nebula/constants.py | 34 +++++++++++++---- deepspeed/runtime/checkpoint_engine/README.md | 37 +++++++++++++++++++ deepspeed/runtime/engine.py | 3 -- 4 files changed, 67 insertions(+), 20 deletions(-) create mode 100644 deepspeed/runtime/checkpoint_engine/README.md diff --git a/deepspeed/nebula/config.py b/deepspeed/nebula/config.py index f82f4418e8a6..f9928d66147c 100644 --- a/deepspeed/nebula/config.py +++ b/deepspeed/nebula/config.py @@ -16,7 +16,6 @@ def __init__(self, param_dict): self.persistent_time_interval = None self.num_of_version_in_retention = None self.enable_nebula_load = None - self.load_path_tier3 = None if NEBULA in param_dict.keys(): nebula_dict = param_dict[NEBULA] @@ -34,6 +33,10 @@ def _initialize(self, nebula_dict): NEBULA_LOAD_PATH, NEBULA_LOAD_PATH_DEFAULT) + self.enable_nebula_load = get_scalar_param(nebula_dict, + NEBULA_ENABLE_NEBULA_LOAD, + NEBULA_ENABLE_NEBULA_LOAD_DEFAULT) + self.persistent_storage_path = get_scalar_param( nebula_dict, NEBULA_PERSISTENT_STORAGE_PATH, @@ -48,11 +51,3 @@ def _initialize(self, nebula_dict): nebula_dict, NEBULA_NUM_OF_VERSION_IN_RETENTION, NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT) - - self.enable_nebula_load = get_scalar_param(nebula_dict, - NEBULA_ENABLE_NEBULA_LOAD, - NEBULA_ENABLE_NEBULA_LOAD_DEFAULT) - - self.nebula_load_path_tier3 = get_scalar_param(nebula_dict, - NEBULA_LOAD_PATH_TIER3, - NEBULA_LOAD_PATH_TIER3_DEFAULT) diff --git a/deepspeed/nebula/constants.py b/deepspeed/nebula/constants.py index bee82ffae089..6c4fadbd6965 100644 --- a/deepspeed/nebula/constants.py +++ b/deepspeed/nebula/constants.py @@ -15,7 +15,8 @@ "enabled": true, "persistent_storage_path": "/foo/bar", "persistent_time_interval": 100, - "num_of_version_in_retention": 2 + "num_of_version_in_retention": 2, + "enable_nebula_load": true } } ''' @@ -25,24 +26,41 @@ NEBULA_ENABLED = "enabled" NEBULA_ENABLED_DEFAULT = False -NEBULA_LOAD_PATH = "persistent_storage_path" +# There is a case where customer want to load the checkpoint saved +# by raw torch. Because nebula cannot load torch checkpoint directly +# as they have different folder structures to bring the gap for +# loading(the data are totaly same in bytes for torch and enbula s +# aving). +# In this case, we must disable nebula load to use raw torch load. +# Customer can just set NEBULA_ENABLE_NEBULA_LOAD to False. Then use +# original way of deepspeed to load, i.e. set the value of "--load". +NEBULA_ENABLE_NEBULA_LOAD = "enable_nebula_load" +NEBULA_ENABLE_NEBULA_LOAD_DEFAULT = True + +# When you want to resume the previous checkpoint saved by nebula, +# you can set NEBULA_LOAD_PATH as the parent folder of checkpoint. +# If NEBULA_LOAD_PATH is None, the NEBULA_PERSISTENT_STORAGE_PATH +# will be the default path to load. +NEBULA_LOAD_PATH = "nebula_load_path" NEBULA_LOAD_PATH_DEFAULT = None +# Nebula will save the checkpoint under NEBULA_LOAD_PATH in the +# asynchronous way. NEBULA_PERSISTENT_STORAGE_PATH = "persistent_storage_path" NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT = None +# Time interval to trigger the nebula persistence. NEBULA_PERSISTENT_TIME_INTERVAL = "persistent_time_interval" NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT = 100 +# Checkpoint number which will be kept in memory. Let us say, +# if the value is 2. Then we have checkpoints 1 and 2 are ready +# now. When it comes to checkpoint 3, the 1 will be removed if +# 1 has been persisted to disk. NEBULA_NUM_OF_VERSION_IN_RETENTION = "num_of_version_in_retention" NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2 -NEBULA_ENABLE_NEBULA_LOAD = "nebula_enable_nebula_load" -NEBULA_ENABLE_NEBULA_LOAD_DEFAULT = True - -NEBULA_LOAD_PATH_TIER3 = "nebula_load_path_tier3" -NEBULA_LOAD_PATH_TIER3_DEFAULT = None - +# Neubla envs NEBULA_EXPORT_ENVS = [ 'DLTS_JOB_ID', 'DLTS_NUM_WORKER', diff --git a/deepspeed/runtime/checkpoint_engine/README.md b/deepspeed/runtime/checkpoint_engine/README.md new file mode 100644 index 000000000000..fbb3702a68d7 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/README.md @@ -0,0 +1,37 @@ +# Checkpoint Engine + + +The `CheckpointEngine` was designed to modularized the checkpoint serialization. In this way, we can simply replace/refine the checkpoint serialization methods. + +### Interface for `CheckpointEngine` + +Basically, for checkpoint management(save/load by deepspeed with the given tag), the `CheckpointEngine` will: + + 1. To make preliminaries ready by call `create(tag)`. For `torch`, we can just log some extra info as `torch` can directly call `save/load` without other preparation. + + 2. After the `create(tag)`, deepspeed can call `save/load` to persist files into disk/memory/etc. + + 3. When all the files for a tag are ready, deepspeed engine will call `commit()` to tell the checkpoint engine current checkpoint is complete. For original torch, it also plays the role of logger. + + +```python +class CheckpointEngine(object): + # init checkpoint engine for save/load + def __init__(self, config_params=None): + pass + + def create(self, tag): + # create checkpoint on give tag for save/load. + pass + + def save(self, state_dict, path: str): + pass + + def load(self, path: str, map_location=None): + pass + + def commit(self, tag): + # to tell checkpoint services if all files are readys. + pass + +``` diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 95efca324eaf..77cf42923075 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2524,7 +2524,6 @@ def load_checkpoint(self, ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine before ``load_checkpoint()``. """ - self.persist_path = self._config.nebula_config.load_path_tier3 if tag is None: latest_tag = "latest_universal" if self.load_universal_checkpoint( @@ -2569,8 +2568,6 @@ def load_checkpoint(self, if self.zero_optimization_partition_weights(): self.optimizer.checkpoint_event_epilogue() - self.persist_path = None - return load_path, client_states def _load_checkpoint(self, From bb5bb7c6bcc79c564c8ee1aa01c0f44644b06c77 Mon Sep 17 00:00:00 2001 From: trajepl Date: Mon, 25 Jul 2022 11:13:14 +0800 Subject: [PATCH 21/25] linter fix --- deepspeed/runtime/checkpoint_engine/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/checkpoint_engine/README.md b/deepspeed/runtime/checkpoint_engine/README.md index fbb3702a68d7..a19f54889802 100644 --- a/deepspeed/runtime/checkpoint_engine/README.md +++ b/deepspeed/runtime/checkpoint_engine/README.md @@ -1,7 +1,7 @@ # Checkpoint Engine -The `CheckpointEngine` was designed to modularized the checkpoint serialization. In this way, we can simply replace/refine the checkpoint serialization methods. +The `CheckpointEngine` was designed to modularized the checkpoint serialization. In this way, we can simply replace/refine the checkpoint serialization methods. ### Interface for `CheckpointEngine` @@ -12,7 +12,7 @@ Basically, for checkpoint management(save/load by deepspeed with the given tag), 2. After the `create(tag)`, deepspeed can call `save/load` to persist files into disk/memory/etc. 3. When all the files for a tag are ready, deepspeed engine will call `commit()` to tell the checkpoint engine current checkpoint is complete. For original torch, it also plays the role of logger. - + ```python class CheckpointEngine(object): From 3e8c238fd42846ccb8b34c2d0c10745bdb241928 Mon Sep 17 00:00:00 2001 From: trajepl Date: Tue, 26 Jul 2022 11:04:14 +0800 Subject: [PATCH 22/25] remove load engine --- deepspeed/constants.py | 1 - deepspeed/launcher/launch.py | 3 ++- deepspeed/nebula/constants.py | 3 +++ .../checkpoint_engine/nebula_checkpoint_engine.py | 12 ++++++++++-- deepspeed/runtime/engine.py | 12 ++++-------- deepspeed/runtime/pipe/engine.py | 2 +- 6 files changed, 20 insertions(+), 13 deletions(-) diff --git a/deepspeed/constants.py b/deepspeed/constants.py index 8b4cd415d1e4..9576c9c078d5 100644 --- a/deepspeed/constants.py +++ b/deepspeed/constants.py @@ -17,4 +17,3 @@ INFERENCE_GENERIC_MODE = 'generic' INFERENCE_SPECIALIZED_MODE = 'specialized' -DLTS_POD_ENV_PATH = '/dlts-runtime/env/pod.env' diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index 731c8a216119..c7113c13f53d 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -18,7 +18,8 @@ from collections import defaultdict from argparse import ArgumentParser, REMAINDER -from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, DLTS_POD_ENV_PATH +from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from ..nebula.constants import DLTS_POD_ENV_PATH from ..utils import logger PID_FILE_BASEPATH = "/tmp" diff --git a/deepspeed/nebula/constants.py b/deepspeed/nebula/constants.py index 6c4fadbd6965..0e66fa8d1536 100644 --- a/deepspeed/nebula/constants.py +++ b/deepspeed/nebula/constants.py @@ -81,3 +81,6 @@ 'AZUREML_BLOB_ACCOUNT_NAME', 'AZUREML_BLOB_ACCOUNT_KEY' ] + +# ITP env files +DLTS_POD_ENV_PATH = '/dlts-runtime/env/pod.env' diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index 39b9f6ffac7b..99cda5cb9904 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -1,9 +1,10 @@ import os +import torch import torch_nebula from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ CheckpointEngine -from deepspeed.utils import logger +from deepspeed.utils import logger, log_dist from deepspeed.nebula.constants import * @@ -15,6 +16,7 @@ class NebulaCheckpointEngine(CheckpointEngine): def __init__(self, config_params=None): super().__init__(config_params) self.checkpoint = None + self.enable_nebula_load = config_params.enable_nebula_load self.nebula_load_path = config_params.load_path if self.nebula_load_path is None: self.nebula_load_path = config_params.persistent_storage_path @@ -28,7 +30,7 @@ def __init__(self, config_params=None): torch_nebula.init(**nebula_config_params) def create(self, tag): - logger.info(f"[Nebula] Start Checkpoint for tag:{tag}") + log_dist(f"[Nebula] Start Checkpoint for tag:{tag}", ranks=[0]) # -2 means: customer needs to explicitly tell nebula # current checkpoint is complete by commit methond. self.checkpoint = torch_nebula.Checkpoint(tag, -2) @@ -42,6 +44,12 @@ def save(self, state_dict, path: str): return None def load(self, path: str, map_location=None): + if not self.enable_nebula_load: + logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path}...") + partition = torch.load(path, map_location=map_location) + logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path}...") + return partition + tag = _get_tag_from_path(path) partititon_name = os.path.basename(path) logger.info( diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 77cf42923075..6634a89698f5 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -245,7 +245,6 @@ def __init__( self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. self.checkpoint_engine = None - self.checkpoint_load_engine = None global dist from deepspeed import comm as dist @@ -802,7 +801,6 @@ def _configure_lr_scheduler(self, client_lr_scheduler): def _configure_checkpointing(self, dist_init_required): self.checkpoint_engine = TorchCheckpointEngine() - self.checkpoint_load_engine = self.checkpoint_engine if self._config is not None and self._config.nebula_config.enabled: try: @@ -810,8 +808,6 @@ def _configure_checkpointing(self, dist_init_required): NebulaCheckpointEngine self.checkpoint_engine = NebulaCheckpointEngine( config_params=self._config.nebula_config) - if self._config.nebula_config.enable_nebula_load: - self.checkpoint_load_engine = self.checkpoint_engine except ImportError as err: logger.error( f"No torch_nebula was found! Will fall back to torch.save. Details: {err}" @@ -2584,7 +2580,7 @@ def _load_checkpoint(self, ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader( ckpt_list, - checkpoint_engine=self.checkpoint_load_engine) + checkpoint_engine=self.checkpoint_engine) is_pipe_parallel = isinstance(self.module, PipelineModule) @@ -2613,7 +2609,7 @@ def _load_checkpoint(self, model=self.module, mpu=self.mpu, num_experts=self.num_experts, - checkpoint_engine=self.checkpoint_load_engine) + checkpoint_engine=self.checkpoint_engine) if not self.load_universal_checkpoint(): self.load_module_state_dict(state_dict=checkpoint['module'], strict=load_module_strict, @@ -2630,7 +2626,7 @@ def _load_checkpoint(self, largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) - optim_checkpoint = self.checkpoint_load_engine.load( + optim_checkpoint = self.checkpoint_engine.load( optim_load_path, map_location=torch.device('cpu')) else: @@ -2799,7 +2795,7 @@ def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): # Fully load state for current rank if self.zero_elastic_checkpoint() or dist.get_rank( group=self.optimizer.dp_process_group) == i: - _state = self.checkpoint_load_engine.load( + _state = self.checkpoint_engine.load( ckpt_name, map_location='cpu', ) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 8711e0f909cb..2835efcebcba 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1342,7 +1342,7 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): self.module.load_state_dir(load_dir=self._curr_ckpt_path, strict=strict, - checkpoint_engine=self.checkpoint_load_engine) + checkpoint_engine=self.checkpoint_engine) # A map of PipeInstruction types to methods. Each method will be executed with the # kwargs provided to the PipeInstruction from the scheduler. From a5c889748bfc5e4cc667d7996d8d54288f741f7a Mon Sep 17 00:00:00 2001 From: trajepl Date: Tue, 26 Jul 2022 11:11:24 +0800 Subject: [PATCH 23/25] print->log_dist --- .../nebula_checkpoint_engine.py | 6 ++++-- .../torch_checkpoint_engine.py | 6 +++--- deepspeed/runtime/engine.py | 17 ++++++++--------- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index 99cda5cb9904..90978bdcdddf 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -45,9 +45,11 @@ def save(self, state_dict, path: str): def load(self, path: str, map_location=None): if not self.enable_nebula_load: - logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path}...") + logger.info( + f"[Nebula] Disable nebula load. Loading checkpoint from {path}...") partition = torch.load(path, map_location=map_location) - logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path}...") + logger.info( + f"[Nebula] Disable nebula load. Loaded checkpoint from {path}...") return partition tag = _get_tag_from_path(path) diff --git a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py index ae46c877a43d..9b4942f0a01f 100644 --- a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py @@ -1,5 +1,5 @@ import torch -from deepspeed.utils import logger +from deepspeed.utils import logger, log_dist from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ CheckpointEngine @@ -9,7 +9,7 @@ def __init__(self, config_params=None): super().__init__(config_params) def create(self, tag): - print(f"[Torch] Checkpoint {tag} is begin to save!") + log_dist(f"[Torch] Checkpoint {tag} is begin to save!", ranks=[0]) def save(self, state_dict, path: str): logger.info(f"[Torch] Saving {path}...") @@ -24,5 +24,5 @@ def load(self, path: str, map_location=None): return partition def commit(self, tag): - print(f"[Torch] Checkpoint {tag} is ready now!") + logger.info(f"[Torch] Checkpoint {tag} is ready now!") return True diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6634a89698f5..81a3b0e3b732 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2601,15 +2601,14 @@ def _load_checkpoint(self, old_moe_load = False if not isinstance(checkpoint['num_experts'], list): old_moe_load = True - DeepSpeedEngine.load_moe_state_dict( - load_dir, - tag, - state_dict=checkpoint['module'], - old_moe_load=old_moe_load, - model=self.module, - mpu=self.mpu, - num_experts=self.num_experts, - checkpoint_engine=self.checkpoint_engine) + DeepSpeedEngine.load_moe_state_dict(load_dir, + tag, + state_dict=checkpoint['module'], + old_moe_load=old_moe_load, + model=self.module, + mpu=self.mpu, + num_experts=self.num_experts, + checkpoint_engine=self.checkpoint_engine) if not self.load_universal_checkpoint(): self.load_module_state_dict(state_dict=checkpoint['module'], strict=load_module_strict, From 44d687b9d3f2a220ff9b0e7365799f823037407a Mon Sep 17 00:00:00 2001 From: trajepl Date: Tue, 26 Jul 2022 11:57:12 +0800 Subject: [PATCH 24/25] linter fix --- deepspeed/inference/engine.py | 1 - deepspeed/runtime/pipe/module.py | 1 - 2 files changed, 2 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 57c7f704fb0b..af9cf57ee2a9 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -1,7 +1,6 @@ ''' Copyright 2021 The Microsoft DeepSpeed Team ''' -from deepspeed.runtime.checkpoint_engine import checkpoint_engine import torch import os diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 7ad4da25920b..466809b9d3ce 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -6,7 +6,6 @@ from collections import defaultdict from functools import partial -from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine import torch import torch.nn as nn From 82ad297a832a149aa779631827e171c8d088f3aa Mon Sep 17 00:00:00 2001 From: trajepl Date: Tue, 26 Jul 2022 20:08:46 +0800 Subject: [PATCH 25/25] add tag flag to distinguish the loading order --- .../runtime/checkpoint_engine/nebula_checkpoint_engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index 90978bdcdddf..b7ff5cdf9829 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -16,6 +16,7 @@ class NebulaCheckpointEngine(CheckpointEngine): def __init__(self, config_params=None): super().__init__(config_params) self.checkpoint = None + self.tag_flag = None self.enable_nebula_load = config_params.enable_nebula_load self.nebula_load_path = config_params.load_path if self.nebula_load_path is None: @@ -44,7 +45,10 @@ def save(self, state_dict, path: str): return None def load(self, path: str, map_location=None): - if not self.enable_nebula_load: + tag = _get_tag_from_path(path) + first_load_flag = self.tag_flag is None or self.tag_flag == tag + if not self.enable_nebula_load and first_load_flag: + self.tag_flag = tag logger.info( f"[Nebula] Disable nebula load. Loading checkpoint from {path}...") partition = torch.load(path, map_location=map_location) @@ -52,7 +56,6 @@ def load(self, path: str, map_location=None): f"[Nebula] Disable nebula load. Loaded checkpoint from {path}...") return partition - tag = _get_tag_from_path(path) partititon_name = os.path.basename(path) logger.info( f"[Nebula] Loading {path} under tag{tag} from {self.nebula_load_path}...")