Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 99 additions & 1 deletion unsloth_zoo/temporary_patches/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import os
from .common import TEMPORARY_PATCHES
from .utils import raise_error
from .utils import raise_error, patch_function


# ============================================================================
Expand Down Expand Up @@ -606,3 +606,101 @@ def forward(self, hidden_states, position_embeddings, attention_mask=None, **kwa
Gemma4AudioAttention.forward = forward
pass
TEMPORARY_PATCHES.append(patch_Gemma4AudioAttention)


# ============================================================================
# Gemma-4 float16 stability patch.
#
# Goal: make float16 training (notably GRPO on E2B and E4B) numerically
# stable without paying blanket fp32 compute cost, so that Tesla T4 - which
# has no bf16 tensor cores - can run Gemma-4 GRPO at full fp16 matmul speed.
#
# The NaN chain (verified via TensorStatisticsHooks on E2B GRPO):
#
# 1. Under UNSLOTH_FORCE_FLOAT32 the loader stores linear weights as fp16
# (see unsloth_zoo/patching_utils.py do_forced_float32 block).
# 2. `down_proj(act_fn(gate_proj(x)) * up_proj(x))` saturates in fp16 at
# `layers.0.mlp.down_proj` (E2B) / `layers.1.mlp.down_proj` (E4B).
# 3. The +-inf in down_proj output poisons the residual stream; the next
# generation step samples NaN logits and the CUDA categorical sampler
# trips a device-side assert.
#
# Minimal fix: a single patch on Gemma4TextMLP. The product is computed in
# fp32, clamped to a safe fp16 bound, and down_proj's output is rescued
# with `nan_to_num` for the rare tail overflow from its fp16 accumulator.
#
# Why we do NOT patch RMSNorm / Attention / Embedding:
# * Gemma4RMSNorm already casts to fp32 internally and returns
# `type_as(hidden_states)` - that dtype contract is fine for fp16.
# * Gemma4TextScaledWordEmbedding multiplies by sqrt(hidden_size) which
# is ~45-60 for E2B/E4B, well within fp16 range.
# * Gemma4TextAttention projections did not overflow in any run; the
# failing path is exclusively the MLP gate*up product + down_proj
# accumulator.
#
# Bisection evidence: 8-step GRPO on unsloth/gemma-4-E2B-it and
# unsloth/gemma-4-E4B-it, with gradient_checkpointing on and off, completes
# cleanly with just patch_Gemma4TextMLP. Adding Attention, RMSNorm, or
# Embedding patches on top produces byte-identical loss / grad_norm / kl
# trajectories.
# ============================================================================


def patch_Gemma4TextMLP():
"""Stabilize Gemma-4 MLP under fp16 autocast (GRPO on fp16, Tesla T4).

Root cause: `down_proj(act_fn(gate_proj(x)) * up_proj(x))` summed over
the wide intermediate dimension can exceed fp16_max = 65504 so the fp16
matmul cast produces +-inf. That inf poisons the residual stream and
generation then samples NaN logits, tripping the categorical assert at
GRPO step ~2 on E2B/E4B with dtype=torch.float16.

Fix, three cheap operations, single patch:

1. act_fn(gate) * up in fp32 so the product cannot overflow.
2. Clamp to 65280 (largest value exactly representable in both fp16
and bf16) before down_proj.
3. nan_to_num on the output, rescuing the rare down_proj fp16
accumulator overflow on wide intermediate dims.

Dtype contract is unchanged from upstream (input dtype -> input dtype),
so RMSNorm / Attention / Embedding need no companion patches and the KV
cache stays aligned with the text attention output. gate_proj, up_proj
and down_proj remain fp16 tensor-core matmuls (full T4 throughput at
65 TFLOPS).
"""
if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0":
return
try:
import transformers.models.gemma4.modeling_gemma4 as mod
except ImportError:
return
try:
Gemma4TextMLP = mod.Gemma4TextMLP
except AttributeError as e:
return raise_error("Gemma4TextMLP.forward", e)

# Largest value exactly representable in both fp16 and bf16 (one bf16
# ULP below 65536, which would round to fp16 inf).
_SAFE_FP16 = 65280.0

def forward(self, x):
gate = self.gate_proj(x)
# Gate on matmul output dtype, not x.dtype, to catch fp16-cast
# projections over bf16/fp32 activations (autocast or forced fp16).
if gate.dtype != torch.float16:
return self.down_proj(self.act_fn(gate) * self.up_proj(x))
product = self.act_fn(gate.float()) * self.up_proj(x).float()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PR description mentions that the fix should "Disable fp16 autocast so gate_proj and up_proj execute at the bf16 weight dtype". However, the current implementation calls self.gate_proj(x) and self.up_proj(x) within the existing autocast context. If fp16 autocast is active, these matmuls will still execute in fp16 and potentially saturate at fp16_max (65504) before the .float() call. While the subsequent clamp and nan_to_num provide stability against product and accumulator overflows, they do not recover the range lost during the initial projections if they saturated. Consider wrapping the projections in with torch.cuda.amp.autocast(enabled=False): if the intent was to leverage the full bf16 range of the weights.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation calls self.up_proj(x) twice in the float16 path: once implicitly via the matmul and once explicitly to cast to float. This can be optimized by computing the projection once and reusing the result.

Suggested change
product = self.act_fn(gate.float()) * self.up_proj(x).float()
up = self.up_proj(x)
product = self.act_fn(gate.float()) * up.float()

product = torch.clamp(product, min=-_SAFE_FP16, max=_SAFE_FP16)
out = self.down_proj(product.to(gate.dtype))
# Replace overflow with 0 so the residual keeps the identity path
# instead of being dominated by a near-fp16_max injection.
return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
try:
patch_function(
Gemma4TextMLP, "forward", forward, fullgraph=False,
)
except Exception as e:
return raise_error("Gemma4TextMLP.forward", e)
pass
TEMPORARY_PATCHES.append(patch_Gemma4TextMLP)