Add UNSLOTH_RETURN_HIDDEN_STATES support for GraniteMoeHybrid#4373
Add UNSLOTH_RETURN_HIDDEN_STATES support for GraniteMoeHybrid#4373Maxusmusti wants to merge 5 commits into
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical compatibility issue that previously prevented Reinforcement Learning (RL) training with Granite 4 hybrid models when the Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly adds support for UNSLOTH_RETURN_HIDDEN_STATES to GraniteMoeHybridForCausalLM, which is crucial for RL training workflows. The implementation follows the established pattern of monkey-patching the model's forward pass. I've identified two critical issues in the patched forward method where it incorrectly passes extra keyword arguments (**kwargs and logits_to_keep) to underlying functions from the transformers library that do not accept them. This will lead to TypeError exceptions during runtime. I've provided suggestions to fix these issues.
| ) | ||
|
|
||
| outputs = self.model( | ||
| input_ids = input_ids, | ||
| attention_mask = attention_mask, | ||
| position_ids = position_ids, | ||
| past_key_values = past_key_values, | ||
| inputs_embeds = inputs_embeds, | ||
| use_cache = use_cache, | ||
| output_attentions = output_attentions, | ||
| output_hidden_states = output_hidden_states, | ||
| output_router_logits = output_router_logits, | ||
| return_dict = return_dict, | ||
| cache_position = cache_position, |
There was a problem hiding this comment.
The call to self.model incorrectly passes **kwargs. The transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridModel.forward method does not accept arbitrary keyword arguments, which will likely lead to a TypeError if any unexpected arguments are present in kwargs.
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
)There was a problem hiding this comment.
removed excess kwargs
There was a problem hiding this comment.
actually, your feedback looks to be incorrect, we are looking at granitemoehybridforcausallm, which has kwargs: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py#L1463
| attentions = outputs.attentions if return_dict else None, | ||
| router_logits = getattr(outputs, "router_logits", None) | ||
| if return_dict | ||
| else None, | ||
| ) | ||
|
|
||
| return _original_forward( | ||
| self, | ||
| input_ids = input_ids, | ||
| attention_mask = attention_mask, | ||
| position_ids = position_ids, | ||
| past_key_values = past_key_values, | ||
| inputs_embeds = inputs_embeds, | ||
| labels = labels, | ||
| use_cache = use_cache, | ||
| output_attentions = output_attentions, | ||
| output_hidden_states = output_hidden_states, |
There was a problem hiding this comment.
The call to _original_forward incorrectly passes logits_to_keep and **kwargs. The original transformers.models.granitemoehybrid.modeling_granitemoehybrid.GraniteMoeHybridForCausalLM.forward method does not accept these arguments. This will cause a TypeError when the UNSLOTH_RETURN_HIDDEN_STATES environment variable is not set to "1".
return _original_forward(
self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
)There was a problem hiding this comment.
removed kwargs, however the logits_to_keep field is part of the original forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py#L1463
There was a problem hiding this comment.
again, re-adding kwargs since they are part of the causallm forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py#L1463
76f3f13 to
4b51002
Compare
533d206 to
7f731c8
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b1199a11a9
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| output_router_logits = None, | ||
| return_dict = None, | ||
| cache_position = None, | ||
| logits_to_keep = 0, | ||
| ): |
There was a problem hiding this comment.
Accept and forward extra kwargs in patched forward
This wrapper narrows GraniteMoeHybridForCausalLM.forward to a fixed parameter list and drops **kwargs, so any caller that injects additional forward kwargs (for example trainer/PEFT arguments like num_items_in_batch, task_ids, or similar passthrough keys) will now fail with TypeError before _original_forward is reached, even when UNSLOTH_RETURN_HIDDEN_STATES is disabled. Because this patch is applied globally for GraniteMoE Hybrid models, it can break normal training/inference flows that previously relied on the base forward's flexible kwargs handling.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
re-added the kwargs, originally was following above gemini comment suggestion, but looked into it and found it was in fact necessary
| cache_position = cache_position, | ||
| ) | ||
|
|
||
| hidden_states = outputs[0] |
There was a problem hiding this comment.
Move hidden states to lm_head device before returning
In the hidden-state return branch, hidden_states are taken directly from self.model and returned without aligning to self.lm_head.weight.device. In offload/model-parallel setups where decoder outputs and lm_head are on different devices, the downstream GRPO path that multiplies returned hidden states by lm_head can hit device mismatch errors or expensive implicit transfers; Llama/Mistral fast-forward paths already do this device alignment explicitly, so GraniteMoE Hybrid should match that behavior.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
updated in most recent commit
b1199a1 to
6b8ad4d
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f0d70ee3ad
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| return MoeCausalLMOutputWithPast( | ||
| loss = None, | ||
| logits = hidden_states, | ||
| past_key_values = outputs.past_key_values if return_dict else None, |
There was a problem hiding this comment.
Preserve tuple output when return_dict is False
When UNSLOTH_RETURN_HIDDEN_STATES=1, this branch always returns MoeCausalLMOutputWithPast even if the caller requested return_dict=False, and it explicitly sets tuple-only fields like past_key_values to None in that mode. That changes the forward contract versus the original GraniteMoE hybrid implementation, so callers that rely on tuple outputs/cached values under return_dict=False can misbehave (e.g., missing cache entries or indexing assumptions).
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
updated in most recent commit
f0d70ee to
669bfa5
Compare
for more information, see https://pre-commit.ci
|
@danielhanchen I think a better place to do this is to do in compiler.py where we write compiled CausalLM classes so that we need not deal with this on a per model basis. Thoughts? |
|
Hi @Datta0, following up with this, I opened a PR for a |
|
It should make the regex more forgiving and robust for more model families. If the more general fix looks/works better for you, I can close this PR and move forward with the other one. Let me know! |
|
If the compiler in unsloth-zoo writes a patched forward pass code for granite, we should not need to do the RETURN_HIDDEN_STATES patch here right? |
|
@Datta0 yeah exactly, if the compiler PR fix is added, this patch becomes redundant and unnecessary, so I can close this PR if the other method seems cleaner |
Summary
Adds
UNSLOTH_RETURN_HIDDEN_STATESsupport forGraniteMoeHybridForCausalLM, enabling RL training (GRPO) with Granite 4 hybrid models.Problem
When
UNSLOTH_RETURN_HIDDEN_STATES=1is set, the RL training codepath expects the model's.logitsoutput to contain pre-lm_head hidden states (shape[B, S, hidden_size]). This works for Llama, Mistral, Qwen3, Qwen3Moe, and dense Granite models because theirforward()methods are monkey-patched by Unsloth to check this env var.However,
GraniteMoeHybridForCausalLMuses the raw transformers forward method, which ignores the env var and always returns full logits (shape[B, S, vocab_size]). This causes a shape mismatch error during log probability computation:The hidden states (vocab_size=100352) are incorrectly multiplied with lm_head (hidden_size=1536).
Fix
Wraps
GraniteMoeHybridForCausalLM.forwardto intercept hidden states beforelm_headis applied whenUNSLOTH_RETURN_HIDDEN_STATES=1, matching the existing pattern inllama.py(CausalLM_fast_forward) andmistral.py(MistralForCausalLM_fast_forward).The patch is applied in
loader.pywhen agranitemoehybridmodel is detected, alongside the existing env var setup for this architecture.Testing
Verified with
ibm-granite/granite-4.0-h-tiny:[1, 64, 1536](matcheslm_headweight shape[100352, 1536]) ✓hidden_states @ lm_head.weight.t(): produces correct[1, 64, 100352]✓