Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,19 @@ def get_checkpoint_tag_validation_mode(checkpoint_params):
)


def get_checkpoint_parallel_write_pipeline(checkpoint_params):
par_write_params = checkpoint_params.get(CHECKPOINT_PARALLEL_WRITE, {})
par_write_pipeline = par_write_params.get(
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE,
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT)
if par_write_pipeline in [True, False]:
return par_write_pipeline
else:
raise DeepSpeedConfigError(
"checkpoint::parallel_write::pipeline_stage "
f"value of '{par_write_pipeline}' is invalid, expecting: true or false")


def get_dataloader_drop_last(param_dict):
return get_scalar_param(param_dict,
DATALOADER_DROP_LAST,
Expand Down Expand Up @@ -887,6 +900,8 @@ def _initialize_params(self, param_dict):
self.load_universal_checkpoint = checkpoint_params.get(
LOAD_UNIVERSAL_CHECKPOINT,
LOAD_UNIVERSAL_CHECKPOINT_DEFAULT)
par_write_pipe = get_checkpoint_parallel_write_pipeline(checkpoint_params)
self.checkpoint_parallel_write_pipeline = par_write_pipe

self.aio_config = get_aio_config(param_dict)

Expand Down
7 changes: 7 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ class ValidationMode:
# "checkpoint": {
# tag_validation=["Ignore"|"Warn"|"Fail"]
# load_universal=false
# parallel_write: {
# pipeline_stage: [True|False]
# }
# }
CHECKPOINT = "checkpoint"
CHECKPOINT_TAG_VALIDATION = "tag_validation"
Expand All @@ -380,6 +383,10 @@ class ValidationMode:
LOAD_UNIVERSAL_CHECKPOINT = "load_universal"
LOAD_UNIVERSAL_CHECKPOINT_DEFAULT = False

CHECKPOINT_PARALLEL_WRITE = "parallel_write"
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE = "pipeline_stage"
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT = False

#########################################
# Drop the last incomplete Batch
# #########################################
Expand Down
24 changes: 17 additions & 7 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2924,7 +2924,11 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
self._create_checkpoint_file(save_dir, tag, False)
self._save_moe_checkpoint(save_dir, tag, client_state=client_state)

if self.save_non_zero_checkpoint:
# We distribute the task of saving layer checkpoint files among
# data parallel instances, so all procs should call _save_checkpoint.
# All procs then call module_state_dict(), but only procs of data
# parallel rank 0 save the general model params.
if not self.has_moe_layers:
self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state)

Expand Down Expand Up @@ -3091,12 +3095,18 @@ def _create_zero_checkpoint_files(self, save_dir, tag):
def _save_checkpoint(self, save_dir, tag, client_state={}):

save_path = self._get_ckpt_name(save_dir, tag)

zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()

# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None.
# then instead just returns None. The module_state_dict() implementation in
# PipelineEngine expects the save path to be set in self._curr_ckpt_path.
self._curr_ckpt_path = os.path.join(save_dir, tag)
zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
state = dict(module=self.module_state_dict(),
module = self.module_state_dict()
self._curr_ckpt_path = None

state = dict(module=module,
buffer_names=self._get_buffer_names(),
optimizer=self.optimizer.state_dict()
if self.optimizer and not zero_optimizer_state else None,
Expand All @@ -3114,9 +3124,9 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
ds_version=version)
state.update(client_state)

log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
self.checkpoint_engine.save(state, save_path)
self._curr_save_path = None
if self.save_non_zero_checkpoint:
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
self.checkpoint_engine.save(state, save_path)

def _get_buffer_names(self):
buffer_names = []
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
self.module.activation_checkpoint_interval = self._config.pipeline[
'activation_checkpoint_interval']

self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline

if self.is_last_stage():
self.loss_model = self.module.loss_fn

Expand Down
25 changes: 20 additions & 5 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,13 +562,28 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx):
return ckpt_files

def save_state_dict(self, save_dir, checkpoint_engine):
if self._grid.data_parallel_id != 0:
return
# Processes having the same model parallel rank on different data parallel instances
# have identical layer weights. We can distribute the task of saving the layer weights
# among the data parallel ranks. For example, if a pipeline stage has 9 layers and
# if there are 2 data parallel instances, rank 0 will save the first 5 layers and
# rank 1 will save the last 4.
dp_rank = self._grid.data_parallel_id
dp_size = self._grid.data_parallel_size
num_layers = len(self.forward_funcs)
if self.checkpoint_parallel_write_pipeline:
# spread layers evenly across data parallel ranks
offsets = ds_utils.partition_uniform(num_layers, dp_size)
start, end = offsets[dp_rank], offsets[dp_rank + 1]
else:
# data parallel rank 0 writes all layers
if dp_rank != 0:
return
start, end = 0, num_layers
layer_list = self.forward_funcs[start:end]

os.makedirs(save_dir, exist_ok=True)
layer_offset = self._local_start
for idx, layer in enumerate(self.forward_funcs):
model_ckpt_path = self.ckpt_layer_path(save_dir, idx)
for idx, layer in enumerate(layer_list):
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
if not hasattr(layer, 'state_dict'):
continue
# We pass cloned tensors to torch.save() to avoid checkpoint bloat which occurs because torch.save()
Expand Down