Skip to content

Fix VLM GRPO matmul shape mismatch in _get_per_token_logps_and_entropies#4301

Merged
danielhanchen merged 5 commits into
mainfrom
fix-vlm-grpo-matmul-shape
Mar 16, 2026
Merged

Fix VLM GRPO matmul shape mismatch in _get_per_token_logps_and_entropies#4301
danielhanchen merged 5 commits into
mainfrom
fix-vlm-grpo-matmul-shape

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

  • VLM models (e.g. Qwen2.5-VL) can return logits [B*T, vocab_size] instead of hidden states [B*T, hidden_dim] from their forward pass during GRPO training. When this happens, chunked_hidden_states_selective_log_softmax tries logits @ lm_head.t() which crashes with RuntimeError: mat1 and mat2 shapes cannot be multiplied (72x152064 and 3584x152064).
  • Add a shape guard in the VLM branch of _get_per_token_logps_and_entropies: check output.shape[-1] against lm_head.shape[1] (hidden_dim). When hidden states are returned, the existing chunked_hidden_states_selective_log_softmax path is taken. When logits are returned, scaling/softcapping/temperature are applied manually and chunked_selective_log_softmax (which takes logits directly) is used instead.
  • Add chunked_selective_log_softmax to the import from unsloth_zoo.rl_replacements.
  • The text-only branch (pixel_values is None) is unchanged.

Companion PR: unslothai/unsloth-zoo#546 (same fix for grpo_accumulated_loss).

Test plan

  • Shape guard unit test: hidden_states vs logits routing both produce valid [B, T] log probabilities
  • VLM GRPO smoke test: Qwen2.5-VL-7B with max_steps=3, no matmul crash
  • Text-only GRPO regression: Llama-3.2-1B with max_steps=10, no regression

VLM models (e.g. Qwen2.5-VL) can return logits [B*T, vocab_size] instead
of hidden states [B*T, hidden_dim] from their forward pass. When this
happens, chunked_hidden_states_selective_log_softmax tries to compute
logits @ lm_head.t() which fails with a shape mismatch.

Add a shape guard in the VLM branch of _get_per_token_logps_and_entropies:
check output.shape[-1] against lm_head.shape[1] (hidden_dim). When hidden
states are returned, the existing path is taken. When logits are returned,
scaling/softcapping/temperature are applied manually and
chunked_selective_log_softmax is used instead.

Also add chunked_selective_log_softmax to the import from unsloth_zoo.

The text-only branch (pixel_values is None) is unchanged.

Companion PR to unslothai/unsloth-zoo for grpo_accumulated_loss.
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue in GRPO training for Vision-Language Models (VLMs) where certain models might output raw logits instead of hidden states, leading to runtime errors due to shape mismatches. The changes introduce a robust mechanism to detect the output type and apply appropriate processing, either by transforming hidden states or directly handling logits with manual scaling and temperature adjustments. This ensures broader compatibility with diverse VLM architectures and maintains the stability of the GRPO training pipeline.

Highlights

  • VLM Logit Handling: Implemented a fix for VLM models that return logits directly instead of hidden states during GRPO training, preventing matmul shape mismatch errors.
  • Shape Guard: Added a shape guard in _get_per_token_logps_and_entropies to correctly route processing based on whether the model output is hidden states or direct logits.
  • Manual Scaling for Logits: Applied manual scaling, softcapping, and temperature adjustments when VLM models return direct logits.
  • chunked_selective_log_softmax Integration: Utilized chunked_selective_log_softmax for processing direct logits, ensuring proper log probability calculation.
  • Dependency Update: Imported chunked_selective_log_softmax from unsloth_zoo.rl_replacements.
Changelog
  • unsloth/models/rl_replacements.py
    • Imported chunked_selective_log_softmax for direct logit processing.
    • Introduced a conditional block within the VLM processing path to check the shape of the model's output.
    • Added logic to manually apply logit_scale_multiply, logit_scale_divide, logit_softcapping, and temperature when the model returns direct logits.
    • Modified the log probability calculation to use chunked_selective_log_softmax for direct logits and chunked_hidden_states_selective_log_softmax for hidden states.
