Include completion_mask in return statement#528
Conversation
Added completion_mask to the return values of the function.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request updates a core function to ensure that a necessary Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request modifies the grpo_accumulated_loss function in unsloth_zoo/rl_replacements.py to include completion_mask in its return values. The change is straightforward and correctly adds an existing variable to the function's output tuple. This aligns with the stated goal of the pull request. The implementation appears correct and I have no further feedback.
Add mask to the return values of grpo_compute_loss function.
| logit_softcapping, temperature): | ||
|
|
||
| ctx.saved_hidden_states = to_device(hidden_states, "cpu", non_blocking=True) | ||
| ctx.saved_hidden_states = hidden_states.detach().contiguous().to("cpu", non_blocking=True) |
There was a problem hiding this comment.
Please include a comment here describing the change and its implications (can ref the issue).
danielhanchen
left a comment
There was a problem hiding this comment.
Review: PR #528 + PR #4140 (companion)
Summary
These two PRs together fix issues #4081 (shape mismatch crash in masked_batch_mean) and #4122 (CPU OOM from hidden state references). The approach is correct -- returning completion_mask from grpo_accumulated_loss and grpo_compute_loss so the caller uses the correctly-shaped mask (after create_completion_attention_mask reshapes it to include max_left_pad tokens).
Critical bug: UnslothEfficientGRPO.compute_loss will crash
This PR changes grpo_compute_loss return from 6 to 7 values (adding mask), but line 504 in UnslothEfficientGRPO.compute_loss still unpacks only 6 values:
# Line 504 -- WILL CRASH with ValueError: too many values to unpack (expected 6)
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...)This is the active code path called from grpo_accumulated_loss -> UnslothEfficientGRPO.apply. Fix:
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, _mask = grpo_compute_loss(...)The _mask can be discarded since grpo_accumulated_loss already has the correctly-reshaped completion_mask from its own scope (line 743). The aux tuple at line 519 should remain 6 values (no need to pass _mask through), keeping UnslothEfficientGRPO.forward return and backward signature unchanged.
.detach().contiguous() change (line 849)
Good change. Replacing to_device(hidden_states, "cpu", non_blocking=True) with hidden_states.detach().contiguous().to("cpu", non_blocking=True):
.detach()prevents computation graph from holding CPU tensor references (fixes OOM leak in #4122).contiguous()required for correct non-blocking CPU transfer
Minor: backwards compat branch in PR #4140
The backwards compat branch (~line 1085 in unsloth/models/rl_replacements.py) unpacks 5 values from grpo_accumulated_loss, but after this PR it returns 7. This is dead code (self.args always has loss_type in TRL 0.24+), so won't crash in practice.
Test Results
Tested with critical fix applied (line 504 unpack 7 values):
| Test | Model | Steps | Result | Peak Memory |
|---|---|---|---|---|
| GRPO (with PRs) | GPT-OSS 20B (4bit, MoE) | 20 | PASSED | 11.34 GB avg |
| GRPO (with PRs) | Qwen3 4B (4bit) | 20 | PASSED | 4.56 GB avg |
GPT-OSS 20B Losses: [-0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.101, 0.0, 0.0, 0.0, -0.028, 0.0, 0.0, 0.0, 0.0]
GPT-OSS 20B Grad Norms: [0.0, 0.0, 0.0, 0.0, 0.272, 0.256, 0.0, 0.0, 0.0, 0.299, 0.0, 0.268, 0.292, 0.0, 0.347, 0.286, 0.245, 0.247, 0.231, 0.293]
Qwen3 4B Losses: [0.104, 0.0, 0.361, 0.182, -0.357, 0.049, 0.412, 0.523, -0.076, 0.216, 0.0, 0.46, 0.39, -0.533, -0.352, 0.045, -0.298, -0.523, -0.198, -0.07]
Qwen3 4B Grad Norms: [0.181, 0.0, 0.229, 0.197, 0.331, 0.129, 0.227, 0.298, 0.217, 0.21, 0.0, 0.277, 0.218, 0.249, 0.252, 0.138, 0.198, 0.237, 0.268, 0.431]
Training runs completed without RuntimeError shape mismatch. Losses and grad norms are reasonable (no NaN/divergence). Memory usage is stable (no CPU OOM spikes seen).
Action Required
Please update line 504 in unsloth_zoo/rl_replacements.py to unpack 7 values before merging. Without this fix, all GRPO training will crash with ValueError: too many values to unpack.
danielhanchen
left a comment
There was a problem hiding this comment.
Thanks for the PR -- the overall approach is correct. Returning completion_mask from grpo_accumulated_loss and grpo_compute_loss is the right fix for the shape mismatch in masked_batch_mean (#4081), and the .detach().contiguous() change fixes the CPU OOM leak (#4122). However there is a critical bug that needs to be fixed before this can be merged.
Critical: UnslothEfficientGRPO.compute_loss will crash
grpo_compute_loss now returns 7 values (added mask), but UnslothEfficientGRPO.compute_loss at line 504 still unpacks only 6:
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...)This is the active code path -- called from grpo_accumulated_loss via UnslothEfficientGRPO.apply -> forward -> compute_loss. Every GRPO training run will hit this and crash with:
ValueError: too many values to unpack (expected 6)
Fix needed at line 504:
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, _mask = grpo_compute_loss(...)The _mask can be discarded here since grpo_accumulated_loss already has the correctly-reshaped completion_mask from its own scope (line 743 via create_completion_attention_mask). The aux tuple at line 519 should stay at 6 values -- no need to pass _mask through -- which keeps UnslothEfficientGRPO.forward (line 625-632) and backward (line 636) signatures unchanged.
.detach().contiguous() change (line 849) -- good
The old code to_device(hidden_states, "cpu", non_blocking=True) kept the autograd graph alive on CPU because hidden_states.requires_grad == True. The .to("cpu") call produced a CPU tensor with grad_fn=CopyBackward, anchoring the full forward computation graph in CPU memory. Over hundreds of steps this accumulated tens of GB of CPU memory (as reported in #4122 with memory jumping from 40GB to 90GB+).
The fix is correct:
.detach()severs the autograd graph reference.contiguous()ensures correct memory layout for the async CPU transfernon_blocking=Trueis safe here sincebackward()at line 866-869 moves the tensor back to GPU and recomputes forward insidetorch.enable_grad()-- standard manual gradient checkpointing pattern
Returning completion_mask from grpo_accumulated_loss (line 993) -- good
The returned completion_mask is the one from line 743, reshaped by create_completion_attention_mask() to (bsz, logits_to_keep + max_left_pad). This matches the shape of coef_1 and all other returned tensors. The companion PR (#4140 in unsloth) then reassigns the completion_mask variable in compute_loss so that masked_batch_mean uses the correctly-shaped mask instead of the original inputs["completion_mask"].
Test results
I tested both PRs together (with the line 504 fix applied locally) on two GRPO notebooks:
| Test | Model | Steps | Result | Peak Memory |
|---|---|---|---|---|
| GRPO | GPT-OSS 20B (4bit, MoE) | 20 | Passed | 11.34 GB avg |
| GRPO | Qwen3 4B (4bit) | 20 | Passed | 4.56 GB avg |
No shape mismatch errors, losses and grad norms are reasonable (no NaN/divergence), memory usage is stable.
GPT-OSS 20B losses: [-0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.101, 0.0, 0.0, 0.0, -0.028, 0.0, 0.0, 0.0, 0.0]
Qwen3 4B losses: [0.104, 0.0, 0.361, 0.182, -0.357, 0.049, 0.412, 0.523, -0.076, 0.216, 0.0, 0.46, 0.39, -0.533, -0.352, 0.045, -0.298, -0.523, -0.198, -0.07]
Backwards compatible with TRL 0.22.2 through main -- all accessed attributes (loss_type, importance_sampling_level, epsilon_low, epsilon_high, delta) exist in TRL >= 0.22.
Please fix line 504 and I can merge this.
|
I went through both PRs in detail and tested them together. Here is a full breakdown of what they do, what works, what does not, and what needs to change before merging. What these PRs fixIssue #4081 -- shape mismatch in During GRPO training, The 4-token difference (828 vs 824) is exactly Issue #4122 -- CPU OOM from hidden state offload: In How the shape mismatch fix works (code path)
How the CPU OOM fix works (code path)Line 849 in
Critical bug: line 504 still unpacks 6 values
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...)This is the active code path called from Required fix at line 504: loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, _mask = grpo_compute_loss(...)
Other observations (pre-existing, not blocking)These are all pre-existing issues on main, not introduced by this PR:
Test resultsTested both PRs together with the line 504 fix applied locally:
GPT-OSS 20B losses: Qwen3 4B losses: No shape mismatch errors. Losses and grad norms look reasonable with no NaN or divergence. Memory usage is stable with no CPU OOM spikes across all 20 steps. Backwards compatibilityAll attributes accessed in the main code path ( SummaryThe approach is correct and both issues (#4081, #4122) are properly addressed. The only required change before merge is fixing the unpack count at line ~504 in |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: bb459d2664
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| return completion_length, mean_kl | ||
| completion_length, mean_kl = masked_batch_mean(kl_i) | ||
| return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 | ||
| return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, mask |
There was a problem hiding this comment.
Keep grpo_compute_loss return arity consistent
grpo_compute_loss now returns 7 values, but the UnslothEfficientGRPO.forward helper compute_loss still unpacks only 6 (loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...) at rl_replacements.py:504). In the GRPO training path this raises ValueError: too many values to unpack as soon as compute_loss is invoked, so training fails before loss/grad accumulation completes.
Useful? React with 👍 / 👎.
Added a comment to clarify the purpose of detaching hidden states.
Added completion_mask to the return values of the function. Relies on: unslothai/unsloth#4140. This pr also detaches the hidden states activations from the computation graph when offloading to cpu to save more memory and prevent crashes like unslothai/unsloth#4122.