diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index b54ceaf842..e3e7440f09 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -125,6 +125,113 @@ ] +def _patch_granitemoehybrid_return_hidden_states(): + """Patch GraniteMoeHybridForCausalLM.forward to support UNSLOTH_RETURN_HIDDEN_STATES. + + The GraniteMoeHybrid architecture uses the raw transformers forward method, + which does not check UNSLOTH_RETURN_HIDDEN_STATES. This causes the RL training + codepath to receive full logits (vocab_size dim) instead of pre-lm_head hidden + states (hidden_size dim), resulting in a shape mismatch during log probability + computation. + + This patch wraps the forward to intercept hidden states before lm_head is applied, + matching the pattern used in llama.py and mistral.py. + """ + try: + from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( + GraniteMoeHybridForCausalLM, + ) + from transformers.modeling_outputs import MoeCausalLMOutputWithPast + except ImportError: + return + + from functools import wraps + + _original_forward = GraniteMoeHybridForCausalLM.forward + + @wraps(_original_forward) + def _patched_forward( + self, + input_ids = None, + attention_mask = None, + position_ids = None, + past_key_values = None, + inputs_embeds = None, + labels = None, + use_cache = None, + output_attentions = None, + output_hidden_states = None, + output_router_logits = None, + return_dict = None, + cache_position = None, + logits_to_keep = 0, + **kwargs, + ): + if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.model( + input_ids = input_ids, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + output_router_logits = output_router_logits, + return_dict = return_dict, + cache_position = cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + num_logits_to_keep = ( + logits_to_keep if isinstance(logits_to_keep, int) else 0 + ) + if num_logits_to_keep != 0: + hidden_states = hidden_states[:, -num_logits_to_keep:, :] + + # Align device with lm_head for model-parallel/offload setups + lm_head_device = self.lm_head.weight.device + if hidden_states.device != lm_head_device: + hidden_states = hidden_states.to(lm_head_device) + + if not return_dict: + return (hidden_states,) + outputs[1:] + + return MoeCausalLMOutputWithPast( + loss = None, + logits = hidden_states, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions = outputs.attentions, + router_logits = getattr(outputs, "router_logits", None), + ) + + return _original_forward( + self, + input_ids = input_ids, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + labels = labels, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + output_router_logits = output_router_logits, + return_dict = return_dict, + cache_position = cache_position, + logits_to_keep = logits_to_keep, + **kwargs, + ) + + GraniteMoeHybridForCausalLM.forward = _patched_forward + + def _fix_rope_inv_freq(model): """Fix inv_freq corruption caused by transformers v5 meta-device loading. @@ -1186,6 +1293,7 @@ def from_pretrained( # Granite-4 rms norms are stored as 16 bit, but we upcast os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1" os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" + _patch_granitemoehybrid_return_hidden_states() # Olmo 2 elif "olmo2" in model_types_all and transformers_version < Version( "4.50.0.dev0"