Activity
  • Completed unit tests for the shape guard, verifying correct log probability generation for both hidden states and logits.
  • Successfully ran a VLM GRPO smoke test with Qwen2.5-VL-7B, confirming the absence of matmul crashes.
  • Verified no regressions in text-only GRPO training with Llama-3.2-1B.
  • A companion PR (unslothai/unsloth-zoo#546) was created for a similar fix in grpo_accumulated_loss.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a matrix multiplication shape mismatch error that occurs during GRPO training with Vision Language Models (VLMs) which may return logits instead of hidden states. The fix introduces a shape guard to correctly identify the output type and applies the necessary transformations manually when logits are returned, preventing the crash. The changes are logical and well-contained. My review includes one suggestion to refactor the newly added logit transformation logic to reduce code duplication and improve long-term maintainability.

Comment thread unsloth/models/rl_replacements.py Outdated
Comment on lines +902 to +912
chunks = input_ids_chunk.shape[0] * multiplier,
logit_scale_multiply = logit_scale_multiply,
logit_scale_divide = logit_scale_divide,
logit_softcapping = logit_softcapping,
temperature = temperature,
)
)
else:
# Model returned logits directly, apply scaling manually
if logit_scale_multiply != 0.0:
logits_chunk = logits_chunk * logit_scale_multiply

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block manually applies logit scaling, softcapping, and temperature. This logic appears to be a reimplementation of transformations that are also performed within chunked_hidden_states_selective_log_softmax, which is called in the if branch of this conditional.

This introduces code duplication. As the PR description mentions a companion PR with the "same fix," this logic is likely duplicated in at least two places across the codebase.

For better long-term maintainability, consider extracting this transformation logic into a shared utility function. If refactoring the shared unsloth_zoo library is not feasible in this PR, a local helper function within _get_per_token_logps_and_entropies would at least improve readability and centralize the logic within this function.

Example of a local helper:

def _apply_logit_transformations(logits, temperature, logit_softcapping, logit_scale_multiply, logit_scale_divide):
    if logit_scale_multiply != 0.0:
        logits = logits * logit_scale_multiply
    if logit_scale_divide != 0.0:
        logits = logits / logit_scale_divide
    if logit_softcapping != 0.0:
        logits = logits * torch.tanh(logits / logit_softcapping)
    logits = logits.to(torch.float32)
    if temperature != 1.0:
        logits = logits / temperature
    return logits

This would make the code more modular and easier to maintain.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1869545420

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/rl_replacements.py Outdated
if logit_scale_divide != 0.0:
logits_chunk = logits_chunk / logit_scale_divide
if logit_softcapping != 0.0:
logits_chunk = logits_chunk * torch.tanh(logits_chunk / logit_softcapping)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use proper softcapping formula for direct-logit VLM path

When a VLM forward returns logits directly and final_logit_softcapping is non-zero, this branch applies logits_chunk * tanh(logits_chunk / logit_softcapping), which is not the standard softcap transform and changes log-probabilities significantly for large-magnitude logits; elsewhere in this repo (e.g., the non-RL logits path and kernel code) softcapping is implemented as logit_softcapping * tanh(logits / logit_softcapping). This means GRPO loss is computed from distorted probabilities in that configuration rather than matching the model’s intended output scaling.

Useful? React with 👍 / 👎.

danielhanchen and others added 3 commits March 15, 2026 10:59
When COMPILE_DISABLE=1 and the model returns logits directly, scaling
and softcapping are already applied by the model forward. Only
temperature (a GRPO training parameter) needs to be applied.
Use the new temperature parameter in chunked_selective_log_softmax
(added in companion zoo PR) to avoid casting the entire logits tensor
to float32 before the function call.
@danielhanchen danielhanchen merged commit 1144920 into main Mar 16, 2026
5 checks passed
@danielhanchen danielhanchen deleted the fix-vlm-grpo-matmul-shape branch March 16, 2026 10:54
shibizhao pushed a commit to shibizhao/unsloth-npu that referenced this pull request Apr 7, 2026
…ies (unslothai#4301)

* Fix VLM GRPO matmul shape mismatch in _get_per_token_logps_and_entropies

VLM models (e.g. Qwen2.5-VL) can return logits [B*T, vocab_size] instead
of hidden states [B*T, hidden_dim] from their forward pass. When this
happens, chunked_hidden_states_selective_log_softmax tries to compute
logits @ lm_head.t() which fails with a shape mismatch.

Add a shape guard in the VLM branch of _get_per_token_logps_and_entropies:
check output.shape[-1] against lm_head.shape[1] (hidden_dim). When hidden
states are returned, the existing path is taken. When logits are returned,
scaling/softcapping/temperature are applied manually and
chunked_selective_log_softmax is used instead.

Also add chunked_selective_log_softmax to the import from unsloth_zoo.

The text-only branch (pixel_values is None) is unchanged.

Companion PR to unslothai/unsloth-zoo for grpo_accumulated_loss.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove redundant scaling in logits fallback path

When COMPILE_DISABLE=1 and the model returns logits directly, scaling
and softcapping are already applied by the model forward. Only
temperature (a GRPO training parameter) needs to be applied.

* Pass temperature to chunked_selective_log_softmax instead of manual cast

Use the new temperature parameter in chunked_selective_log_softmax
(added in companion zoo PR) to avoid casting the entire logits tensor
to float32 before the function call.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant