-
Notifications
You must be signed in to change notification settings - Fork 265
Include completion_mask in return statement #528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
74c0994
Include completion_mask in return statement
pluesclues a41bdbd
Refactor hidden_states handling in forward method
pluesclues f3936d2
Include mask in grpo_compute_loss return values
pluesclues bb459d2
Merge branch 'unslothai:main' into completion_mask_fix
pluesclues 6d42b0c
Fix argument passing in compute_loss function
pluesclues cc0ffd3
Clarify memory management in forward method
pluesclues File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -481,7 +481,7 @@ def masked_batch_mean(x): | |
| mean_kl = mean_kl_per_reward.mean() | ||
| return completion_length, mean_kl | ||
| completion_length, mean_kl = masked_batch_mean(kl_i) | ||
| return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 | ||
| return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, mask | ||
| pass | ||
| RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss | ||
| RL_REPLACEMENTS["grpo_compute_loss_slow"] = \ | ||
|
|
@@ -501,7 +501,7 @@ def forward(ctx, _new_logps, _old_logps, _ref_logps, _sampling_per_token_logps, | |
| if extra_kwargs is None: | ||
| extra_kwargs = {} | ||
| def compute_loss(new_logps, old_logps, ref_logps, sampling_per_token_logps, input_ids, mask, advantages, scaling): | ||
| loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss( | ||
| loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, _mask = grpo_compute_loss( | ||
| ref_logps, | ||
| new_logps, | ||
| old_logps, | ||
|
|
@@ -845,8 +845,8 @@ class Unsloth_Offloaded_Log_Softmax(torch.autograd.Function): | |
| def forward(ctx, hidden_states, lm_head, index, chunks, | ||
| logit_scale_multiply, logit_scale_divide, | ||
| logit_softcapping, temperature): | ||
|
|
||
| ctx.saved_hidden_states = to_device(hidden_states, "cpu", non_blocking=True) | ||
| #Only the activations are needed so if we keep entire computational graph, keeps unnecessary memory on CPU so we detach it | ||
| ctx.saved_hidden_states = hidden_states.detach().contiguous().to("cpu", non_blocking=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please include a comment here describing the change and its implications (can ref the issue).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did that |
||
| ctx.device = hidden_states.device | ||
| ctx.dtype = hidden_states.dtype | ||
|
|
||
|
|
@@ -990,7 +990,7 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32, | |
| # Must force not returning hidden states but logits otherwise gibberish | ||
| os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" | ||
|
|
||
| return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 | ||
| return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask | ||
| # Old non efficient code path | ||
| new_logits = torch.matmul(new_hidden_states, lm_head.t()) | ||
| new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grpo_compute_lossnow returns 7 values, but theUnslothEfficientGRPO.forwardhelpercompute_lossstill unpacks only 6 (loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...)atrl_replacements.py:504). In the GRPO training path this raisesValueError: too many values to unpackas soon ascompute_lossis invoked, so training fails before loss/grad accumulation completes.Useful? React with 👍 / 👎.