Guard GRPO text branch against models returning logits#602
Conversation
The vision branch in grpo_accumulated_loss already dispatches between efficient_log_softmax and chunked_selective_log_softmax based on whether the model's .logits is actually hidden states or real logits. The text branch (pixel_values is None) was missing the same guard and passed the tensor straight into efficient_log_softmax, which assumes hidden states and does hidden_states @ lm_head.t(). This crashes with a shape-mismatch matmul for any text model whose forward genuinely returns logits (e.g. Nemotron-H hybrid Mamba+attention, custom remote-code models). Mirror the vision branch's guard so text models that return logits go directly to chunked_selective_log_softmax.
There was a problem hiding this comment.
Code Review
This pull request introduces a guard in unsloth_zoo/rl_replacements.py to handle cases where the model returns logits directly instead of hidden states. A review comment points out that the logic for handling direct logits uses a fixed chunk size which may cause memory issues for large vocabularies and notes that scaling and softcapping parameters are currently ignored in that execution path.
| if new_hidden_states_chunk.shape[-1] == lm_head.shape[1]: | ||
| logprobs_chunk = efficient_log_softmax( | ||
| new_hidden_states_chunk, | ||
| lm_head, | ||
| completion_ids, | ||
| chunks=input_ids_chunk.shape[0]*multiplier, | ||
| logit_scale_multiply=logit_scale_multiply, | ||
| logit_scale_divide=logit_scale_divide, | ||
| logit_softcapping=logit_softcapping, | ||
| temperature=temperature, | ||
| batch_size = B | ||
| ) | ||
| else: | ||
| # Model returned logits directly - scaling/softcapping already applied by model forward | ||
| logprobs_chunk = chunked_selective_log_softmax(new_hidden_states_chunk, completion_ids, temperature) |
There was a problem hiding this comment.
While this guard mirrors the logic in the vision branch, chunked_selective_log_softmax uses a fixed number of chunks (4), whereas efficient_log_softmax uses a dynamic number of chunks based on the multiplier. For models with very large vocabularies that return logits directly, the else branch might be more susceptible to OOM errors. Additionally, note that logit_scale_multiply, logit_scale_divide, and logit_softcapping are ignored in the else branch. If these parameters are expected to be applied to the logits returned by the model, they should be handled before or within the chunked_selective_log_softmax call.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c76e33eee0
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| logit_scale_multiply = logit_scale_multiply, | ||
| logit_scale_divide = logit_scale_divide, | ||
| logit_softcapping = logit_softcapping, |
There was a problem hiding this comment.
Skip scaling when logits already come from model forward
In the direct-logits fallback (else branch), this now passes logit_scale_multiply/divide/softcapping into chunked_selective_log_softmax, which re-applies those transforms to logits that may already have had them applied by the model forward. For models that ignore UNSLOTH_RETURN_HIDDEN_STATES and return post-head logits (including softcapped/scaled logits), this silently changes token logprobs and therefore GRPO loss/KL, corrupting training rather than just fixing the shape mismatch. The pre-change vision fallback explicitly avoided this double application, so this is a behavior regression for direct-logits models.
Useful? React with 👍 / 👎.
c76e33e to
96b0a2a
Compare
Extend chunked_selective_log_softmax with a chunks parameter (default 4 for back-compat) and have the text-branch direct-logits fallback pass input_ids_chunk.shape[0] * multiplier, mirroring the lm_head matmul path above. Large-vocab models that return post-head logits previously went through a hardcoded 4-way chunking that could blow out VRAM. Scaling/softcapping are intentionally not forwarded: when the model's forward returns real logits it has already applied those transforms, and re-applying would double-transform the logits and corrupt GRPO logprobs.
96b0a2a to
4a8e82d
Compare
|
Nice find. Maybe we should consolidate the two into a function and reuse? |
Per @Datta0's review on unslothai#602, the text and vision branches had identical guard + dispatch logic. Extract into compute_logprobs_chunk() to remove the duplication.
|
Great suggestion! I've merged them into a shared helper function |
|
Hi @Datta0 , may I ask if this PR is going to be merged soon? This is blocking text-only GRPO for Qwen 3.5 |
Yes,This update breaks the text-only GRPO for Qwen 3.5. Once I load the model using FastVisionModel, it still expects images even when my input data contains only text and no images at all, triggering RuntimeError: The size of tensor a (2) must match the size of tensor b (0) at non-singleton dimension 1. |
Problem
In
grpo_accumulated_loss(unsloth_zoo/rl_replacements.py), the branch handling text-only inputs (if pixel_values is None:) passes the model's.logitsoutput unconditionally intoefficient_log_softmax, which assumes hidden states and computeshidden_states @ lm_head.t(). When a model's forward actually returns real logits (last dim = vocab) rather than pre-lm_head hidden states, this produces a shape-mismatched matmul and crashes the RL step.The vision branch (
else:, lines 986–1001) already dispatches based on a shape check:The text branch is missing the same guard.
Repro
Training
nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16(hybrid Mamba+attention MoE, text-only) with UnslothPatchFastRL("GRPO", ...)andtrl.GRPOTrainer. The model's remote modeling code returns genuine logits from.logits, and the text branch crashes with a dimension mismatch inside the compiledchunked_hidden_states_selective_log_softmax. Other text models that ship with custom modeling code can hit the same path.Workaround in the wild: monkey-patching the compiled cache file to add the guard.
Fix
Mirror the existing vision-branch guard into the text branch. When the model returns actual logits, skip the
lm_headmatmul and go directly tochunked_selective_log_softmax. No behavioral change for models that follow Unsloth's usual pattern of returning hidden states through.logits.Diff
Pure addition of the shape-check branch in the text path; no logic changes elsewhere.