diff --git a/tests/test_patch_gemma4_text_mlp.py b/tests/test_patch_gemma4_text_mlp.py new file mode 100644 index 000000000..99d39e1ff --- /dev/null +++ b/tests/test_patch_gemma4_text_mlp.py @@ -0,0 +1,192 @@ +import os +import sys +import types +import importlib.util + +os.environ.setdefault("UNSLOTH_IS_PRESENT", "1") +os.environ.setdefault("UNSLOTH_COMPILE_DISABLE", "1") +if "unsloth" not in sys.modules: + _stub = types.ModuleType("unsloth") + _stub.__spec__ = importlib.util.spec_from_loader("unsloth", loader=None) + _stub.__path__ = [] + sys.modules["unsloth"] = _stub + +import pytest +import torch +import torch.nn as nn + +from unsloth_zoo.temporary_patches import gemma4 as g4 + + +def _make_mlp_class(): + class Gemma4TextMLP(nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = nn.Linear(64, 2048, bias=False) + self.up_proj = nn.Linear(64, 2048, bias=False) + self.down_proj = nn.Linear(2048, 64, bias=False) + self.act_fn = nn.GELU(approximate="tanh") + with torch.no_grad(): + for p in ( + self.gate_proj.weight, + self.up_proj.weight, + self.down_proj.weight, + ): + p.fill_(0.5) + + def forward(self, x): + return self.down_proj( + self.act_fn(self.gate_proj(x)) * self.up_proj(x) + ) + + return Gemma4TextMLP + + +def _install_module_stub(monkeypatch, cls): + fake = types.ModuleType("transformers.models.gemma4.modeling_gemma4") + fake.Gemma4TextMLP = cls + for pkg in ( + "transformers", + "transformers.models", + "transformers.models.gemma4", + ): + if pkg not in sys.modules: + p = types.ModuleType(pkg) + p.__path__ = [] + monkeypatch.setitem(sys.modules, pkg, p) + monkeypatch.setitem( + sys.modules, "transformers.models.gemma4.modeling_gemma4", fake + ) + + +def test_noop_when_force_float32_unset(monkeypatch): + monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False) + cls = _make_mlp_class() + _install_module_stub(monkeypatch, cls) + original = cls.forward + g4.patch_Gemma4TextMLP() + assert cls.forward is original + + +def test_noop_when_force_float32_zero(monkeypatch): + monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "0") + cls = _make_mlp_class() + _install_module_stub(monkeypatch, cls) + original = cls.forward + g4.patch_Gemma4TextMLP() + assert cls.forward is original + + +def test_upstream_without_patch_overflows_fp16(): + cls = _make_mlp_class() + m = cls().half().eval() + torch.manual_seed(0) + x = torch.randn(2, 64, dtype=torch.float16) * 20.0 + with torch.no_grad(): + out = m(x) + assert (~torch.isfinite(out)).any().item() + + +def test_fp16_overflow_output_is_finite(monkeypatch): + monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1") + cls = _make_mlp_class() + _install_module_stub(monkeypatch, cls) + g4.patch_Gemma4TextMLP() + m = cls().half().eval() + torch.manual_seed(0) + x = torch.randn(2, 64, dtype=torch.float16) * 20.0 + with torch.no_grad(): + out = m(x) + assert torch.all(torch.isfinite(out)).item() + assert out.dtype == torch.float16 + + +def test_fp16_nan_to_num_replaces_with_zero(monkeypatch): + monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1") + cls = _make_mlp_class() + _install_module_stub(monkeypatch, cls) + g4.patch_Gemma4TextMLP() + m = cls().half().eval() + torch.manual_seed(0) + x = torch.randn(2, 64, dtype=torch.float16) * 20.0 + with torch.no_grad(): + out = m(x) + assert out.abs().max().item() == 0.0 + + +def test_fp16_normal_input_produces_nonzero_output(monkeypatch): + monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1") + cls = _make_mlp_class() + _install_module_stub(monkeypatch, cls) + g4.patch_Gemma4TextMLP() + m = cls().half().eval() + torch.manual_seed(0) + x = torch.randn(2, 64, dtype=torch.float16) * 0.1 + with torch.no_grad(): + out = m(x) + assert torch.all(torch.isfinite(out)).item() + assert (out != 0).any().item() + + +def test_bf16_input_matches_upstream(monkeypatch): + monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1") + cls = _make_mlp_class() + _install_module_stub(monkeypatch, cls) + upstream = cls.forward + g4.patch_Gemma4TextMLP() + m = cls().to(torch.bfloat16).eval() + torch.manual_seed(0) + x = torch.randn(2, 64, dtype=torch.bfloat16) * 50.0 + with torch.no_grad(): + patched = m(x) + expected = upstream(m, x) + assert torch.equal(patched, expected) + assert patched.dtype == torch.bfloat16 + + +def test_fp32_input_matches_upstream(monkeypatch): + monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1") + cls = _make_mlp_class() + _install_module_stub(monkeypatch, cls) + upstream = cls.forward + g4.patch_Gemma4TextMLP() + m = cls().eval() + torch.manual_seed(0) + x = torch.randn(2, 64) + with torch.no_grad(): + patched = m(x) + expected = upstream(m, x) + assert torch.equal(patched, expected) + + +def test_idempotent_patch_install(monkeypatch): + monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1") + cls = _make_mlp_class() + _install_module_stub(monkeypatch, cls) + g4.patch_Gemma4TextMLP() + g4.patch_Gemma4TextMLP() + assert cls.forward.__name__ == "forward" + m = cls().half().eval() + torch.manual_seed(0) + x = torch.randn(2, 64, dtype=torch.float16) * 0.1 + with torch.no_grad(): + out = m(x) + assert torch.all(torch.isfinite(out)).item() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_bf16_weights_fp16_autocast_stabilizes(monkeypatch): + # bf16-capable GPU + user-wrapped torch.amp.autocast(fp16): x.dtype stays + # bf16 but self.gate_proj(x) runs in fp16 and can overflow. The + # gate.dtype guard must enter stabilization here; an x.dtype guard would + # bypass and leave the overflow unfixed. + monkeypatch.setenv("UNSLOTH_FORCE_FLOAT32", "1") + cls = _make_mlp_class() + _install_module_stub(monkeypatch, cls) + g4.patch_Gemma4TextMLP() + m = cls().cuda().to(torch.bfloat16).eval() + torch.manual_seed(0) + x = torch.randn(2, 64, dtype=torch.bfloat16, device="cuda") * 20.0 + with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.float16): + out = m(x) + assert torch.all(torch.isfinite(out)).item() diff --git a/unsloth_zoo/temporary_patches/gemma4.py b/unsloth_zoo/temporary_patches/gemma4.py index 3a91bba0c..647a157fe 100644 --- a/unsloth_zoo/temporary_patches/gemma4.py +++ b/unsloth_zoo/temporary_patches/gemma4.py @@ -17,7 +17,7 @@ import torch import os from .common import TEMPORARY_PATCHES -from .utils import raise_error +from .utils import raise_error, patch_function # ============================================================================ @@ -606,3 +606,109 @@ def forward(self, hidden_states, position_embeddings, attention_mask=None, **kwa Gemma4AudioAttention.forward = forward pass TEMPORARY_PATCHES.append(patch_Gemma4AudioAttention) + + +# ============================================================================ +# Gemma-4 float16 stability patch. +# +# Goal: make float16 training (notably GRPO on E2B and E4B) numerically +# stable without paying blanket fp32 compute cost, so that Tesla T4 - which +# has no bf16 tensor cores - can run Gemma-4 GRPO at full fp16 matmul speed. +# +# The NaN chain (verified via TensorStatisticsHooks on E2B GRPO): +# +# 1. Under UNSLOTH_FORCE_FLOAT32 the loader stores linear weights as fp16 +# (see unsloth_zoo/patching_utils.py do_forced_float32 block). +# 2. `down_proj(act_fn(gate_proj(x)) * up_proj(x))` saturates in fp16 at +# `layers.0.mlp.down_proj` (E2B) / `layers.1.mlp.down_proj` (E4B). +# 3. The +-inf in down_proj output poisons the residual stream; the next +# generation step samples NaN logits and the CUDA categorical sampler +# trips a device-side assert. +# +# Minimal fix: a single patch on Gemma4TextMLP. The product is computed in +# fp32, clamped to a safe fp16 bound, and down_proj's output is rescued +# with `nan_to_num` for the rare tail overflow from its fp16 accumulator. +# +# Why we do NOT patch RMSNorm / Attention / Embedding: +# * Gemma4RMSNorm already casts to fp32 internally and returns +# `type_as(hidden_states)` - that dtype contract is fine for fp16. +# * Gemma4TextScaledWordEmbedding multiplies by sqrt(hidden_size) which +# is ~45-60 for E2B/E4B, well within fp16 range. +# * Gemma4TextAttention projections did not overflow in any run; the +# failing path is exclusively the MLP gate*up product + down_proj +# accumulator. +# +# Bisection evidence: 8-step GRPO on unsloth/gemma-4-E2B-it and +# unsloth/gemma-4-E4B-it, with gradient_checkpointing on and off, completes +# cleanly with just patch_Gemma4TextMLP. Adding Attention, RMSNorm, or +# Embedding patches on top produces byte-identical loss / grad_norm / kl +# trajectories. +# ============================================================================ + + +def patch_Gemma4TextMLP(): + """Stabilize Gemma-4 MLP under fp16 autocast (GRPO on fp16, Tesla T4). + + Root cause: `down_proj(act_fn(gate_proj(x)) * up_proj(x))` summed over + the wide intermediate dimension can exceed fp16_max = 65504 so the fp16 + matmul cast produces +-inf. That inf poisons the residual stream and + generation then samples NaN logits, tripping the categorical assert at + GRPO step ~2 on E2B/E4B with dtype=torch.float16. + + Fix, three cheap operations, single patch: + + 1. act_fn(gate) * up in fp32 so the product cannot overflow. + 2. Clamp to 65280 (largest value exactly representable in both fp16 + and bf16) before down_proj. + 3. nan_to_num on the output, rescuing the rare down_proj fp16 + accumulator overflow on wide intermediate dims. + + Dtype contract is unchanged from upstream (input dtype -> input dtype), + so RMSNorm / Attention / Embedding need no companion patches and the KV + cache stays aligned with the text attention output. gate_proj, up_proj + and down_proj remain fp16 tensor-core matmuls (full T4 throughput at + 65 TFLOPS). + """ + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + return + try: + import transformers.models.gemma4.modeling_gemma4 as mod + except ImportError: + return + try: + Gemma4TextMLP = mod.Gemma4TextMLP + except AttributeError as e: + return raise_error("Gemma4TextMLP.forward", e) + + # 65280 is the largest value exactly representable in both fp16 and bf16: + # one bf16 ULP below 65536 (the next representable bf16 value) and 224 + # below fp16_max=65504. Note fp16_max itself is not representable in bf16 + # (it rounds up to 65536). Clamping here therefore survives any downstream + # round-trip through PEFT's internal dtype casts without rounding to inf. + _SAFE_FP16 = 65280.0 + + def forward(self, x): + gate = self.gate_proj(x) + # Gate on the matmul output dtype rather than x.dtype so that + # bf16/fp32 activations combined with fp16 weights (via autocast or + # do_forced_float32) still enter the stabilization path when the + # projection actually produces fp16 outputs. + if gate.dtype != torch.float16: + return self.down_proj(self.act_fn(gate) * self.up_proj(x)) + # fp32 act + multiply so the product cannot overflow before clamp. + product = self.act_fn(gate.float()) * self.up_proj(x).float() + product = torch.clamp(product, min=-_SAFE_FP16, max=_SAFE_FP16) + out = self.down_proj(product.to(gate.dtype)) + # nan_to_num catches the rare down_proj fp16 accumulator overflow + # on wide intermediate dims. Replacements are 0 so the MLP output + # at overflow positions defers to the identity residual instead of + # injecting a near-fp16_max value that would dominate hidden_states. + return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0) + try: + patch_function( + Gemma4TextMLP, "forward", forward, fullgraph=False, + ) + except Exception as e: + return raise_error("Gemma4TextMLP.forward", e) +pass +TEMPORARY_PATCHES.append(patch_Gemma4TextMLP)