-
Notifications
You must be signed in to change notification settings - Fork 41
Fix microbatch loss scale when loss_agg_mode is "token-mean" #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix microbatch loss scale when loss_agg_mode is "token-mean" #336
Conversation
Summary of ChangesHello @yanxi-chen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request aims to correct an inconsistency in how loss scaling is applied across microbatches within the Trinity-RFT framework when using "token-mean" loss aggregation. It introduces an experimental configuration flag that, when enabled, will adjust microbatch loss scaling to be proportional to the number of response tokens, aligning it with the intended behavior of token-mean aggregation. This change is currently in a Work-In-Progress state, pending validation experiments. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an experimental feature to fix the microbatch loss scaling for "token-mean" loss aggregation. The changes, controlled by a new fix_actor_microbatch_loss_scale flag, correctly adjust the loss scaling logic in dp_actor.py to be proportional to the number of tokens in a microbatch. The implementation is clear and well-contained. My review includes one suggestion to improve robustness by handling a potential division-by-zero edge case, which would prevent crashes during training.
|
/unittest-module-trainer |
1 similar comment
|
/unittest-module-trainer |
|
/unittest-module-common |
|
/unittest-module-common |
Summary
Failed Tests
Tests
Github Test Reporter by CTRF 💚 |
Description
Motivation:
In Trinity-RFT (and also in verl),
loss_agg_mode = "token-mean"is only applicable for loss calculation within a microbatch, whereas loss scaling / weighting for multiple microbatches of the same minibatch is still determined by the number of sequences, as shown indp_actor.py. This deviates from the desired behavior for token-mean loss aggregation.Solution:
This PR adds an experimental config parameter
fix_actor_microbatch_loss_scale(bool). When it is set to True &&loss_agg_modeis "token-mean", the loss scale of each microbatch is set proportional to its total number of response tokens, matching the desired behavior of token-mean loss aggregation for a mini-batch mathematically.The default value of
fix_actor_microbatch_loss_scaleis currently False; we keep default behavior of Trinity unchanged for now, until this fix has been thoroughly validated and shown benefits.Experimental results:
Marginal difference from the original implementation in a single-turn math reasoning task.
[TODO] Try out multi-step agentic tasks (where sequence lengths could be more diverse).
Limitations:
Only applicable to DP actor, not critic or Megatron actor.
Checklist
Please check the following items before code is ready to be reviewed.