From a2df1064993eb65b69f2ccfa2784edce7b3f3833 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Tue, 17 Mar 2026 17:10:08 +0000 Subject: [PATCH 1/5] Add UNSLOTH_RETURN_HIDDEN_STATES support for GraniteMoeHybrid --- unsloth/models/loader.py | 96 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index b54ceaf842..ce176008ee 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -125,6 +125,101 @@ ] +def _patch_granitemoehybrid_return_hidden_states(): + """Patch GraniteMoeHybridForCausalLM.forward to support UNSLOTH_RETURN_HIDDEN_STATES. + + The GraniteMoeHybrid architecture uses the raw transformers forward method, + which does not check UNSLOTH_RETURN_HIDDEN_STATES. This causes the RL training + codepath to receive full logits (vocab_size dim) instead of pre-lm_head hidden + states (hidden_size dim), resulting in a shape mismatch during log probability + computation. + + This patch wraps the forward to intercept hidden states before lm_head is applied, + matching the pattern used in llama.py and mistral.py. + """ + try: + from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( + GraniteMoeHybridForCausalLM, + ) + from transformers.modeling_outputs import MoeCausalLMOutputWithPast + except ImportError: + return + + from functools import wraps + + _original_forward = GraniteMoeHybridForCausalLM.forward + + @wraps(_original_forward) + def _patched_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_router_logits=None, + return_dict=None, + cache_position=None, + logits_to_keep=0, + **kwargs, + ): + if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + num_logits_to_keep = logits_to_keep if isinstance(logits_to_keep, int) else 0 + if num_logits_to_keep != 0: + hidden_states = hidden_states[:, -num_logits_to_keep:, :] + + return MoeCausalLMOutputWithPast( + loss=None, + logits=hidden_states, + past_key_values=outputs.past_key_values if return_dict else None, + hidden_states=outputs.hidden_states if return_dict else None, + attentions=outputs.attentions if return_dict else None, + router_logits=getattr(outputs, "router_logits", None) if return_dict else None, + ) + + return _original_forward( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + GraniteMoeHybridForCausalLM.forward = _patched_forward + + def _fix_rope_inv_freq(model): """Fix inv_freq corruption caused by transformers v5 meta-device loading. @@ -1186,6 +1281,7 @@ def from_pretrained( # Granite-4 rms norms are stored as 16 bit, but we upcast os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1" os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" + _patch_granitemoehybrid_return_hidden_states() # Olmo 2 elif "olmo2" in model_types_all and transformers_version < Version( "4.50.0.dev0" From 7f731c8f631e97af98ff59c6771fd3b6f907c1d0 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Tue, 17 Mar 2026 19:53:12 +0000 Subject: [PATCH 2/5] Remove unnecessary **kwargs passthrough --- unsloth/models/loader.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index ce176008ee..263b005b34 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -165,7 +165,6 @@ def _patched_forward( return_dict=None, cache_position=None, logits_to_keep=0, - **kwargs, ): if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -182,7 +181,6 @@ def _patched_forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, - **kwargs, ) hidden_states = outputs[0] @@ -214,7 +212,6 @@ def _patched_forward( return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, - **kwargs, ) GraniteMoeHybridForCausalLM.forward = _patched_forward From 6b8ad4db35ba6a9f513f6671e75bcae91c00f615 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Tue, 17 Mar 2026 20:03:32 +0000 Subject: [PATCH 3/5] Restore **kwargs passthrough for flash attention compatibility --- unsloth/models/loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 263b005b34..ce176008ee 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -165,6 +165,7 @@ def _patched_forward( return_dict=None, cache_position=None, logits_to_keep=0, + **kwargs, ): if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -181,6 +182,7 @@ def _patched_forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -212,6 +214,7 @@ def _patched_forward( return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **kwargs, ) GraniteMoeHybridForCausalLM.forward = _patched_forward From 669bfa5892eae194c903b8c8af98619a77f46e6e Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Tue, 17 Mar 2026 20:13:07 +0000 Subject: [PATCH 4/5] Add device alignment and tuple return support --- unsloth/models/loader.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index ce176008ee..99c4b07aab 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -190,13 +190,21 @@ def _patched_forward( if num_logits_to_keep != 0: hidden_states = hidden_states[:, -num_logits_to_keep:, :] + # Align device with lm_head for model-parallel/offload setups + lm_head_device = self.lm_head.weight.device + if hidden_states.device != lm_head_device: + hidden_states = hidden_states.to(lm_head_device) + + if not return_dict: + return (hidden_states,) + outputs[1:] + return MoeCausalLMOutputWithPast( loss=None, logits=hidden_states, - past_key_values=outputs.past_key_values if return_dict else None, - hidden_states=outputs.hidden_states if return_dict else None, - attentions=outputs.attentions if return_dict else None, - router_logits=getattr(outputs, "router_logits", None) if return_dict else None, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=getattr(outputs, "router_logits", None), ) return _original_forward( From 62623a1947d2995344f8f871d6e79a106b3aa84d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 20:13:16 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/loader.py | 94 +++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 45 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 99c4b07aab..e3e7440f09 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -152,41 +152,45 @@ def _patch_granitemoehybrid_return_hidden_states(): @wraps(_original_forward) def _patched_forward( self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - output_router_logits=None, - return_dict=None, - cache_position=None, - logits_to_keep=0, + input_ids = None, + attention_mask = None, + position_ids = None, + past_key_values = None, + inputs_embeds = None, + labels = None, + use_cache = None, + output_attentions = None, + output_hidden_states = None, + output_router_logits = None, + return_dict = None, + cache_position = None, + logits_to_keep = 0, **kwargs, ): if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - return_dict=return_dict, - cache_position=cache_position, + input_ids = input_ids, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + output_router_logits = output_router_logits, + return_dict = return_dict, + cache_position = cache_position, **kwargs, ) hidden_states = outputs[0] - num_logits_to_keep = logits_to_keep if isinstance(logits_to_keep, int) else 0 + num_logits_to_keep = ( + logits_to_keep if isinstance(logits_to_keep, int) else 0 + ) if num_logits_to_keep != 0: hidden_states = hidden_states[:, -num_logits_to_keep:, :] @@ -199,29 +203,29 @@ def _patched_forward( return (hidden_states,) + outputs[1:] return MoeCausalLMOutputWithPast( - loss=None, - logits=hidden_states, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=getattr(outputs, "router_logits", None), + loss = None, + logits = hidden_states, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions = outputs.attentions, + router_logits = getattr(outputs, "router_logits", None), ) return _original_forward( self, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, + input_ids = input_ids, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + labels = labels, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + output_router_logits = output_router_logits, + return_dict = return_dict, + cache_position = cache_position, + logits_to_keep = logits_to_keep, **kwargs, )