Skip to content

Commit

Permalink
add async + distopt to sft (#7018)
Browse files Browse the repository at this point in the history
Signed-off-by: MaximumEntropy <[email protected]>
  • Loading branch information
MaximumEntropy committed Jul 12, 2023
1 parent 41d8477 commit e87985d
Showing 1 changed file with 14 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
tensor_shape = [seq_length, get_micro_batch_size(), self.cfg.hidden_size]
data_iter = get_iterator_k_split(batch, get_num_microbatches())

# handle asynchronous grad reduction
no_sync_func = None
grad_sync_func = None
param_sync_func = None
if not forward_only and self.with_distributed_adam:
no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_o2,)
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters

fwd_bwd_function = get_forward_backward_func()

losses_reduced_per_micro_batch = fwd_bwd_function(
Expand All @@ -309,6 +318,11 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=self.enable_autocast,
no_sync_func=no_sync_func,
grad_sync_func=grad_sync_func,
param_sync_func=param_sync_func,
overlap_p2p_comm=self.cfg.get('overlap_p2p_comm', False),
batch_p2p_comm=self.cfg.get('batch_p2p_comm', True),
)

# only the last stages of the pipeline return losses
Expand Down

0 comments on commit e87985d

Please sign in to comment.