Skip to content

Include completion_mask in return statement#528

Merged
danielhanchen merged 6 commits into
unslothai:mainfrom
pluesclues:completion_mask_fix
Mar 8, 2026
Merged

Include completion_mask in return statement#528
danielhanchen merged 6 commits into
unslothai:mainfrom
pluesclues:completion_mask_fix

Conversation

@pluesclues
Copy link
Copy Markdown
Collaborator

@pluesclues pluesclues commented Mar 2, 2026

Added completion_mask to the return values of the function. Relies on: unslothai/unsloth#4140. This pr also detaches the hidden states activations from the computation graph when offloading to cpu to save more memory and prevent crashes like unslothai/unsloth#4122.

Added completion_mask to the return values of the function.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request updates a core function to ensure that a necessary completion_mask is explicitly returned. This change is crucial for maintaining data integrity and enabling subsequent operations that depend on this mask, particularly in the context of reinforcement learning replacements within the Unsloth framework.

Highlights

  • Function Return Value: The efficient_log_softmax function in unsloth_zoo/rl_replacements.py now includes completion_mask in its return statement.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth_zoo/rl_replacements.py
    • Added completion_mask to the return values of the efficient_log_softmax function.
Activity
  • This change relies on the completion of pull request #4140.
  • No other review comments or activity have been recorded yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the grpo_accumulated_loss function in unsloth_zoo/rl_replacements.py to include completion_mask in its return values. The change is straightforward and correctly adds an existing variable to the function's output tuple. This aligns with the stated goal of the pull request. The implementation appears correct and I have no further feedback.

logit_softcapping, temperature):

ctx.saved_hidden_states = to_device(hidden_states, "cpu", non_blocking=True)
ctx.saved_hidden_states = hidden_states.detach().contiguous().to("cpu", non_blocking=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please include a comment here describing the change and its implications (can ref the issue).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did that

Copy link
Copy Markdown
Member

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: PR #528 + PR #4140 (companion)

Summary

These two PRs together fix issues #4081 (shape mismatch crash in masked_batch_mean) and #4122 (CPU OOM from hidden state references). The approach is correct -- returning completion_mask from grpo_accumulated_loss and grpo_compute_loss so the caller uses the correctly-shaped mask (after create_completion_attention_mask reshapes it to include max_left_pad tokens).

Critical bug: UnslothEfficientGRPO.compute_loss will crash

This PR changes grpo_compute_loss return from 6 to 7 values (adding mask), but line 504 in UnslothEfficientGRPO.compute_loss still unpacks only 6 values:

# Line 504 -- WILL CRASH with ValueError: too many values to unpack (expected 6)
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...)

This is the active code path called from grpo_accumulated_loss -> UnslothEfficientGRPO.apply. Fix:

loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, _mask = grpo_compute_loss(...)

The _mask can be discarded since grpo_accumulated_loss already has the correctly-reshaped completion_mask from its own scope (line 743). The aux tuple at line 519 should remain 6 values (no need to pass _mask through), keeping UnslothEfficientGRPO.forward return and backward signature unchanged.

.detach().contiguous() change (line 849)

