Skip to content

Add UNSLOTH_RETURN_HIDDEN_STATES support for GraniteMoeHybrid#4373

Closed
Maxusmusti wants to merge 5 commits into
unslothai:mainfrom
Maxusmusti:fix/granitemoehybrid-return-hidden-states
Closed

Add UNSLOTH_RETURN_HIDDEN_STATES support for GraniteMoeHybrid#4373
Maxusmusti wants to merge 5 commits into
unslothai:mainfrom
Maxusmusti:fix/granitemoehybrid-return-hidden-states

Conversation

@Maxusmusti
Copy link
Copy Markdown
Contributor

Summary

Adds UNSLOTH_RETURN_HIDDEN_STATES support for GraniteMoeHybridForCausalLM, enabling RL training (GRPO) with Granite 4 hybrid models.

Problem

When UNSLOTH_RETURN_HIDDEN_STATES=1 is set, the RL training codepath expects the model's .logits output to contain pre-lm_head hidden states (shape [B, S, hidden_size]). This works for Llama, Mistral, Qwen3, Qwen3Moe, and dense Granite models because their forward() methods are monkey-patched by Unsloth to check this env var.

However, GraniteMoeHybridForCausalLM uses the raw transformers forward method, which ignores the env var and always returns full logits (shape [B, S, vocab_size]). This causes a shape mismatch error during log probability computation:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1024x100352 and 1536x100352)

The hidden states (vocab_size=100352) are incorrectly multiplied with lm_head (hidden_size=1536).

Fix

Wraps GraniteMoeHybridForCausalLM.forward to intercept hidden states before lm_head is applied when UNSLOTH_RETURN_HIDDEN_STATES=1, matching the existing pattern in llama.py (CausalLM_fast_forward) and mistral.py (MistralForCausalLM_fast_forward).

The patch is applied in loader.py when a granitemoehybrid model is detected, alongside the existing env var setup for this architecture.

Testing

Verified with ibm-granite/granite-4.0-h-tiny:

  • Hidden states output: [1, 64, 1536] (matches lm_head weight shape [100352, 1536]) ✓
  • Logits recomputation via hidden_states @ lm_head.weight.t(): produces correct [1, 64, 100352]
  • Normal forward (env var unset): unchanged behavior ✓
  • Full GRPO training loop with ART: completes successfully ✓

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical compatibility issue that previously prevented Reinforcement Learning (RL) training with Granite 4 hybrid models when the UNSLOTH_RETURN_HIDDEN_STATES environment variable was active. By introducing a targeted patch to the GraniteMoeHybridForCausalLM's forward method, the system can now correctly output pre-lm_head hidden states, resolving a shape mismatch error and enabling seamless integration of these models into RL training workflows.

Highlights

  • GraniteMoeHybrid Support for Hidden States: Added support for UNSLOTH_RETURN_HIDDEN_STATES in GraniteMoeHybridForCausalLM, allowing it to return pre-lm_head hidden states instead of full logits when the environment variable is set.
  • Enabling RL Training: This change enables Reinforcement Learning (RL) training, specifically GRPO, with Granite 4 hybrid models by resolving a shape mismatch error that occurred during log probability computation.
  • Forward Method Patching: Implemented a monkey-patch for GraniteMoeHybridForCausalLM.forward within unsloth/models/loader.py to intercept and return hidden states, aligning its behavior with other models like Llama and Mistral.
Changelog
  • unsloth/models/loader.py
    • Added a new function _patch_granitemoehybrid_return_hidden_states responsible for dynamically modifying the GraniteMoeHybridForCausalLM.forward method.
    • The patched forward method now checks for the UNSLOTH_RETURN_HIDDEN_STATES environment variable and, if set, returns hidden states instead of full logits.
    • Integrated the call to _patch_granitemoehybrid_return_hidden_states within the from_pretrained function, ensuring the patch is applied when a granitemoehybrid model is loaded.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adds support for UNSLOTH_RETURN_HIDDEN_STATES to GraniteMoeHybridForCausalLM, which is crucial for RL training workflows. The implementation follows the established pattern of monkey-patching the model's forward pass. I've identified two critical issues in the patched forward method where it incorrectly passes extra keyword arguments (**kwargs and logits_to_keep) to underlying functions from the transformers library that do not accept them. This will lead to TypeError exceptions during runtime. I've provided suggestions to fix these issues.

Comment thread unsloth/models/loader.py
Comment on lines +173 to +186
)

outputs = self.model(
input_ids = input_ids,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
use_cache = use_cache,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
output_router_logits = output_router_logits,
return_dict = return_dict,
cache_position = cache_position,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to self.model incorrectly passes **kwargs. The transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridModel.forward method does not accept arbitrary keyword arguments, which will likely lead to a TypeError if any unexpected arguments are present in kwargs.

            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                output_router_logits=output_router_logits,
                return_dict=return_dict,
                cache_position=cache_position,
            )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed excess kwargs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, your feedback looks to be incorrect, we are looking at granitemoehybridforcausallm, which has kwargs: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py#L1463

