Skip to content

fix: gradient accumulation in DP#906

Closed
tongyx361 wants to merge 18 commits intoverl-project:mainfrom
tongyx361:tyx/fix/grad-acc-dp
Closed

fix: gradient accumulation in DP#906
tongyx361 wants to merge 18 commits intoverl-project:mainfrom
tongyx361:tyx/fix/grad-acc-dp

Conversation

@tongyx361
Copy link
Copy Markdown
Collaborator

@tongyx361 tongyx361 commented Apr 3, 2025

Motivation

Gradient accumulation should ensure the loss after applying it is the same as not applying it. But verl's original implementation is only compatible with sequence-mean loss:

mini_loss_to_acc = micro_agg_loss * (len(micro_batch) / self.config.ppo_mini_batch_size)

while verl used to use token-mean loss by default.

For more background, please refer to:

  1. https://huggingface.co/blog/gradient_accumulation
  2. https://unsloth.ai/blog/gradient

Related Issue/Comment(s)

#623 (comment)

Summary

This PR fixes the mismatch between w/ & w/o gradient accumulation by adapting to the loss aggregation mode.

Core Code to Review

  1. Calculate the number of loss tokens in every mini-batch (mini_batch_loss_token_nums):
def calc_mini_batch_loss_token_nums(batch: DataProto, traj_mini_bsz: int, num_dp_ranks: int) -> list[int]:
    """
    NOTE: Be compatible with
    
    1. verl.workers.fsdp_workers.ActorRolloutRefWorker.update_actor
    2. verl.workers.fsdp_workers.CriticWorker.update_critic

    TODO: Calculate separate numbers if adopting different strategies for actor and critic
    """
    response_mask = compute_response_mask(response_ids=batch.batch['responses'],
                                          attention_mask=batch.batch['attention_mask'])

    traj_bsz = len(batch.batch)
    num_mini_batches = (traj_bsz + traj_mini_bsz - 1) // traj_mini_bsz
    traj_mini_bsz_per_rank = traj_mini_bsz // num_dp_ranks

    mini_batch_loss_token_nums = []
    for _ in range(num_mini_batches):
        mini_batch_traj_idxs = []
        for dp_rank in range(num_dp_ranks):
            start_traj_idx = int(traj_bsz / num_dp_ranks * dp_rank)
            next_start_traj_idx = int(traj_bsz / num_dp_ranks * (dp_rank + 1))
            end_traj_idx = int(min(start_traj_idx + traj_mini_bsz_per_rank, next_start_traj_idx))
            mini_batch_traj_idxs.extend(list(range(start_traj_idx, end_traj_idx)))
        mini_batch_resp_mask = response_mask[mini_batch_traj_idxs]
        mini_batch_loss_token_num = mini_batch_resp_mask.sum()
        mini_batch_loss_token_nums.append(mini_batch_loss_token_num)

    return mini_batch_loss_token_nums
  1. Reweight the micro-batch-aggregated loss (micro_agg_loss) adaptive to the loss aggregation mode (loss_agg_mode) to get this micro-batch's contribution to accumulate for the mini-batch-aggregated loss (mini_loss_to_acc):
if self.config.loss_agg_mode == 'token-mean':
    mini_batch_loss_token_nums = data.meta_info['mini_batch_loss_token_nums']
    mini_batch_loss_token_num = mini_batch_loss_token_nums[mini_idx]
    num_valid_toks = response_mask.sum()
    mini_loss_to_acc = micro_agg_loss * num_valid_toks / mini_batch_loss_token_num
else:  # seq-mean
    mini_loss_to_acc = micro_agg_loss * (len(micro_data_chunk) / self.config.ppo_mini_batch_size)

Checklist

@tongyx361 tongyx361 requested review from PeterSH6, eric-haibin-lin, hiyouga and vermouth1992 and removed request for vermouth1992 April 3, 2025 22:47
Copy link
Copy Markdown
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could u summarize what the issue was and what's the impact to existing users?

@tongyx361 tongyx361 marked this pull request as draft April 3, 2025 23:37
@tongyx361
Copy link
Copy Markdown
Collaborator Author

tongyx361 commented Apr 30, 2025

Splitting into PRs:

  1. Variable name: [refactor] feat: separate data, batch and metric with clear variable names #1339
  2. loss_agg_mode: fix: add loss_agg_mode to critics #1340
  3. Fixing gradient accumulation in DP
  4. Test the fix

@tongyx361 tongyx361 closed this Apr 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants