Skip to content

Guard GRPO text branch against models returning logits#602

Merged
Datta0 merged 3 commits into
unslothai:mainfrom
Rorical:fix-grpo-text-branch-logits-guard
May 22, 2026
Merged

Guard GRPO text branch against models returning logits#602
Datta0 merged 3 commits into
unslothai:mainfrom
Rorical:fix-grpo-text-branch-logits-guard

Conversation

@Rorical
Copy link
Copy Markdown
Contributor

@Rorical Rorical commented Apr 19, 2026

Problem

In grpo_accumulated_loss (unsloth_zoo/rl_replacements.py), the branch handling text-only inputs (if pixel_values is None:) passes the model's .logits output unconditionally into efficient_log_softmax, which assumes hidden states and computes hidden_states @ lm_head.t(). When a model's forward actually returns real logits (last dim = vocab) rather than pre-lm_head hidden states, this produces a shape-mismatched matmul and crashes the RL step.

The vision branch (else:, lines 986–1001) already dispatches based on a shape check:

if new_hidden_states_chunk.shape[-1] == lm_head.shape[1]:
    logprobs_chunk = efficient_log_softmax(...)
else:
    logprobs_chunk = chunked_selective_log_softmax(new_hidden_states_chunk, completion_ids, temperature)

The text branch is missing the same guard.

Repro

Training nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 (hybrid Mamba+attention MoE, text-only) with Unsloth PatchFastRL("GRPO", ...) and trl.GRPOTrainer. The model's remote modeling code returns genuine logits from .logits, and the text branch crashes with a dimension mismatch inside the compiled chunked_hidden_states_selective_log_softmax. Other text models that ship with custom modeling code can hit the same path.

Workaround in the wild: monkey-patching the compiled cache file to add the guard.

Fix

Mirror the existing vision-branch guard into the text branch. When the model returns actual logits, skip the lm_head matmul and go directly to chunked_selective_log_softmax. No behavioral change for models that follow Unsloth's usual pattern of returning hidden states through .logits.

Diff

Pure addition of the shape-check branch in the text path; no logic changes elsewhere.

The vision branch in grpo_accumulated_loss already dispatches between
efficient_log_softmax and chunked_selective_log_softmax based on whether
the model's .logits is actually hidden states or real logits. The text
branch (pixel_values is None) was missing the same guard and passed the
tensor straight into efficient_log_softmax, which assumes hidden states
and does hidden_states @ lm_head.t(). This crashes with a shape-mismatch
matmul for any text model whose forward genuinely returns logits
(e.g. Nemotron-H hybrid Mamba+attention, custom remote-code models).

Mirror the vision branch's guard so text models that return logits go
directly to chunked_selective_log_softmax.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a guard in unsloth_zoo/rl_replacements.py to handle cases where the model returns logits directly instead of hidden states. A review comment points out that the logic for handling direct logits uses a fixed chunk size which may cause memory issues for large vocabularies and notes that scaling and softcapping parameters are currently ignored in that execution path.

Comment thread unsloth_zoo/rl_replacements.py Outdated
Comment on lines +963 to +977
if new_hidden_states_chunk.shape[-1] == lm_head.shape[1]:
logprobs_chunk = efficient_log_softmax(
new_hidden_states_chunk,
lm_head,
completion_ids,
chunks=input_ids_chunk.shape[0]*multiplier,
logit_scale_multiply=logit_scale_multiply,
logit_scale_divide=logit_scale_divide,
logit_softcapping=logit_softcapping,
temperature=temperature,
batch_size = B
)
else:
# Model returned logits directly - scaling/softcapping already applied by model forward
logprobs_chunk = chunked_selective_log_softmax(new_hidden_states_chunk, completion_ids, temperature)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this guard mirrors the logic in the vision branch, chunked_selective_log_softmax uses a fixed number of chunks (4), whereas efficient_log_softmax uses a dynamic number of chunks based on the multiplier. For models with very large vocabularies that return logits directly, the else branch might be more susceptible to OOM errors. Additionally, note that logit_scale_multiply, logit_scale_divide, and logit_softcapping are ignored in the else branch. If these parameters are expected to be applied to the logits returned by the model, they should be handled before or within the chunked_selective_log_softmax call.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c76e33eee0

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/rl_replacements.py Outdated
Comment on lines +998 to +1000
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip scaling when logits already come from model forward

In the direct-logits fallback (else branch), this now passes logit_scale_multiply/divide/softcapping into chunked_selective_log_softmax, which re-applies those transforms to logits that may already have had them applied by the model forward. For models that ignore UNSLOTH_RETURN_HIDDEN_STATES and return post-head logits (including softcapped/scaled logits), this silently changes token logprobs and therefore GRPO loss/KL, corrupting training rather than just fixing the shape mismatch. The pre-change vision fallback explicitly avoided this double application, so this is a behavior regression for direct-logits models.

Useful? React with 👍 / 👎.

@Rorical Rorical force-pushed the fix-grpo-text-branch-logits-guard branch from c76e33e to 96b0a2a Compare April 19, 2026 14:49
Extend chunked_selective_log_softmax with a chunks parameter (default 4
for back-compat) and have the text-branch direct-logits fallback pass
input_ids_chunk.shape[0] * multiplier, mirroring the lm_head matmul path
above. Large-vocab models that return post-head logits previously went
through a hardcoded 4-way chunking that could blow out VRAM.

Scaling/softcapping are intentionally not forwarded: when the model's
forward returns real logits it has already applied those transforms,
and re-applying would double-transform the logits and corrupt GRPO
logprobs.
@Rorical Rorical force-pushed the fix-grpo-text-branch-logits-guard branch from 96b0a2a to 4a8e82d Compare April 19, 2026 14:53
@Datta0
Copy link
Copy Markdown
Collaborator

Datta0 commented Apr 20, 2026

Nice find. Maybe we should consolidate the two into a function and reuse?

Per @Datta0's review on unslothai#602, the text and vision branches had identical
guard + dispatch logic. Extract into compute_logprobs_chunk() to remove
the duplication.
@Rorical
Copy link
Copy Markdown
Contributor Author

Rorical commented Apr 22, 2026

Great suggestion! I've merged them into a shared helper function

Copy link
Copy Markdown
Collaborator

@Datta0 Datta0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bi1101
Copy link
Copy Markdown

bi1101 commented May 9, 2026

Hi @Datta0 , may I ask if this PR is going to be merged soon? This is blocking text-only GRPO for Qwen 3.5

@lastrei
Copy link
Copy Markdown

lastrei commented May 21, 2026

Hi @Datta0 , may I ask if this PR is going to be merged soon? This is blocking text-only GRPO for Qwen 3.5

Yes,This update breaks the text-only GRPO for Qwen 3.5. Once I load the model using FastVisionModel, it still expects images even when my input data contains only text and no images at all, triggering RuntimeError: The size of tensor a (2) must match the size of tensor b (0) at non-singleton dimension 1.

@Datta0 Datta0 merged commit 388d935 into unslothai:main May 22, 2026
@Rorical Rorical deleted the fix-grpo-text-branch-logits-guard branch May 22, 2026 13:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants