diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 173f2587de70..8af9a12c2809 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -9,6 +9,7 @@ from torch.nn.modules import Module from packaging import version as pkg_version +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization @@ -92,6 +93,7 @@ def __init__(self, self.expert_mp_group = expert_mp_group self.enable_cuda_graph = enable_cuda_graph self.cuda_graph_created = False + self.checkpoint_engine = TorchCheckpointEngine() self._init_quantization_setting(quantization_setting) if enable_cuda_graph: @@ -376,9 +378,10 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): tag = fd.read().strip() ckpt_list = self._get_all_ckpt_names(load_dir, tag) - sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list) + sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine) else: - sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir) + sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, + self.checkpoint_engine) if type(sd_loader) is list: self.sd = torch.load(sd_loader[0], map_location='cpu') @@ -416,10 +419,12 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): state_dict=checkpoint[self._choose_module_key(checkpoint)], old_moe_load=old_moe_load, model=self.module, - mpu=self.mpu) + mpu=self.mpu, + checkpoint_engine=self.checkpoint_engine) self.module.load_state_dict( state_dict=checkpoint[self._choose_module_key(checkpoint)], + checkpoint_engine=self.checkpoint_engine, strict=load_module_strict) def _choose_module_key(self, sd): diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index 06321cbf90b3..c7113c13f53d 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -19,6 +19,7 @@ from argparse import ArgumentParser, REMAINDER from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from ..nebula.constants import DLTS_POD_ENV_PATH from ..utils import logger PID_FILE_BASEPATH = "/tmp" @@ -143,6 +144,17 @@ def main(): with open(pid_file, 'w') as fd: fd.write(f"{launcher_pid}") + if os.path.exists(DLTS_POD_ENV_PATH): + with open(DLTS_POD_ENV_PATH) as file: + lines = file.readlines() + lines = [line.rstrip() for line in lines] + for line in lines: + if line.startswith('export FC_TASKROLE_NAME') or line.startswith( + 'export FC_TASK_INDEX'): + key_val = line.split()[1] + key, val = key_val.split('=') + current_env[key] = val + processes = [] cmd = [] for local_rank in range(0, num_local_procs): diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index be25715660db..cf3e98dc25bb 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -20,12 +20,14 @@ from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from ..nebula.constants import NEBULA_EXPORT_ENVS from ..utils import logger from ..autotuning import Autotuner DLTS_HOSTFILE = "/job/hostfile" -EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", "UCX"] +EXPORT_ENVS = ['NCCL', 'PYTHON', 'MV2', 'UCX'] +EXPORT_ENVS += NEBULA_EXPORT_ENVS DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env" DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.'] PDSH_MAX_FAN_OUT = 1024 diff --git a/deepspeed/nebula/__init__.py b/deepspeed/nebula/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/deepspeed/nebula/config.py b/deepspeed/nebula/config.py new file mode 100644 index 000000000000..f9928d66147c --- /dev/null +++ b/deepspeed/nebula/config.py @@ -0,0 +1,53 @@ +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" + +from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject +from deepspeed.nebula.constants import * + + +class DeepSpeedNebulaConfig(DeepSpeedConfigObject): + def __init__(self, param_dict): + super(DeepSpeedNebulaConfig, self).__init__() + + self.enabled = None + self.persistent_storage_path = None + self.persistent_time_interval = None + self.num_of_version_in_retention = None + self.enable_nebula_load = None + + if NEBULA in param_dict.keys(): + nebula_dict = param_dict[NEBULA] + else: + nebula_dict = {} + + self._initialize(nebula_dict) + + def _initialize(self, nebula_dict): + self.enabled = get_scalar_param(nebula_dict, + NEBULA_ENABLED, + NEBULA_ENABLED_DEFAULT) + + self.load_path = get_scalar_param(nebula_dict, + NEBULA_LOAD_PATH, + NEBULA_LOAD_PATH_DEFAULT) + + self.enable_nebula_load = get_scalar_param(nebula_dict, + NEBULA_ENABLE_NEBULA_LOAD, + NEBULA_ENABLE_NEBULA_LOAD_DEFAULT) + + self.persistent_storage_path = get_scalar_param( + nebula_dict, + NEBULA_PERSISTENT_STORAGE_PATH, + NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT) + + self.persistent_time_interval = get_scalar_param( + nebula_dict, + NEBULA_PERSISTENT_TIME_INTERVAL, + NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT) + + self.num_of_version_in_retention = get_scalar_param( + nebula_dict, + NEBULA_NUM_OF_VERSION_IN_RETENTION, + NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT) diff --git a/deepspeed/nebula/constants.py b/deepspeed/nebula/constants.py new file mode 100644 index 000000000000..0e66fa8d1536 --- /dev/null +++ b/deepspeed/nebula/constants.py @@ -0,0 +1,86 @@ +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" + +######################################### +# nebula +######################################### +# Nebula. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +NEBULA_FORMAT = ''' +nebula should be enabled as: +"session_params": { + "nebula": { + "enabled": true, + "persistent_storage_path": "/foo/bar", + "persistent_time_interval": 100, + "num_of_version_in_retention": 2, + "enable_nebula_load": true + } +} +''' + +NEBULA = "nebula" + +NEBULA_ENABLED = "enabled" +NEBULA_ENABLED_DEFAULT = False + +# There is a case where customer want to load the checkpoint saved +# by raw torch. Because nebula cannot load torch checkpoint directly +# as they have different folder structures to bring the gap for +# loading(the data are totaly same in bytes for torch and enbula s +# aving). +# In this case, we must disable nebula load to use raw torch load. +# Customer can just set NEBULA_ENABLE_NEBULA_LOAD to False. Then use +# original way of deepspeed to load, i.e. set the value of "--load". +NEBULA_ENABLE_NEBULA_LOAD = "enable_nebula_load" +NEBULA_ENABLE_NEBULA_LOAD_DEFAULT = True + +# When you want to resume the previous checkpoint saved by nebula, +# you can set NEBULA_LOAD_PATH as the parent folder of checkpoint. +# If NEBULA_LOAD_PATH is None, the NEBULA_PERSISTENT_STORAGE_PATH +# will be the default path to load. +NEBULA_LOAD_PATH = "nebula_load_path" +NEBULA_LOAD_PATH_DEFAULT = None + +# Nebula will save the checkpoint under NEBULA_LOAD_PATH in the +# asynchronous way. +NEBULA_PERSISTENT_STORAGE_PATH = "persistent_storage_path" +NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT = None + +# Time interval to trigger the nebula persistence. +NEBULA_PERSISTENT_TIME_INTERVAL = "persistent_time_interval" +NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT = 100 + +# Checkpoint number which will be kept in memory. Let us say, +# if the value is 2. Then we have checkpoints 1 and 2 are ready +# now. When it comes to checkpoint 3, the 1 will be removed if +# 1 has been persisted to disk. +NEBULA_NUM_OF_VERSION_IN_RETENTION = "num_of_version_in_retention" +NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2 + +# Neubla envs +NEBULA_EXPORT_ENVS = [ + 'DLTS_JOB_ID', + 'DLTS_NUM_WORKER', + 'NEBULA_PERSISTENT_STORAGE_PATH', + 'NEBULA_PERSISTENT_TIME_INTERVAL', + 'AML_RUN_ID', + 'AZUREML_RUN_TOKEN', + 'AZUREML_WORKSPACE_SCOPE', + 'AZUREML_EXPERIMENT_SCOPE', + 'AZUREML_RUN_HISTORY_SERVICE_ENDPOINT', + 'AZUREML_RUN_ID', + 'NEBULA_MEMORY_BUFFER_SIZE', + 'AZUREML_PARAMETER_ITPJOB_NAME', + 'FC_TASKROLE_NAME', + 'FC_TASK_INDEX', + 'MASTER_HOST', + 'LOCAL_HOST', + 'AZUREML_BLOB_ACCOUNT_NAME', + 'AZUREML_BLOB_ACCOUNT_KEY' +] + +# ITP env files +DLTS_POD_ENV_PATH = '/dlts-runtime/env/pod.env' diff --git a/deepspeed/runtime/checkpoint_engine/README.md b/deepspeed/runtime/checkpoint_engine/README.md new file mode 100644 index 000000000000..a19f54889802 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/README.md @@ -0,0 +1,37 @@ +# Checkpoint Engine + + +The `CheckpointEngine` was designed to modularized the checkpoint serialization. In this way, we can simply replace/refine the checkpoint serialization methods. + +### Interface for `CheckpointEngine` + +Basically, for checkpoint management(save/load by deepspeed with the given tag), the `CheckpointEngine` will: + + 1. To make preliminaries ready by call `create(tag)`. For `torch`, we can just log some extra info as `torch` can directly call `save/load` without other preparation. + + 2. After the `create(tag)`, deepspeed can call `save/load` to persist files into disk/memory/etc. + + 3. When all the files for a tag are ready, deepspeed engine will call `commit()` to tell the checkpoint engine current checkpoint is complete. For original torch, it also plays the role of logger. + + +```python +class CheckpointEngine(object): + # init checkpoint engine for save/load + def __init__(self, config_params=None): + pass + + def create(self, tag): + # create checkpoint on give tag for save/load. + pass + + def save(self, state_dict, path: str): + pass + + def load(self, path: str, map_location=None): + pass + + def commit(self, tag): + # to tell checkpoint services if all files are readys. + pass + +``` diff --git a/deepspeed/runtime/checkpoint_engine/__init__.py b/deepspeed/runtime/checkpoint_engine/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py new file mode 100644 index 000000000000..67c888815242 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/checkpoint_engine.py @@ -0,0 +1,19 @@ +class CheckpointEngine(object): + + # init checkpoint engine for save/load + def __init__(self, config_params=None): + pass + + def create(self, tag): + # create checkpoint on give tag for save/load. + pass + + def save(self, state_dict, path: str): + pass + + def load(self, path: str, map_location=None): + pass + + def commit(self, tag): + # to tell checkpoint services if all files are readys. + pass diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py new file mode 100644 index 000000000000..b7ff5cdf9829 --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -0,0 +1,88 @@ +import os +import torch +import torch_nebula + +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ + CheckpointEngine +from deepspeed.utils import logger, log_dist +from deepspeed.nebula.constants import * + + +def _get_tag_from_path(path): + return os.path.basename(os.path.dirname(path)) + + +class NebulaCheckpointEngine(CheckpointEngine): + def __init__(self, config_params=None): + super().__init__(config_params) + self.checkpoint = None + self.tag_flag = None + self.enable_nebula_load = config_params.enable_nebula_load + self.nebula_load_path = config_params.load_path + if self.nebula_load_path is None: + self.nebula_load_path = config_params.persistent_storage_path + + nebula_config_params = { + NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path, + NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval, + NEBULA_NUM_OF_VERSION_IN_RETENTION: + config_params.num_of_version_in_retention, + } + torch_nebula.init(**nebula_config_params) + + def create(self, tag): + log_dist(f"[Nebula] Start Checkpoint for tag:{tag}", ranks=[0]) + # -2 means: customer needs to explicitly tell nebula + # current checkpoint is complete by commit methond. + self.checkpoint = torch_nebula.Checkpoint(tag, -2) + + def save(self, state_dict, path: str): + tag = _get_tag_from_path(path) + partititon_name = os.path.basename(path) + logger.info(f"[Nebula] Saving {partititon_name} under tag{tag}...") + self.checkpoint.save(partititon_name, state_dict) + logger.info(f"[Nebula] Saved {partititon_name} under tag{tag}.") + return None + + def load(self, path: str, map_location=None): + tag = _get_tag_from_path(path) + first_load_flag = self.tag_flag is None or self.tag_flag == tag + if not self.enable_nebula_load and first_load_flag: + self.tag_flag = tag + logger.info( + f"[Nebula] Disable nebula load. Loading checkpoint from {path}...") + partition = torch.load(path, map_location=map_location) + logger.info( + f"[Nebula] Disable nebula load. Loaded checkpoint from {path}...") + return partition + + partititon_name = os.path.basename(path) + logger.info( + f"[Nebula] Loading {path} under tag{tag} from {self.nebula_load_path}...") + + checkpoint = None + if tag is None: + checkpoint = torch_nebula.get_latest_checkpoint( + persist_path=self.nebula_load_path) + if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''): + logger.warning(f"Unable to find latest valid checkpoint from Nebula!") + return None + else: + checkpoint = torch_nebula.get_checkpoint(tag=tag, + persist_path=self.nebula_load_path) + partition = checkpoint.load(partititon_name, map_location=map_location) + logger.info( + f"[Nebula] Loaded {path} under tag{tag} from {self.nebula_load_path}.") + return partition + + def commit(self, tag): + # nebula commit will be call when all files under give tag are ready to be persisted in the async way. + logger.info( + f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting" + ) + commit_rls = self.checkpoint.commit() + if not commit_rls: + logger.error( + f"[Nebula] failed to commit the checkpoint, please check the log.") + return False + return commit_rls diff --git a/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py new file mode 100644 index 000000000000..9b4942f0a01f --- /dev/null +++ b/deepspeed/runtime/checkpoint_engine/torch_checkpoint_engine.py @@ -0,0 +1,28 @@ +import torch +from deepspeed.utils import logger, log_dist +from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ + CheckpointEngine + + +class TorchCheckpointEngine(CheckpointEngine): + def __init__(self, config_params=None): + super().__init__(config_params) + + def create(self, tag): + log_dist(f"[Torch] Checkpoint {tag} is begin to save!", ranks=[0]) + + def save(self, state_dict, path: str): + logger.info(f"[Torch] Saving {path}...") + torch.save(state_dict, path) + logger.info(f"[Torch] Saved {path}.") + return None + + def load(self, path: str, map_location=None): + logger.info(f"[Torch] Loading checkpoint from {path}...") + partition = torch.load(path, map_location=map_location) + logger.info(f"[Torch] Loaded checkpoint from {path}.") + return partition + + def commit(self, tag): + logger.info(f"[Torch] Checkpoint {tag} is ready now!") + return True diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 24ea5a01efbd..0794e4525343 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -45,6 +45,8 @@ from ..profiling.config import DeepSpeedFlopsProfilerConfig from ..autotuning.config import DeepSpeedAutotuningConfig +from ..nebula.config import DeepSpeedNebulaConfig + from ..compression.config import get_compression_config, get_quantize_enabled from ..compression.constants import * from .swap_tensor.aio_config import get_aio_config @@ -864,6 +866,8 @@ def _initialize_params(self, param_dict): self.dataloader_drop_last = get_dataloader_drop_last(param_dict) + self.nebula_config = DeepSpeedNebulaConfig(param_dict) + def _batch_assertion(self): train_batch = self.train_batch_size diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f6e5da1aed73..9d1b8b6aac74 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -63,6 +63,7 @@ from deepspeed.runtime.utils import clip_grad_norm_ from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from .pipe.module import PipelineModule from .utils import ensure_directory_exists, get_ma_status @@ -223,6 +224,8 @@ def __init__( self._global_grad_norm = None self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. + self.checkpoint_engine = None + global dist from deepspeed import comm as dist self._is_gradient_accumulation_boundary = None @@ -780,6 +783,19 @@ def _configure_lr_scheduler(self, client_lr_scheduler): log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) def _configure_checkpointing(self, dist_init_required): + self.checkpoint_engine = TorchCheckpointEngine() + + if self._config is not None and self._config.nebula_config.enabled: + try: + from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \ + NebulaCheckpointEngine + self.checkpoint_engine = NebulaCheckpointEngine( + config_params=self._config.nebula_config) + except ImportError as err: + logger.error( + f"No torch_nebula was found! Will fall back to torch.save. Details: {err}" + ) + self.checkpoint_engine = TorchCheckpointEngine() dp_rank = self.global_rank if self.mpu: @@ -2318,7 +2334,8 @@ def load_moe_state_dict(checkpoint_path, old_moe_load, model=None, mpu=None, - num_experts=1): + num_experts=1, + checkpoint_engine=TorchCheckpointEngine()): if old_moe_load: expp_rank = groups._get_expert_data_parallel_rank( groups._get_max_expert_size_name()) @@ -2328,7 +2345,7 @@ def load_moe_state_dict(checkpoint_path, groups._get_max_expert_size_name()) for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id - expert_state_dict = torch.load(DeepSpeedEngine._get_expert_ckpt_name( + expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name( checkpoint_path, -1, # -1 means ignore layer_id global_expert_id, @@ -2354,7 +2371,7 @@ def load_moe_state_dict(checkpoint_path, # loop all local_experts for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id - expert_state_dict = torch.load( + expert_state_dict = checkpoint_engine.load( DeepSpeedEngine._get_expert_ckpt_name( checkpoint_path, moe_layer_id, @@ -2377,7 +2394,8 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): if custom_load_fn: custom_load_fn(src=state_dict, dst=self.module) else: - self.module.load_state_dict(state_dict, strict=strict) + self.module.load_state_dict(state_dict, # TODO + strict=strict) def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}' @@ -2550,7 +2568,9 @@ def _load_checkpoint(self, from deepspeed.runtime.state_dict_factory import SDLoaderFactory ckpt_list = self._get_all_ckpt_names(load_dir, tag) - sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list) + sd_loader = SDLoaderFactory.get_sd_loader( + ckpt_list, + checkpoint_engine=self.checkpoint_engine) is_pipe_parallel = isinstance(self.module, PipelineModule) @@ -2577,7 +2597,8 @@ def _load_checkpoint(self, old_moe_load=old_moe_load, model=self.module, mpu=self.mpu, - num_experts=self.num_experts) + num_experts=self.num_experts, + checkpoint_engine=self.checkpoint_engine) if not self.load_universal_checkpoint(): self.load_module_state_dict(state_dict=checkpoint['module'], strict=load_module_strict, @@ -2594,8 +2615,9 @@ def _load_checkpoint(self, largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) - optim_checkpoint = torch.load(optim_load_path, - map_location=torch.device('cpu')) + optim_checkpoint = self.checkpoint_engine.load( + optim_load_path, + map_location=torch.device('cpu')) else: optim_checkpoint = checkpoint @@ -2762,7 +2784,10 @@ def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): # Fully load state for current rank if self.zero_elastic_checkpoint() or dist.get_rank( group=self.optimizer.dp_process_group) == i: - _state = torch.load(ckpt_name, map_location='cpu') + _state = self.checkpoint_engine.load( + ckpt_name, + map_location='cpu', + ) else: _state = {OPTIMIZER_STATE_DICT: None} zero_sd_list.append(_state) @@ -2838,6 +2863,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) # Ensure tag is a string tag = str(tag) + self.checkpoint_engine.create(tag) # Ensure checkpoint tag is consistent across ranks self._checkpoint_tag_validation(tag) @@ -2860,6 +2886,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) # Save latest checkpoint tag dist.barrier() + self.checkpoint_engine.commit(tag) if save_latest and self.global_rank == 0: with open(os.path.join(save_dir, 'latest'), 'w') as fd: fd.write(tag) @@ -2929,7 +2956,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): global_expert_id, tag, self.mpu) - torch.save(expert_state_dict, moe_save_path) + self.checkpoint_engine.save(expert_state_dict, moe_save_path) moe_layer_id += 1 self._curr_ckpt_path = os.path.join(save_dir, tag) @@ -2950,9 +2977,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None } - with open(self._get_optimizer_ckpt_name(save_dir, tag, expp_rank), 'wb') as fd: - torch.save(optimizer_state, fd) - fd.flush() + # TODO: why use BufferedWriter not the path + file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) + self.checkpoint_engine.save(optimizer_state, file_path) # get non-moe parameters model_state_dict = self._get_non_moe_state_dict(self.module_state_dict()) @@ -2982,9 +3009,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): } state.update(client_state) logger.info(f'Saving model checkpoint: {save_path}') - with open(save_path, 'wb') as fd: - torch.save(state, fd) - fd.flush() + self.checkpoint_engine.save(state, save_path) self._curr_save_path = None def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): @@ -3037,7 +3062,7 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): state.update(client_state) log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1]) - torch.save(state, save_path) + self.checkpoint_engine.save(state, save_path) self._curr_save_path = None def _get_buffer_names(self): @@ -3118,9 +3143,8 @@ def _save_zero_checkpoint(self, save_path, tag): zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version) - with open(zero_checkpoint_name, 'wb') as fd: - torch.save(zero_sd, fd) - fd.flush() + self.checkpoint_engine.save(zero_sd, zero_checkpoint_name) + if self.global_rank == 0: self._copy_recovery_script(save_path) ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero' @@ -3228,6 +3252,6 @@ def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): if dist.get_rank() == 0: os.makedirs(save_dir, exist_ok=True) logger.info(f"Saving model weights to {path}") - torch.save(state_dict, path) + self.checkpoint_engine.save(state_dict, path) return True diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index a47f1e45b241..0f3443d980f8 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -1315,7 +1315,8 @@ def module_state_dict(self): assert self._curr_ckpt_path is not None, \ "PipelineEngine expects module_state_dict() to be called from save_checkpoint()" - self.module.save_state_dict(self._curr_ckpt_path) + self.module.save_state_dict(self._curr_ckpt_path, + checkpoint_engine=self.checkpoint_engine) return None def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): @@ -1334,7 +1335,9 @@ def load_module_state_dict(self, state_dict, strict=True, custom_load_fn=None): super().load_module_state_dict(state_dict, strict) return - self.module.load_state_dir(load_dir=self._curr_ckpt_path, strict=strict) + self.module.load_state_dir(load_dir=self._curr_ckpt_path, + strict=strict, + checkpoint_engine=self.checkpoint_engine) # A map of PipeInstruction types to methods. Each method will be executed with the # kwargs provided to the PipeInstruction from the scheduler. diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index ac35a9fa2cf8..03e1c413c950 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -561,7 +561,7 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx): ckpt_files.sort() return ckpt_files - def save_state_dict(self, save_dir): + def save_state_dict(self, save_dir, checkpoint_engine): if self._grid.data_parallel_id != 0: return @@ -582,9 +582,9 @@ def save_state_dict(self, save_dir): {k: v.clone() for k, v in orig_state_dict.items()}) - torch.save(final_state_dict, model_ckpt_path) + checkpoint_engine.save(final_state_dict, model_ckpt_path) - def load_state_dir(self, load_dir, strict=True): + def load_state_dir(self, load_dir, checkpoint_engine, strict=True): for idx, layer in enumerate(self.forward_funcs): # Functions, etc. will not have state_dicts if not hasattr(layer, 'load_state_dict'): @@ -595,7 +595,10 @@ def load_state_dir(self, load_dir, strict=True): mp_rank = self._grid.get_slice_parallel_rank() mp_world_size = self._grid.get_slice_parallel_world_size() - sd_loader = SDLoaderFactory.get_sd_loader(model_ckpt_list, version=2.0) + sd_loader = SDLoaderFactory.get_sd_loader( + model_ckpt_list, + version=2.0, + checkpoint_engine=checkpoint_engine) load_path, checkpoint, _ = sd_loader.load(mp_world_size, mp_rank, module_key=None, is_pipe_parallel=True) layer.load_state_dict(checkpoint) diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 35ccb8d5fd02..6097e8baa004 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -8,7 +8,10 @@ import collections import json from abc import ABC, abstractmethod + from deepspeed.utils import logger +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine + from .weight_quantizer import WeightQuantization AUTO_MODULE_KEY = 'auto' @@ -16,7 +19,7 @@ class SDLoaderFactory: @staticmethod - def get_sd_loader_json(json_file): + def get_sd_loader_json(json_file, checkpoint_engine): with open(json_file) as f: data = json.load(f) sd_type = data['type'] @@ -24,22 +27,27 @@ def get_sd_loader_json(json_file): version = data['version'] if 'BLOOM' in sd_type or 'Bloom' in sd_type: return ckpt_list - return SDLoaderFactory.get_sd_loader(ckpt_list, sd_type, version) + return SDLoaderFactory.get_sd_loader(ckpt_list, + checkpoint_engine, + sd_type, + version) @staticmethod - def get_sd_loader(ckpt_list, sd_type='Megatron', version=None): + def get_sd_loader(ckpt_list, checkpoint_engine, sd_type='Megatron', version=None): if sd_type == 'Megatron': - return MegatronSDLoader(ckpt_list, version) + return MegatronSDLoader(ckpt_list, version, checkpoint_engine) else: assert False, '{} checkpoint type is not supported'.format(sd_type) class SDLoaderBase(ABC): - def __init__(self, ckpt_list, version): + def __init__(self, ckpt_list, version, checkpoint_engine): self.module_key = None self.ckpt_list = ckpt_list - self.check_ckpt_list() self.version = version + self.checkpoint_engine = TorchCheckpointEngine( + ) if checkpoint_engine is None else checkpoint_engine + self.check_ckpt_list() def load(self, mp_world_size, @@ -81,7 +89,8 @@ def load(self, if num_ckpt == mp_world_size: assert os.path.exists(load_path) #logger.info(f'rank: {mp_rank} loading checkpoint: {load_path}') - sd = torch.load(load_path, map_location=lambda storage, loc: storage) + sd = self.checkpoint_engine.load(load_path, map_location=lambda storage, \ + loc: storage) if quantize: quantizer = WeightQuantization(mlp_extra_grouping=mlp_extra_grouping, @@ -110,9 +119,9 @@ def get_merge_state_dicts(self, mp_world_size, mp_rank): logger.info(f"mp_rank: {mp_rank}, ckpt_list: {ckpt_list}") sd_list = [ - torch.load(ckpt, - map_location=lambda storage, - loc: storage) for ckpt in ckpt_list + self.checkpoint_engine.load(ckpt, + map_location=lambda storage, + loc: storage) for ckpt in ckpt_list ] return sd_list @@ -128,9 +137,9 @@ def get_split_state_dict(self, mp_world_size, mp_rank): f"mp_rank: {mp_rank}, ckpt_list: {self.ckpt_list[ckpt_index]}, offset: {ckpt_offset}" ) - sd = torch.load(self.ckpt_list[ckpt_index], - map_location=lambda storage, - loc: storage) + sd = self.checkpoint_engine.load(self.ckpt_list[ckpt_index], + map_location=lambda storage, + loc: storage) return sd, num_to_split, ckpt_offset @@ -163,7 +172,9 @@ def check_ckpt_list(self): #logger.info(f'checkpoint file list: {self.ckpt_list}') assert len(self.ckpt_list) > 0 - sd = torch.load(self.ckpt_list[0], map_location=lambda storage, loc: storage) + sd = self.checkpoint_engine.load(self.ckpt_list[0], + map_location=lambda storage, + loc: storage) # check checkpoint count is same with saved mp_world_size if 'mp_world_size' in sd.keys(): @@ -195,8 +206,8 @@ def sanity_check(self, ckpt_file_name): class MegatronSDLoader(SDLoaderBase): - def __init__(self, ckpt_list, version): - super().__init__(ckpt_list, version) + def __init__(self, ckpt_list, version, checkpoint_engine): + super().__init__(ckpt_list, version, checkpoint_engine) """ ## Q/K/V data need special processing key: transformer.layers.0.attention.query_key_value.weight, shape: torch.Size([3192, 4256]) @@ -433,7 +444,9 @@ def sanity_check(self, ckpt_file_name): "mlp.dense_h_to_4h.bias" ] - sd = torch.load(ckpt_file_name, map_location=lambda storage, loc: storage) + sd = self.checkpoint_engine.load(ckpt_file_name, + map_location=lambda storage, + loc: storage) # partial_key is a sub-string of one key in the sd def check_key_exist(partial_key, sd): diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index e55fc30fa56c..7174ae0a0a63 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -2,6 +2,7 @@ from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer +from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer from deepspeed.runtime.pipe.topology import * @@ -727,13 +728,14 @@ def _test(save_folder, num_stages): def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir): @distributed_test(world_size=4) def _test(base_topo, test_topo, save_folder): + checkpoint_engine = TorchCheckpointEngine() base_model = LinearStackPipe(topology=base_topo) - base_model.save_state_dict(save_folder) + base_model.save_state_dict(save_folder, checkpoint_engine=checkpoint_engine) dist.barrier() test_model = LinearStackPipe(topology=test_topo) - test_model.load_state_dir(save_folder) + test_model.load_state_dir(save_folder, checkpoint_engine=checkpoint_engine) # Base and test can have different lengths, so make sure we map from the # smaller to larger model