diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 6d377e5dc6a2..b602d9d10958 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -669,6 +669,19 @@ def get_checkpoint_tag_validation_mode(checkpoint_params): ) +def get_checkpoint_parallel_write_pipeline(checkpoint_params): + par_write_params = checkpoint_params.get(CHECKPOINT_PARALLEL_WRITE, {}) + par_write_pipeline = par_write_params.get( + CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE, + CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT) + if par_write_pipeline in [True, False]: + return par_write_pipeline + else: + raise DeepSpeedConfigError( + "checkpoint::parallel_write::pipeline_stage " + f"value of '{par_write_pipeline}' is invalid, expecting: true or false") + + def get_dataloader_drop_last(param_dict): return get_scalar_param(param_dict, DATALOADER_DROP_LAST, @@ -887,6 +900,8 @@ def _initialize_params(self, param_dict): self.load_universal_checkpoint = checkpoint_params.get( LOAD_UNIVERSAL_CHECKPOINT, LOAD_UNIVERSAL_CHECKPOINT_DEFAULT) + par_write_pipe = get_checkpoint_parallel_write_pipeline(checkpoint_params) + self.checkpoint_parallel_write_pipeline = par_write_pipe self.aio_config = get_aio_config(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index da36a7199470..c6a257e77d73 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -367,6 +367,9 @@ class ValidationMode: # "checkpoint": { # tag_validation=["Ignore"|"Warn"|"Fail"] # load_universal=false +# parallel_write: { +# pipeline_stage: [True|False] +# } # } CHECKPOINT = "checkpoint" CHECKPOINT_TAG_VALIDATION = "tag_validation" @@ -380,6 +383,10 @@ class ValidationMode: LOAD_UNIVERSAL_CHECKPOINT = "load_universal" LOAD_UNIVERSAL_CHECKPOINT_DEFAULT = False +CHECKPOINT_PARALLEL_WRITE = "parallel_write" +CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE = "pipeline_stage" +CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT = False + ######################################### # Drop the last incomplete Batch # ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 1f4331239f3b..0b54642abb0a 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2924,7 +2924,11 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) self._create_checkpoint_file(save_dir, tag, False) self._save_moe_checkpoint(save_dir, tag, client_state=client_state) - if self.save_non_zero_checkpoint: + # We distribute the task of saving layer checkpoint files among + # data parallel instances, so all procs should call _save_checkpoint. + # All procs then call module_state_dict(), but only procs of data + # parallel rank 0 save the general model params. + if not self.has_moe_layers: self._create_checkpoint_file(save_dir, tag, False) self._save_checkpoint(save_dir, tag, client_state=client_state) @@ -3091,12 +3095,18 @@ def _create_zero_checkpoint_files(self, save_dir, tag): def _save_checkpoint(self, save_dir, tag, client_state={}): save_path = self._get_ckpt_name(save_dir, tag) + + zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() + # A hack to save the checkpointing directory. Pipeline parallelism overrides # module_state_dict() and uses this path to save the model. module_state_dict() - # then instead just returns None. + # then instead just returns None. The module_state_dict() implementation in + # PipelineEngine expects the save path to be set in self._curr_ckpt_path. self._curr_ckpt_path = os.path.join(save_dir, tag) - zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() - state = dict(module=self.module_state_dict(), + module = self.module_state_dict() + self._curr_ckpt_path = None + + state = dict(module=module, buffer_names=self._get_buffer_names(), optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, @@ -3114,9 +3124,9 @@ def _save_checkpoint(self, save_dir, tag, client_state={}): ds_version=version) state.update(client_state) - log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1]) - self.checkpoint_engine.save(state, save_path) - self._curr_save_path = None + if self.save_non_zero_checkpoint: + log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1]) + self.checkpoint_engine.save(state, save_path) def _get_buffer_names(self): buffer_names = [] diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 33edc2db1a6a..1a57bb4e84a2 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -182,6 +182,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self.module.activation_checkpoint_interval = self._config.pipeline[ 'activation_checkpoint_interval'] + self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline + if self.is_last_stage(): self.loss_model = self.module.loss_fn diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 03e1c413c950..294db38b3bfb 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -562,13 +562,28 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx): return ckpt_files def save_state_dict(self, save_dir, checkpoint_engine): - if self._grid.data_parallel_id != 0: - return + # Processes having the same model parallel rank on different data parallel instances + # have identical layer weights. We can distribute the task of saving the layer weights + # among the data parallel ranks. For example, if a pipeline stage has 9 layers and + # if there are 2 data parallel instances, rank 0 will save the first 5 layers and + # rank 1 will save the last 4. + dp_rank = self._grid.data_parallel_id + dp_size = self._grid.data_parallel_size + num_layers = len(self.forward_funcs) + if self.checkpoint_parallel_write_pipeline: + # spread layers evenly across data parallel ranks + offsets = ds_utils.partition_uniform(num_layers, dp_size) + start, end = offsets[dp_rank], offsets[dp_rank + 1] + else: + # data parallel rank 0 writes all layers + if dp_rank != 0: + return + start, end = 0, num_layers + layer_list = self.forward_funcs[start:end] os.makedirs(save_dir, exist_ok=True) - layer_offset = self._local_start - for idx, layer in enumerate(self.forward_funcs): - model_ckpt_path = self.ckpt_layer_path(save_dir, idx) + for idx, layer in enumerate(layer_list): + model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx) if not hasattr(layer, 'state_dict'): continue # We pass cloned tensors to torch.save() to avoid checkpoint bloat which occurs because torch.save()