Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2009,11 +2009,14 @@ def backward(self,
return loss

def is_gradient_accumulation_boundary(self):
"""Query whether the current micro-batch is at the boundary of
"""
Query whether the current micro-batch is at the boundary of
gradient accumulation, and thus will trigger gradient reductions and
an optimizer step.

Returns:
bool: if the current step is a gradient accumulation boundary.

"""
if self._is_gradient_accumulation_boundary is None:
return (self.micro_steps + 1) % \
Expand All @@ -2022,7 +2025,8 @@ def is_gradient_accumulation_boundary(self):
return self._is_gradient_accumulation_boundary

def set_gradient_accumulation_boundary(self, is_boundary):
"""Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional
"""
Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional
feature and should be used with care. The state should be set before to the intended
value before each forward/backward. The final fordward/backward should have the
boundary state set to True. This style allows client code to only call engine.step() once after all
Expand Down Expand Up @@ -2714,7 +2718,9 @@ def load_checkpoint(self,
load_lr_scheduler_states=True,
load_module_only=False,
custom_load_fn=None):
"""Load training checkpoint
"""
Load training checkpoint

Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
Expand All @@ -2723,14 +2729,17 @@ def load_checkpoint(self,
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
custom_load_fn: Optional. Custom model load function.

Returns:
A tuple of ``load_path`` and ``client_state``.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
*``client_state``: State dictionary used for loading required training states in the client code.

Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and
``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine
before ``load_checkpoint()``.

"""

if tag is None:
Expand Down Expand Up @@ -3062,7 +3071,8 @@ def _checkpoint_tag_validation(self, tag):
logger.warning(msg)

def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True):
r"""Save training checkpoint
"""Save training checkpoint

Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
Expand All @@ -3073,6 +3083,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
because each process needs to save its master weights and scheduler+optimizer states. This
method will hang waiting to synchronize with other processes if it's called just for the
process with rank 0.

"""
if self.zero_optimization_partition_weights():
# Prepare for checkpoint save by ensuring all parameters are partitioned
Expand Down Expand Up @@ -3467,17 +3478,23 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
return self.save_16bit_model(save_dir, save_filename)

def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save 16bit model weights
"""
Save 16bit model weights

This method saves the 16bit model weights at the desired destination.

Arguments:
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``

Returns:
``True`` when a model has been saved, ``False`` otherwise. It will not be saved if
stage3_gather_16bit_weights_on_model_save is ``False``.

Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
waiting to synchronize with other processes if it's called just for the process with rank 0.

"""

path = os.path.join(save_dir, save_filename)
Expand Down