Comment thread unsloth/models/loader.py Outdated
Comment on lines +202 to +218
attentions = outputs.attentions if return_dict else None,
router_logits = getattr(outputs, "router_logits", None)
if return_dict
else None,
)

return _original_forward(
self,
input_ids = input_ids,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
labels = labels,
use_cache = use_cache,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to _original_forward incorrectly passes logits_to_keep and **kwargs. The original transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridForCausalLM.forward method does not accept these arguments. This will cause a TypeError when the UNSLOTH_RETURN_HIDDEN_STATES environment variable is not set to "1".

        return _original_forward(
            self,
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            return_dict=return_dict,
            cache_position=cache_position,
        )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Maxusmusti Maxusmusti force-pushed the fix/granitemoehybrid-return-hidden-states branch from 76f3f13 to 4b51002 Compare March 17, 2026 19:53
@Maxusmusti Maxusmusti force-pushed the fix/granitemoehybrid-return-hidden-states branch from 533d206 to 7f731c8 Compare March 17, 2026 19:56
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b1199a11a9

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/loader.py
Comment on lines +164 to +168
output_router_logits = None,
return_dict = None,
cache_position = None,
logits_to_keep = 0,
):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Accept and forward extra kwargs in patched forward

This wrapper narrows GraniteMoeHybridForCausalLM.forward to a fixed parameter list and drops **kwargs, so any caller that injects additional forward kwargs (for example trainer/PEFT arguments like num_items_in_batch, task_ids, or similar passthrough keys) will now fail with TypeError before _original_forward is reached, even when UNSLOTH_RETURN_HIDDEN_STATES is disabled. Because this patch is applied globally for GraniteMoE Hybrid models, it can break normal training/inference flows that previously relied on the base forward's flexible kwargs handling.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re-added the kwargs, originally was following above gemini comment suggestion, but looked into it and found it was in fact necessary

Comment thread unsloth/models/loader.py
cache_position = cache_position,
)

hidden_states = outputs[0]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Move hidden states to lm_head device before returning

In the hidden-state return branch, hidden_states are taken directly from self.model and returned without aligning to self.lm_head.weight.device. In offload/model-parallel setups where decoder outputs and lm_head are on different devices, the downstream GRPO path that multiplies returned hidden states by lm_head can hit device mismatch errors or expensive implicit transfers; Llama/Mistral fast-forward paths already do this device alignment explicitly, so GraniteMoE Hybrid should match that behavior.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in most recent commit

@Maxusmusti Maxusmusti force-pushed the fix/granitemoehybrid-return-hidden-states branch from b1199a1 to 6b8ad4d Compare March 17, 2026 20:03
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f0d70ee3ad

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/models/loader.py Outdated
Comment on lines +197 to +200
return MoeCausalLMOutputWithPast(
loss = None,
logits = hidden_states,
past_key_values = outputs.past_key_values if return_dict else None,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve tuple output when return_dict is False

When UNSLOTH_RETURN_HIDDEN_STATES=1, this branch always returns MoeCausalLMOutputWithPast even if the caller requested return_dict=False, and it explicitly sets tuple-only fields like past_key_values to None in that mode. That changes the forward contract versus the original GraniteMoE hybrid implementation, so callers that rely on tuple outputs/cached values under return_dict=False can misbehave (e.g., missing cache entries or indexing assumptions).

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in most recent commit

@Maxusmusti Maxusmusti force-pushed the fix/granitemoehybrid-return-hidden-states branch from f0d70ee to 669bfa5 Compare March 17, 2026 20:13
@Datta0
Copy link
Copy Markdown
Collaborator

Datta0 commented Mar 18, 2026

@danielhanchen I think a better place to do this is to do in compiler.py where we write compiled CausalLM classes so that we need not deal with this on a per model basis. Thoughts?

@Maxusmusti
Copy link
Copy Markdown
Contributor Author

Hi @Datta0, following up with this, I opened a PR for a compiler.py approach, got some help with Claude for the regex updates, take a look here: unslothai/unsloth-zoo#562

@Maxusmusti
Copy link
Copy Markdown
Contributor Author

Maxusmusti commented Mar 24, 2026

It should make the regex more forgiving and robust for more model families. If the more general fix looks/works better for you, I can close this PR and move forward with the other one. Let me know!

@Datta0
Copy link
Copy Markdown
Collaborator

Datta0 commented Mar 25, 2026

If the compiler in unsloth-zoo writes a patched forward pass code for granite, we should not need to do the RETURN_HIDDEN_STATES patch here right?

@Maxusmusti
Copy link
Copy Markdown
Contributor Author

@Datta0 yeah exactly, if the compiler PR fix is added, this patch becomes redundant and unnecessary, so I can close this PR if the other method seems cleaner

@Datta0
Copy link
Copy Markdown
Collaborator

Datta0 commented Jun 2, 2026

This should be dealt by #5898 #5142

@Datta0 Datta0 closed this Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants