Skip to content
192 changes: 192 additions & 0 deletions tests/test_patch_gemma4_text_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
import sys
import types
import importlib.util

os.environ.setdefault("UNSLOTH_IS_PRESENT", "1")
os.environ.setdefault("UNSLOTH_COMPILE_DISABLE", "1")
if "unsloth" not in sys.modules:
_stub = types.ModuleType("unsloth")
_stub.__spec__ = importlib.util.spec_from_loader("unsloth", loader=None)
_stub.__path__ = []
sys.modules["unsloth"] = _stub

import pytest
import torch
import torch.nn as nn

from unsloth_zoo.temporary_patches import gemma4 as g4


def _make_mlp_class():
class Gemma4TextMLP(nn.Module):
def __init__(self):
super().__init__()
self.gate_proj = nn.Linear(64, 2048, bias=False)
self.up_proj = nn.Linear(64, 2048, bias=False)
self.down_proj = nn.Linear(2048, 64, bias=False)
self.act_fn = nn.GELU(approximate="tanh")
with torch.no_grad():
for p in (
self.gate_proj.weight,
self.up_proj.weight,
self.down_proj.weight,
):
p.fill_(0.5)

def forward(self, x):
return self.down_proj(
self.act_fn(self.gate_proj(x)) * self.up_proj(x)
)

return Gemma4TextMLP


def _install_module_stub(monkeypatch, cls):
fake = types.ModuleType("transformers.models.gemma4.modeling_gemma4")
fake.Gemma4TextMLP = cls
for pkg in (
"transformers",
"transformers.models",
"transformers.models.gemma4",
):
if pkg not in sys.modules:
p = types.ModuleType(pkg)
p.__path__ = []
monkeypatch.setitem(sys.modules, pkg, p)
monkeypatch.setitem(
sys.modules, "transformers.models.gemma4.modeling_gemma4", fake
)


def test_noop_when_force_float32_unset(monkeypatch):
monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False)
cls = _make_mlp_class()
_install_module_stub(monkeypatch, cls)
original = cls.forward
g4.patch_Gemma4TextMLP()
assert cls.forward is original


def test_noop_when_force_float32_zero(monkeypatch):
monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "0")
cls = _make_mlp_class()
_install_module_stub(monkeypatch, cls)
original = cls.forward
g4.patch_Gemma4TextMLP()
assert cls.forward is original


def test_upstream_without_patch_overflows_fp16():
cls = _make_mlp_class()
m = cls().half().eval()
torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.float16) * 20.0
with torch.no_grad():
out = m(x)
assert (~torch.isfinite(out)).any().item()


def test_fp16_overflow_output_is_finite(monkeypatch):
monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1")
cls = _make_mlp_class()
_install_module_stub(monkeypatch, cls)
g4.patch_Gemma4TextMLP()
m = cls().half().eval()
torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.float16) * 20.0
with torch.no_grad():
out = m(x)
assert torch.all(torch.isfinite(out)).item()
assert out.dtype == torch.float16


def test_fp16_nan_to_num_replaces_with_zero(monkeypatch):
monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1")
cls = _make_mlp_class()
_install_module_stub(monkeypatch, cls)
g4.patch_Gemma4TextMLP()
m = cls().half().eval()
torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.float16) * 20.0
with torch.no_grad():
out = m(x)
assert out.abs().max().item() == 0.0


def test_fp16_normal_input_produces_nonzero_output(monkeypatch):
monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1")
cls = _make_mlp_class()
_install_module_stub(monkeypatch, cls)
g4.patch_Gemma4TextMLP()
m = cls().half().eval()
torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.float16) * 0.1
with torch.no_grad():
out = m(x)
assert torch.all(torch.isfinite(out)).item()
assert (out != 0).any().item()


def test_bf16_input_matches_upstream(monkeypatch):
monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1")
cls = _make_mlp_class()
_install_module_stub(monkeypatch, cls)
upstream = cls.forward
g4.patch_Gemma4TextMLP()
m = cls().to(torch.bfloat16).eval()
torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.bfloat16) * 50.0
with torch.no_grad():
patched = m(x)
expected = upstream(m, x)
assert torch.equal(patched, expected)
assert patched.dtype == torch.bfloat16


def test_fp32_input_matches_upstream(monkeypatch):
monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1")
cls = _make_mlp_class()
_install_module_stub(monkeypatch, cls)
upstream = cls.forward
g4.patch_Gemma4TextMLP()
m = cls().eval()
torch.manual_seed(0)
x = torch.randn(2, 64)
with torch.no_grad():
patched = m(x)
expected = upstream(m, x)
assert torch.equal(patched, expected)


def test_idempotent_patch_install(monkeypatch):
monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1")
cls = _make_mlp_class()
_install_module_stub(monkeypatch, cls)
g4.patch_Gemma4TextMLP()
g4.patch_Gemma4TextMLP()
assert cls.forward.__name__ == "forward"
m = cls().half().eval()
torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.float16) * 0.1
with torch.no_grad():
out = m(x)
assert torch.all(torch.isfinite(out)).item()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_bf16_weights_fp16_autocast_stabilizes(monkeypatch):
# bf16-capable GPU + user-wrapped torch.amp.autocast(fp16): x.dtype stays
# bf16 but self.gate_proj(x) runs in fp16 and can overflow. The
# gate.dtype guard must enter stabilization here; an x.dtype guard would
# bypass and leave the overflow unfixed.
monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1")
cls = _make_mlp_class()
_install_module_stub(monkeypatch, cls)
g4.patch_Gemma4TextMLP()
m = cls().cuda().to(torch.bfloat16).eval()
torch.manual_seed(0)
x = torch.randn(2, 64, dtype=torch.bfloat16, device="cuda") * 20.0
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.float16):
out = m(x)
assert torch.all(torch.isfinite(out)).item()
108 changes: 107 additions & 1 deletion unsloth_zoo/temporary_patches/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import os
from .common import TEMPORARY_PATCHES
from .utils import raise_error
from .utils import raise_error, patch_function


# ============================================================================
Expand Down Expand Up @@ -606,3 +606,109 @@ def forward(self, hidden_states, position_embeddings, attention_mask=None, **kwa
Gemma4AudioAttention.forward = forward
pass
TEMPORARY_PATCHES.append(patch_Gemma4AudioAttention)


# ============================================================================
# Gemma-4 float16 stability patch.
#
# Goal: make float16 training (notably GRPO on E2B and E4B) numerically
# stable without paying blanket fp32 compute cost, so that Tesla T4 - which
# has no bf16 tensor cores - can run Gemma-4 GRPO at full fp16 matmul speed.
#
# The NaN chain (verified via TensorStatisticsHooks on E2B GRPO):
#
# 1. Under UNSLOTH_FORCE_FLOAT32 the loader stores linear weights as fp16
# (see unsloth_zoo/patching_utils.py do_forced_float32 block).
# 2. `down_proj(act_fn(gate_proj(x)) * up_proj(x))` saturates in fp16 at
# `layers.0.mlp.down_proj` (E2B) / `layers.1.mlp.down_proj` (E4B).
# 3. The +-inf in down_proj output poisons the residual stream; the next
# generation step samples NaN logits and the CUDA categorical sampler
# trips a device-side assert.
#
# Minimal fix: a single patch on Gemma4TextMLP. The product is computed in
# fp32, clamped to a safe fp16 bound, and down_proj's output is rescued
# with `nan_to_num` for the rare tail overflow from its fp16 accumulator.
#
# Why we do NOT patch RMSNorm / Attention / Embedding:
# * Gemma4RMSNorm already casts to fp32 internally and returns
# `type_as(hidden_states)` - that dtype contract is fine for fp16.
# * Gemma4TextScaledWordEmbedding multiplies by sqrt(hidden_size) which
# is ~45-60 for E2B/E4B, well within fp16 range.
# * Gemma4TextAttention projections did not overflow in any run; the
# failing path is exclusively the MLP gate*up product + down_proj
# accumulator.
#
# Bisection evidence: 8-step GRPO on unsloth/gemma-4-E2B-it and
# unsloth/gemma-4-E4B-it, with gradient_checkpointing on and off, completes
# cleanly with just patch_Gemma4TextMLP. Adding Attention, RMSNorm, or
# Embedding patches on top produces byte-identical loss / grad_norm / kl
# trajectories.
# ============================================================================


def patch_Gemma4TextMLP():
"""Stabilize Gemma-4 MLP under fp16 autocast (GRPO on fp16, Tesla T4).

Root cause: `down_proj(act_fn(gate_proj(x)) * up_proj(x))` summed over
the wide intermediate dimension can exceed fp16_max = 65504 so the fp16
matmul cast produces +-inf. That inf poisons the residual stream and
generation then samples NaN logits, tripping the categorical assert at
GRPO step ~2 on E2B/E4B with dtype=torch.float16.

Fix, three cheap operations, single patch:

1. act_fn(gate) * up in fp32 so the product cannot overflow.
2. Clamp to 65280 (largest value exactly representable in both fp16
and bf16) before down_proj.
3. nan_to_num on the output, rescuing the rare down_proj fp16
accumulator overflow on wide intermediate dims.

Dtype contract is unchanged from upstream (input dtype -> input dtype),
so RMSNorm / Attention / Embedding need no companion patches and the KV
cache stays aligned with the text attention output. gate_proj, up_proj
and down_proj remain fp16 tensor-core matmuls (full T4 throughput at
65 TFLOPS).
"""
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0":
return
try:
import transformers.models.gemma4.modeling_gemma4 as mod
except ImportError:
return
try:
Gemma4TextMLP = mod.Gemma4TextMLP
except AttributeError as e:
return raise_error("Gemma4TextMLP.forward", e)

# 65280 is the largest value exactly representable in both fp16 and bf16:
# one bf16 ULP below 65536 (the next representable bf16 value) and 224
# below fp16_max=65504. Note fp16_max itself is not representable in bf16
# (it rounds up to 65536). Clamping here therefore survives any downstream
# round-trip through PEFT's internal dtype casts without rounding to inf.
_SAFE_FP16 = 65280.0

def forward(self, x):
gate = self.gate_proj(x)
# Gate on the matmul output dtype rather than x.dtype so that
# bf16/fp32 activations combined with fp16 weights (via autocast or
# do_forced_float32) still enter the stabilization path when the
# projection actually produces fp16 outputs.
if gate.dtype != torch.float16:
return self.down_proj(self.act_fn(gate) * self.up_proj(x))
# fp32 act + multiply so the product cannot overflow before clamp.
product = self.act_fn(gate.float()) * self.up_proj(x).float()
product = torch.clamp(product, min=-_SAFE_FP16, max=_SAFE_FP16)
out = self.down_proj(product.to(gate.dtype))
# nan_to_num catches the rare down_proj fp16 accumulator overflow
# on wide intermediate dims. Replacements are 0 so the MLP output
# at overflow positions defers to the identity residual instead of
# injecting a near-fp16_max value that would dominate hidden_states.
return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
try:
patch_function(
Gemma4TextMLP, "forward", forward, fullgraph=False,
)
except Exception as e:
return raise_error("Gemma4TextMLP.forward", e)
pass
TEMPORARY_PATCHES.append(patch_Gemma4TextMLP)