diff --git a/unsloth_zoo/temporary_patches/gemma4.py b/unsloth_zoo/temporary_patches/gemma4.py index 3a91bba0c..d853f344d 100644 --- a/unsloth_zoo/temporary_patches/gemma4.py +++ b/unsloth_zoo/temporary_patches/gemma4.py @@ -17,7 +17,7 @@ import torch import os from .common import TEMPORARY_PATCHES -from .utils import raise_error +from .utils import raise_error, patch_function # ============================================================================ @@ -606,3 +606,56 @@ def forward(self, hidden_states, position_embeddings, attention_mask=None, **kwa Gemma4AudioAttention.forward = forward pass TEMPORARY_PATCHES.append(patch_Gemma4AudioAttention) + + +# Gemma-4 float16 MLP overflow fix. +# +# `down_proj(act_fn(gate_proj(x)) * up_proj(x))` overflows fp16 at layers.0 +# (E2B) / layers.1 (E4B): the product + fp16 matmul accumulator saturate to +# +-inf, poison the residual stream, and generation samples NaN logits that +# trip the CUDA categorical sampler on GRPO step ~2. +# +# Fix: fp32 gate*up, clamp to a safe bound, fp16 cast, nan_to_num on the +# down_proj output. Gated on gate output dtype so bf16/fp32 users see no +# change and no env flag is required. RMSNorm / Attention / Embedding +# patches are unnecessary (verified by bisection - identical loss/kl/grad +# trajectories). + + +def patch_Gemma4TextMLP(): + """fp16 overflow clamp for Gemma4TextMLP. + + Does gate*up in fp32, clamps to a safe fp16 bound, then nan_to_nums + the down_proj output. Self-gated on gate dtype - no-op on bf16/fp32. + """ + try: + import transformers.models.gemma4.modeling_gemma4 as mod + except ImportError: + return + try: + Gemma4TextMLP = mod.Gemma4TextMLP + except AttributeError as e: + return raise_error("Gemma4TextMLP.forward", e) + + # Largest value representable in both fp16 and bf16 (65536 rounds to + # fp16 inf). + _SAFE_FP16 = 65280.0 + + def forward(self, x): + gate = self.gate_proj(x) + # Check matmul output dtype so autocast / PEFT fp16 casts are caught. + if gate.dtype != torch.float16: + return self.down_proj(self.act_fn(gate) * self.up_proj(x)) + product = self.act_fn(gate.float()) * self.up_proj(x).float() + product = torch.clamp(product, min=-_SAFE_FP16, max=_SAFE_FP16) + out = self.down_proj(product.to(gate.dtype)) + # Zero overflows so the residual identity path survives. + return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0) + try: + patch_function( + Gemma4TextMLP, "forward", forward, fullgraph=False, + ) + except Exception as e: + return raise_error("Gemma4TextMLP.forward", e) +pass +TEMPORARY_PATCHES.append(patch_Gemma4TextMLP)