Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def masked_batch_mean(x):
mean_kl = mean_kl_per_reward.mean()
return completion_length, mean_kl
completion_length, mean_kl = masked_batch_mean(kl_i)
return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1
return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, mask
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep grpo_compute_loss return arity consistent

grpo_compute_loss now returns 7 values, but the UnslothEfficientGRPO.forward helper compute_loss still unpacks only 6 (loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(...) at rl_replacements.py:504). In the GRPO training path this raises ValueError: too many values to unpack as soon as compute_loss is invoked, so training fails before loss/grad accumulation completes.

Useful? React with 👍 / 👎.

pass
RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss
RL_REPLACEMENTS["grpo_compute_loss_slow"] = \
Expand All @@ -501,7 +501,7 @@ def forward(ctx, _new_logps, _old_logps, _ref_logps, _sampling_per_token_logps,
if extra_kwargs is None:
extra_kwargs = {}
def compute_loss(new_logps, old_logps, ref_logps, sampling_per_token_logps, input_ids, mask, advantages, scaling):
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = grpo_compute_loss(
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, _mask = grpo_compute_loss(
ref_logps,
new_logps,
old_logps,
Expand Down Expand Up @@ -845,8 +845,8 @@ class Unsloth_Offloaded_Log_Softmax(torch.autograd.Function):
def forward(ctx, hidden_states, lm_head, index, chunks,
logit_scale_multiply, logit_scale_divide,
logit_softcapping, temperature):

ctx.saved_hidden_states = to_device(hidden_states, "cpu", non_blocking=True)
#Only the activations are needed so if we keep entire computational graph, keeps unnecessary memory on CPU so we detach it
ctx.saved_hidden_states = hidden_states.detach().contiguous().to("cpu", non_blocking=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please include a comment here describing the change and its implications (can ref the issue).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did that

ctx.device = hidden_states.device
ctx.dtype = hidden_states.dtype

Expand Down Expand Up @@ -990,7 +990,7 @@ def efficient_log_softmax(hidden_states, lm_head, index, chunks=32,
# Must force not returning hidden states but logits otherwise gibberish
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"

return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1
return loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask
# Old non efficient code path
new_logits = torch.matmul(new_hidden_states, lm_head.t())
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
Expand Down