Good change. Replacing to_device(hidden_states, "cpu", non_blocking=True) with hidden_states.detach().contiguous().to("cpu", non_blocking=True):

  • .detach() prevents computation graph from holding CPU tensor references (fixes OOM leak in #4122)
  • .contiguous() required for correct non-blocking CPU transfer

Minor: backwards compat branch in PR #4140

The backwards compat branch (~line 1085 in unsloth/models/rl_replacements.py) unpacks 5 values from grpo_accumulated_loss, but after this PR it returns 7. This is dead code (self.args always has loss_type in TRL 0.24+), so won't crash in practice.

Test Results

Tested with critical fix applied (line 504 unpack 7 values):

Test Model Steps Result Peak Memory
GRPO (with PRs) GPT-OSS 20B (4bit, MoE) 20 PASSED 11.34 GB avg
GRPO (with PRs) Qwen3 4B (4bit) 20 PASSED 4.56 GB avg

GPT-OSS 20B Losses: [-0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.101, 0.0, 0.0, 0.0, -0.028, 0.0, 0.0, 0.0, 0.0]
GPT-OSS 20B Grad Norms: [0.0, 0.0, 0.0, 0.0, 0.272, 0.256, 0.0, 0.0, 0.0, 0.299, 0.0, 0.268, 0.292, 0.0, 0.347, 0.286, 0.245, 0.247, 0.231, 0.293]

Qwen3 4B Losses: [0.104, 0.0, 0.361, 0.182, -0.357, 0.049, 0.412, 0.523, -0.076, 0.216, 0.0, 0.46, 0.39, -0.533, -0.352, 0.045, -0.298, -0.523, -0.198, -0.07]
Qwen3 4B Grad Norms: [0.181, 0.0, 0.229, 0.197, 0.331, 0.129, 0.227, 0.298, 0.217, 0.21, 0.0, 0.277, 0.218, 0.249, 0.252, 0.138, 0.198, 0.237, 0.268, 0.431]

Training runs completed without RuntimeError shape mismatch. Losses and grad norms are reasonable (no NaN/divergence). Memory usage is stable (no CPU OOM spikes seen).

Action Required

Please update line 504 in unsloth_zoo/rl_replacements.py to unpack 7 values before merging. Without this fix, all GRPO training will crash with ValueError: too many values to unpack.

Copy link
Copy Markdown
Member

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR -- the overall approach is correct. Returning completion_mask from grpo_accumulated_loss and grpo_compute_loss is the right fix for the shape mismatch in masked_batch_mean (#4081), and the .detach().contiguous() change fixes the CPU OOM leak (#4122). However there is a critical bug that needs to be fixed before this can be merged.

Critical: UnslothEfficientGRPO.compute_loss will crash

grpo_compute_loss now returns 7 values (added mask), but UnslothEfficientGRPO.compute_loss at line 504 still unpacks only 6:

loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...)

This is the active code path -- called from grpo_accumulated_loss via UnslothEfficientGRPO.apply -> forward -> compute_loss. Every GRPO training run will hit this and crash with:

ValueError: too many values to unpack (expected 6)

Fix needed at line 504:

loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, _mask = grpo_compute_loss(...)

The _mask can be discarded here since grpo_accumulated_loss already has the correctly-reshaped completion_mask from its own scope (line 743 via create_completion_attention_mask). The aux tuple at line 519 should stay at 6 values -- no need to pass _mask through -- which keeps UnslothEfficientGRPO.forward (line 625-632) and backward (line 636) signatures unchanged.

.detach().contiguous() change (line 849) -- good

The old code to_device(hidden_states, "cpu", non_blocking=True) kept the autograd graph alive on CPU because hidden_states.requires_grad == True. The .to("cpu") call produced a CPU tensor with grad_fn=CopyBackward, anchoring the full forward computation graph in CPU memory. Over hundreds of steps this accumulated tens of GB of CPU memory (as reported in #4122 with memory jumping from 40GB to 90GB+).

The fix is correct:

  • .detach() severs the autograd graph reference
  • .contiguous() ensures correct memory layout for the async CPU transfer
  • non_blocking=True is safe here since backward() at line 866-869 moves the tensor back to GPU and recomputes forward inside torch.enable_grad() -- standard manual gradient checkpointing pattern

Returning completion_mask from grpo_accumulated_loss (line 993) -- good

The returned completion_mask is the one from line 743, reshaped by create_completion_attention_mask() to (bsz, logits_to_keep + max_left_pad). This matches the shape of coef_1 and all other returned tensors. The companion PR (#4140 in unsloth) then reassigns the completion_mask variable in compute_loss so that masked_batch_mean uses the correctly-shaped mask instead of the original inputs["completion_mask"].

Test results

I tested both PRs together (with the line 504 fix applied locally) on two GRPO notebooks:

Test Model Steps Result Peak Memory
GRPO GPT-OSS 20B (4bit, MoE) 20 Passed 11.34 GB avg
GRPO Qwen3 4B (4bit) 20 Passed 4.56 GB avg

No shape mismatch errors, losses and grad norms are reasonable (no NaN/divergence), memory usage is stable.

GPT-OSS 20B losses: [-0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.101, 0.0, 0.0, 0.0, -0.028, 0.0, 0.0, 0.0, 0.0]
Qwen3 4B losses: [0.104, 0.0, 0.361, 0.182, -0.357, 0.049, 0.412, 0.523, -0.076, 0.216, 0.0, 0.46, 0.39, -0.533, -0.352, 0.045, -0.298, -0.523, -0.198, -0.07]

Backwards compatible with TRL 0.22.2 through main -- all accessed attributes (loss_type, importance_sampling_level, epsilon_low, epsilon_high, delta) exist in TRL >= 0.22.

Please fix line 504 and I can merge this.

@danielhanchen
Copy link
Copy Markdown
Member

I went through both PRs in detail and tested them together. Here is a full breakdown of what they do, what works, what does not, and what needs to change before merging.

What these PRs fix

Issue #4081 -- shape mismatch in masked_batch_mean:

During GRPO training, create_completion_attention_mask() inside grpo_accumulated_loss reshapes completion_mask from (bsz, logits_to_keep) to (bsz, logits_to_keep + max_left_pad). The reshaped mask was never returned to the caller. The caller in compute_loss continued using the original completion_mask from inputs["completion_mask"], which has the smaller shape. When max_left_pad > 0, loss tensors like coef_1 have the larger shape, so the elementwise multiply x * completion_mask inside masked_batch_mean crashes:

RuntimeError: The size of tensor a (828) must match the size of tensor b (824) at non-singleton dimension 1

The 4-token difference (828 vs 824) is exactly max_left_pad. This only triggers when the batch has variable-length prompts that require left-padding.

Issue #4122 -- CPU OOM from hidden state offload:

In Unsloth_Offloaded_Log_Softmax.forward, hidden states are offloaded to CPU for gradient checkpointing. The old code used to_device(hidden_states, "cpu", non_blocking=True), but hidden_states has requires_grad=True. Calling .to("cpu") on a tensor with a grad function produces a CPU tensor with grad_fn=CopyBackward, which holds a reference to the full forward computation graph. Over hundreds of steps these references accumulate on CPU, causing memory to climb from ~40GB to 90GB+ until the OOM killer terminates the process.

How the shape mismatch fix works (code path)

  1. compute_loss in unsloth/models/rl_replacements.py reads completion_mask = inputs["completion_mask"] -- shape (bsz, logits_to_keep)
  2. This mask is passed into grpo_accumulated_loss in unsloth_zoo/rl_replacements.py
  3. Inside grpo_accumulated_loss, at line 743, create_completion_attention_mask() produces a new mask with shape (bsz, logits_to_keep + max_left_pad) and assigns it to the local completion_mask variable
  4. All subsequent computations inside grpo_accumulated_loss use this reshaped mask, so loss tensors like coef_1 also have the larger + max_left_pad dimension
  5. PR Include completion_mask in return statement #528 returns this reshaped completion_mask as the 7th value from both grpo_accumulated_loss (line 993) and grpo_compute_loss (line 484)
  6. PR #4140 captures this 7th return value and reassigns the completion_mask variable in compute_loss (line ~1059 in the main branch)
  7. masked_batch_mean is a closure defined after this reassignment, so it captures the updated completion_mask by reference. When masked_batch_mean(coef_1) runs, both coef_1 and completion_mask have shape (bsz, logits_to_keep + max_left_pad) and the shapes match

How the CPU OOM fix works (code path)

Line 849 in unsloth_zoo/rl_replacements.py:

  • Old: ctx.saved_hidden_states = to_device(hidden_states, "cpu", non_blocking=True)
  • New: ctx.saved_hidden_states = hidden_states.detach().contiguous().to("cpu", non_blocking=True)

.detach() severs the autograd graph reference so the CPU copy does not anchor the forward computation graph in memory. .contiguous() ensures correct memory layout for the async transfer. In backward() (lines 866-869), the tensor is moved back to GPU, requires_grad_(True) is re-set, and forward is recomputed inside torch.enable_grad(). This is the standard manual gradient checkpointing pattern, so detaching at the save point is correct and does not break gradient flow.

Critical bug: line 504 still unpacks 6 values

grpo_compute_loss now returns 7 values after this PR, but UnslothEfficientGRPO.compute_loss at line ~504 still unpacks only 6:

loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...)

This is the active code path called from grpo_accumulated_loss -> UnslothEfficientGRPO.apply -> forward -> compute_loss. It will crash every GRPO run with:

ValueError: too many values to unpack (expected 6)

Required fix at line 504:

loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, _mask = grpo_compute_loss(...)

_mask can be discarded here since grpo_accumulated_loss already holds the correctly-reshaped completion_mask from its own scope (line 743). The aux tuple at line 519 should stay at 6 values -- no need to propagate _mask through the autograd boundary -- keeping UnslothEfficientGRPO.forward (lines 625-632) and backward (line 636) signatures unchanged.

Other observations (pre-existing, not blocking)

These are all pre-existing issues on main, not introduced by this PR:

  1. grpo_compute_loss_slow call (PR #4140, lines ~1024-1049): The per_token_logps is not None branch is dead code (never reached in current TRL). The positional args to grpo_compute_loss_slow are in the wrong order, and sampling_per_token_logps is passed both positionally and as a keyword argument, which would cause a TypeError. The PR reformats this block but does not fix the arg order. Dead code, not blocking.

  2. Backwards compat branch (PR #4140, line ~1088): The else branch for not hasattr(self.args, "loss_type") unpacks 5 values from grpo_accumulated_loss which now returns 7. Would crash if reached, but loss_type has been in TRL since v0.17.0, and Unsloth only supports TRL >= 0.22. Dead code, not blocking.

  3. Unreachable code after return (PR Include completion_mask in return statement #528, lines ~999-1007): After the return at line 993, old code calls grpo_compute_loss with wrong positional arg order. Unreachable, not blocking.

Test results

Tested both PRs together with the line 504 fix applied locally:

Model Steps Result Avg Peak Memory
GPT-OSS 20B (4-bit, MoE) 20 Passed 11.34 GB
Qwen3 4B (4-bit) 20 Passed 4.56 GB

GPT-OSS 20B losses: [-0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.101, 0.0, 0.0, 0.0, -0.028, 0.0, 0.0, 0.0, 0.0]
GPT-OSS 20B grad norms: [0.0, 0.0, 0.0, 0.0, 0.272, 0.256, 0.0, 0.0, 0.0, 0.299, 0.0, 0.268, 0.292, 0.0, 0.347, 0.286, 0.245, 0.247, 0.231, 0.293]

Qwen3 4B losses: [0.104, 0.0, 0.361, 0.182, -0.357, 0.049, 0.412, 0.523, -0.076, 0.216, 0.0, 0.46, 0.39, -0.533, -0.352, 0.045, -0.298, -0.523, -0.198, -0.07]
Qwen3 4B grad norms: [0.181, 0.0, 0.229, 0.197, 0.331, 0.129, 0.227, 0.298, 0.217, 0.21, 0.0, 0.277, 0.218, 0.249, 0.252, 0.138, 0.198, 0.237, 0.268, 0.431]

No shape mismatch errors. Losses and grad norms look reasonable with no NaN or divergence. Memory usage is stable with no CPU OOM spikes across all 20 steps.

Backwards compatibility

All attributes accessed in the main code path (loss_type, importance_sampling_level, epsilon_low, epsilon_high, delta) exist in TRL >= 0.20.0. loss_type has been in TRL since v0.17.0. No compatibility issues with TRL 0.22.2 through main.

Summary

The approach is correct and both issues (#4081, #4122) are properly addressed. The only required change before merge is fixing the unpack count at line ~504 in unsloth_zoo/rl_replacements.py from 6 to 7 values. Once that is pushed, this is ready to merge.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: bb459d2664

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

return completion_length, mean_kl
completion_length, mean_kl = masked_batch_mean(kl_i)
return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1
return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, mask
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep grpo_compute_loss return arity consistent

grpo_compute_loss now returns 7 values, but the UnslothEfficientGRPO.forward helper compute_loss still unpacks only 6 (loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...) at rl_replacements.py:504). In the GRPO training path this raises ValueError: too many values to unpack as soon as compute_loss is invoked, so training fails before loss/grad accumulation completes.

Useful? React with 👍 / 👎.

Added a comment to clarify the purpose of detaching hidden states.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants