Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,113 @@
]


def _patch_granitemoehybrid_return_hidden_states():
"""Patch GraniteMoeHybridForCausalLM.forward to support UNSLOTH_RETURN_HIDDEN_STATES.

The GraniteMoeHybrid architecture uses the raw transformers forward method,
which does not check UNSLOTH_RETURN_HIDDEN_STATES. This causes the RL training
codepath to receive full logits (vocab_size dim) instead of pre-lm_head hidden
states (hidden_size dim), resulting in a shape mismatch during log probability
computation.

This patch wraps the forward to intercept hidden states before lm_head is applied,
matching the pattern used in llama.py and mistral.py.
"""
try:
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
GraniteMoeHybridForCausalLM,
)
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
except ImportError:
return

from functools import wraps

_original_forward = GraniteMoeHybridForCausalLM.forward

@wraps(_original_forward)
def _patched_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
past_key_values = None,
inputs_embeds = None,
labels = None,
use_cache = None,
output_attentions = None,
output_hidden_states = None,
output_router_logits = None,
return_dict = None,
cache_position = None,
logits_to_keep = 0,
**kwargs,
):
if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

outputs = self.model(
input_ids = input_ids,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
use_cache = use_cache,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
output_router_logits = output_router_logits,
return_dict = return_dict,
cache_position = cache_position,
**kwargs,
)

hidden_states = outputs[0]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Move hidden states to lm_head device before returning

In the hidden-state return branch, hidden_states are taken directly from self.model and returned without aligning to self.lm_head.weight.device. In offload/model-parallel setups where decoder outputs and lm_head are on different devices, the downstream GRPO path that multiplies returned hidden states by lm_head can hit device mismatch errors or expensive implicit transfers; Llama/Mistral fast-forward paths already do this device alignment explicitly, so GraniteMoE Hybrid should match that behavior.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in most recent commit

num_logits_to_keep = (
logits_to_keep if isinstance(logits_to_keep, int) else 0
)
if num_logits_to_keep != 0:
hidden_states = hidden_states[:, -num_logits_to_keep:, :]

# Align device with lm_head for model-parallel/offload setups
lm_head_device = self.lm_head.weight.device
if hidden_states.device != lm_head_device:
hidden_states = hidden_states.to(lm_head_device)

if not return_dict:
return (hidden_states,) + outputs[1:]

return MoeCausalLMOutputWithPast(
loss = None,
logits = hidden_states,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
router_logits = getattr(outputs, "router_logits", None),
)

return _original_forward(
self,
input_ids = input_ids,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
labels = labels,
use_cache = use_cache,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
output_router_logits = output_router_logits,
return_dict = return_dict,
cache_position = cache_position,
logits_to_keep = logits_to_keep,
**kwargs,
)

GraniteMoeHybridForCausalLM.forward = _patched_forward


def _fix_rope_inv_freq(model):
"""Fix inv_freq corruption caused by transformers v5 meta-device loading.

Expand Down Expand Up @@ -1186,6 +1293,7 @@ def from_pretrained(
# Granite-4 rms norms are stored as 16 bit, but we upcast
os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
_patch_granitemoehybrid_return_hidden_states()
# Olmo 2
elif "olmo2" in model_types_all and transformers_version < Version(
"4.50.0.dev0"
Expand Down