Fix VLM GRPO matmul shape mismatch in _get_per_token_logps_and_entropies#4301
Conversation
VLM models (e.g. Qwen2.5-VL) can return logits [B*T, vocab_size] instead of hidden states [B*T, hidden_dim] from their forward pass. When this happens, chunked_hidden_states_selective_log_softmax tries to compute logits @ lm_head.t() which fails with a shape mismatch. Add a shape guard in the VLM branch of _get_per_token_logps_and_entropies: check output.shape[-1] against lm_head.shape[1] (hidden_dim). When hidden states are returned, the existing path is taken. When logits are returned, scaling/softcapping/temperature are applied manually and chunked_selective_log_softmax is used instead. Also add chunked_selective_log_softmax to the import from unsloth_zoo. The text-only branch (pixel_values is None) is unchanged. Companion PR to unslothai/unsloth-zoo for grpo_accumulated_loss.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical issue in GRPO training for Vision-Language Models (VLMs) where certain models might output raw logits instead of hidden states, leading to runtime errors due to shape mismatches. The changes introduce a robust mechanism to detect the output type and apply appropriate processing, either by transforming hidden states or directly handling logits with manual scaling and temperature adjustments. This ensures broader compatibility with diverse VLM architectures and maintains the stability of the GRPO training pipeline. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request effectively addresses a matrix multiplication shape mismatch error that occurs during GRPO training with Vision Language Models (VLMs) which may return logits instead of hidden states. The fix introduces a shape guard to correctly identify the output type and applies the necessary transformations manually when logits are returned, preventing the crash. The changes are logical and well-contained. My review includes one suggestion to refactor the newly added logit transformation logic to reduce code duplication and improve long-term maintainability.
| chunks = input_ids_chunk.shape[0] * multiplier, | ||
| logit_scale_multiply = logit_scale_multiply, | ||
| logit_scale_divide = logit_scale_divide, | ||
| logit_softcapping = logit_softcapping, | ||
| temperature = temperature, | ||
| ) | ||
| ) | ||
| else: | ||
| # Model returned logits directly, apply scaling manually | ||
| if logit_scale_multiply != 0.0: | ||
| logits_chunk = logits_chunk * logit_scale_multiply |
There was a problem hiding this comment.
This block manually applies logit scaling, softcapping, and temperature. This logic appears to be a reimplementation of transformations that are also performed within chunked_hidden_states_selective_log_softmax, which is called in the if branch of this conditional.
This introduces code duplication. As the PR description mentions a companion PR with the "same fix," this logic is likely duplicated in at least two places across the codebase.
For better long-term maintainability, consider extracting this transformation logic into a shared utility function. If refactoring the shared unsloth_zoo library is not feasible in this PR, a local helper function within _get_per_token_logps_and_entropies would at least improve readability and centralize the logic within this function.
Example of a local helper:
def _apply_logit_transformations(logits, temperature, logit_softcapping, logit_scale_multiply, logit_scale_divide):
if logit_scale_multiply != 0.0:
logits = logits * logit_scale_multiply
if logit_scale_divide != 0.0:
logits = logits / logit_scale_divide
if logit_softcapping != 0.0:
logits = logits * torch.tanh(logits / logit_softcapping)
logits = logits.to(torch.float32)
if temperature != 1.0:
logits = logits / temperature
return logitsThis would make the code more modular and easier to maintain.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1869545420
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if logit_scale_divide != 0.0: | ||
| logits_chunk = logits_chunk / logit_scale_divide | ||
| if logit_softcapping != 0.0: | ||
| logits_chunk = logits_chunk * torch.tanh(logits_chunk / logit_softcapping) |
There was a problem hiding this comment.
Use proper softcapping formula for direct-logit VLM path
When a VLM forward returns logits directly and final_logit_softcapping is non-zero, this branch applies logits_chunk * tanh(logits_chunk / logit_softcapping), which is not the standard softcap transform and changes log-probabilities significantly for large-magnitude logits; elsewhere in this repo (e.g., the non-RL logits path and kernel code) softcapping is implemented as logit_softcapping * tanh(logits / logit_softcapping). This means GRPO loss is computed from distorted probabilities in that configuration rather than matching the model’s intended output scaling.
Useful? React with 👍 / 👎.
When COMPILE_DISABLE=1 and the model returns logits directly, scaling and softcapping are already applied by the model forward. Only temperature (a GRPO training parameter) needs to be applied.
Use the new temperature parameter in chunked_selective_log_softmax (added in companion zoo PR) to avoid casting the entire logits tensor to float32 before the function call.
for more information, see https://pre-commit.ci
…ies (unslothai#4301) * Fix VLM GRPO matmul shape mismatch in _get_per_token_logps_and_entropies VLM models (e.g. Qwen2.5-VL) can return logits [B*T, vocab_size] instead of hidden states [B*T, hidden_dim] from their forward pass. When this happens, chunked_hidden_states_selective_log_softmax tries to compute logits @ lm_head.t() which fails with a shape mismatch. Add a shape guard in the VLM branch of _get_per_token_logps_and_entropies: check output.shape[-1] against lm_head.shape[1] (hidden_dim). When hidden states are returned, the existing path is taken. When logits are returned, scaling/softcapping/temperature are applied manually and chunked_selective_log_softmax is used instead. Also add chunked_selective_log_softmax to the import from unsloth_zoo. The text-only branch (pixel_values is None) is unchanged. Companion PR to unslothai/unsloth-zoo for grpo_accumulated_loss. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove redundant scaling in logits fallback path When COMPILE_DISABLE=1 and the model returns logits directly, scaling and softcapping are already applied by the model forward. Only temperature (a GRPO training parameter) needs to be applied. * Pass temperature to chunked_selective_log_softmax instead of manual cast Use the new temperature parameter in chunked_selective_log_softmax (added in companion zoo PR) to avoid casting the entire logits tensor to float32 before the function call. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Summary
[B*T, vocab_size]instead of hidden states[B*T, hidden_dim]from their forward pass during GRPO training. When this happens,chunked_hidden_states_selective_log_softmaxtrieslogits @ lm_head.t()which crashes withRuntimeError: mat1 and mat2 shapes cannot be multiplied (72x152064 and 3584x152064)._get_per_token_logps_and_entropies: checkoutput.shape[-1]againstlm_head.shape[1](hidden_dim). When hidden states are returned, the existingchunked_hidden_states_selective_log_softmaxpath is taken. When logits are returned, scaling/softcapping/temperature are applied manually andchunked_selective_log_softmax(which takes logits directly) is used instead.chunked_selective_log_softmaxto the import fromunsloth_zoo.rl_replacements.pixel_values is None) is unchanged.Companion PR: unslothai/unsloth-zoo#546 (same fix for
grpo_accumulated_loss).Test plan
[B, T]log probabilitiesmax_steps=3, no matmul crashmax_steps=10, no regression