diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 78e347e84aaf..b62a76e54202 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2009,11 +2009,14 @@ def backward(self, return loss def is_gradient_accumulation_boundary(self): - """Query whether the current micro-batch is at the boundary of + """ + Query whether the current micro-batch is at the boundary of gradient accumulation, and thus will trigger gradient reductions and an optimizer step. + Returns: bool: if the current step is a gradient accumulation boundary. + """ if self._is_gradient_accumulation_boundary is None: return (self.micro_steps + 1) % \ @@ -2022,7 +2025,8 @@ def is_gradient_accumulation_boundary(self): return self._is_gradient_accumulation_boundary def set_gradient_accumulation_boundary(self, is_boundary): - """Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional + """ + Manually overrides the DeepSpeed engine's gradient accumulation boundary state, this is an optional feature and should be used with care. The state should be set before to the intended value before each forward/backward. The final fordward/backward should have the boundary state set to True. This style allows client code to only call engine.step() once after all @@ -2714,7 +2718,9 @@ def load_checkpoint(self, load_lr_scheduler_states=True, load_module_only=False, custom_load_fn=None): - """Load training checkpoint + """ + Load training checkpoint + Arguments: load_dir: Required. Directory to load the checkpoint from tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file @@ -2723,14 +2729,17 @@ def load_checkpoint(self, load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting. custom_load_fn: Optional. Custom model load function. + Returns: A tuple of ``load_path`` and ``client_state``. *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed. *``client_state``: State dictionary used for loading required training states in the client code. + Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine before ``load_checkpoint()``. + """ if tag is None: @@ -3062,7 +3071,8 @@ def _checkpoint_tag_validation(self, tag): logger.warning(msg) def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True): - r"""Save training checkpoint + """Save training checkpoint + Arguments: save_dir: Required. Directory for saving the checkpoint tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is @@ -3073,6 +3083,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True) because each process needs to save its master weights and scheduler+optimizer states. This method will hang waiting to synchronize with other processes if it's called just for the process with rank 0. + """ if self.zero_optimization_partition_weights(): # Prepare for checkpoint save by ensuring all parameters are partitioned @@ -3467,17 +3478,23 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): return self.save_16bit_model(save_dir, save_filename) def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): - r"""Save 16bit model weights + """ + Save 16bit model weights + This method saves the 16bit model weights at the desired destination. + Arguments: save_dir: Required. Directory for saving the model save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin`` + Returns: ``True`` when a model has been saved, ``False`` otherwise. It will not be saved if stage3_gather_16bit_weights_on_model_save is ``False``. + Important: all processes must call this method and not just the process with rank 0. It is because the processes need to work in sync to gather the weights. This method will hang waiting to synchronize with other processes if it's called just for the process with rank 0. + """ path = os.path.join(save_dir, save_filename)