Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion unsloth_zoo/temporary_patches/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import os
from .common import TEMPORARY_PATCHES
from .utils import raise_error
from .utils import raise_error, patch_function


# ============================================================================
Expand Down Expand Up @@ -606,3 +606,56 @@ def forward(self, hidden_states, position_embeddings, attention_mask=None, **kwa
Gemma4AudioAttention.forward = forward
pass
TEMPORARY_PATCHES.append(patch_Gemma4AudioAttention)


# Gemma-4 float16 MLP overflow fix.
#
# `down_proj(act_fn(gate_proj(x)) * up_proj(x))` overflows fp16 at layers.0
# (E2B) / layers.1 (E4B): the product + fp16 matmul accumulator saturate to
# +-inf, poison the residual stream, and generation samples NaN logits that
# trip the CUDA categorical sampler on GRPO step ~2.
#
# Fix: fp32 gate*up, clamp to a safe bound, fp16 cast, nan_to_num on the
# down_proj output. Gated on gate output dtype so bf16/fp32 users see no
# change and no env flag is required. RMSNorm / Attention / Embedding
# patches are unnecessary (verified by bisection - identical loss/kl/grad
# trajectories).


def patch_Gemma4TextMLP():
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
"""fp16 overflow clamp for Gemma4TextMLP.

Does gate*up in fp32, clamps to a safe fp16 bound, then nan_to_nums
the down_proj output. Self-gated on gate dtype - no-op on bf16/fp32.
"""
try:
import transformers.models.gemma4.modeling_gemma4 as mod
except ImportError:
return
try:
Gemma4TextMLP = mod.Gemma4TextMLP
except AttributeError as e:
return raise_error("Gemma4TextMLP.forward", e)

# Largest value representable in both fp16 and bf16 (65536 rounds to
# fp16 inf).
_SAFE_FP16 = 65280.0

def forward(self, x):
gate = self.gate_proj(x)
# Check matmul output dtype so autocast / PEFT fp16 casts are caught.
if gate.dtype != torch.float16:
return self.down_proj(self.act_fn(gate) * self.up_proj(x))
product = self.act_fn(gate.float()) * self.up_proj(x).float()
product = torch.clamp(product, min=-_SAFE_FP16, max=_SAFE_FP16)
out = self.down_proj(product.to(gate.dtype))
# Zero overflows so the residual identity path survives.
return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
try:
patch_function(
Gemma4TextMLP, "forward", forward, fullgraph=False,
)
except Exception as e:
return raise_error("Gemma4TextMLP.forward", e)
pass
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
TEMPORARY_PATCHES.append(patch_Gemma4TextMLP)
Loading