Skip to content

Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability#600

Merged
danielhanchen merged 13 commits into
mainfrom
gemma4-fp16-force-float32
Apr 20, 2026
Merged

Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability#600
danielhanchen merged 13 commits into
mainfrom
gemma4-fp16-force-float32

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Problem

unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it GRPO training NaNs and crashes with a CUDA device-side assert at step 2 when the user passes dtype=torch.float16. torch.bfloat16 is unaffected.

Minimal reproduction, 8 steps on a single B200:

python scripts/gemma4_grpo_sudoku_repro.py --model unsloth/gemma-4-E2B-it --dtype float16 --max-steps 8 --grad-accum 4 --num-generations 2

Unpatched result (dtype=float16):

step 1: loss=0.0, grad_norm=0.1108, reward=1.416
step 2: CUDA error: device-side assert triggered

Hooks on the model show the first overflow occurs in language_model.layers.0.mlp.down_proj. post_stats[base_layer] hits +inf during step 1's forward, poisons the residual stream, and step 2's generation produces NaN logits that trip the categorical sampler.

Root cause

Under UNSLOTH_FORCE_FLOAT32=1, the loader stores weights as bfloat16 but GRPO enters its fp16 autocast context for training. Gemma-4's text-decoder MLP does

down_proj(act_fn(gate_proj(x)) * up_proj(x))

with no internal upcast. Inside fp16 autocast, gate and up outputs saturate at fp16_max, the product becomes +inf on downcast, and down_proj's fp16 matmul accumulator then produces +inf for the residual add. bf16 avoids this because its range covers fp32 magnitudes.

Gemma-3 already ships UNSLOTH_FORCE_FLOAT32 patches for the same failure mode. The corresponding Gemma-4 patches existed briefly and were removed in 158e981 (2026-04-06) after initial testing found it stable. GRPO fp16 training exercises code paths the earlier testing did not.

Fix

Four text-decoder patches in unsloth_zoo/temporary_patches/gemma4.py, each gated on UNSLOTH_FORCE_FLOAT32=1 and each appended to TEMPORARY_PATCHES:

Patch What it does
patch_Gemma4RMSNorm fp32 variance and weight multiply, clamp to 65280 (one bf16 ulp below fp16_max so later fp16 casts cannot round up to inf), then fp16 cast.
patch_Gemma4TextScaledWordEmbedding Embed lookup in fp32 multiplied by the fp32 embed_scale so the first RMSNorm gets full precision input.
patch_Gemma4TextMLP Disable fp16 autocast so gate_proj and up_proj execute at the bf16 weight dtype; activation and multiply in fp32; clamp to 65280; fp16 cast; torch.nan_to_num safety net rescues the rare case where down_proj's fp16 accumulator still overflows for wide intermediate dims.
patch_Gemma4TextAttention Same autocast-disabled pattern for Q/K/V projections, fp32 q/k/v_norm, fp32 RoPE, fp32 SDPA, clamp and fp16 cast before o_proj, same nan_to_num safety net. Preserves the KV-sharing and shared_kv_states contract.

The paired change in unslothai/unsloth adds "gemma4" to FORCE_FLOAT32 so these patches actually engage for fp16 requests.

No Gemma-4 causal mask patch is added because, unlike Gemma-3, Gemma-4 delegates mask construction to ALL_ATTENTION_FUNCTIONS and the existing sdpa / eager / flex_attention paths handle mask dtype correctly once the Q/K/V inputs are finite.

Verification

All runs on a B200 using scripts/gemma4_grpo_sudoku_repro.py with the plan's parameters (max_steps=8, grad_accum=4, num_generations=2, batch_size=1, seed 3407).

Model dtype GC Status step 1 loss step 8 loss step 8 grad_norm step 8 kl
E2B-it bf16 baseline unsloth pass -2.79e-06 4.92e-07 0.3908 5.74e-05
E2B-it fp16 unpatched unsloth CUDA assert at step 2 0.0 n/a n/a n/a
E2B-it fp16 patched unsloth pass 0.0 7.15e-07 0.2655 1.18e-04
E2B-it fp16 patched off pass 0.0 7.15e-07 0.2655 1.18e-04
E4B-it fp16 patched unsloth pass 1.64e-07 1.42e-08 0.0000 1.42e-05

hooks.find_overflows() on the patched runs reports only the embedding input-id pseudo-overflow (hooks see integer token IDs up to the 256k vocab) and the layer-0 down_proj saturation that the nan_to_num safety net catches. No parameter contains NaN or inf at the end of any patched run (post_train_bad_params: []).

Notes for reviewers

  • The clamp bound is 65280 rather than fp16_max=65504 because bf16's next-representable value above 65504 is 65536, which becomes inf on the fp16 cast taken by the PEFT LoRA wrapper on the down_proj base_layer.
  • The nan_to_num call is a defense-in-depth measure, not the primary fix. It only fires when the fp16 accumulator in the down_proj / o_proj matmul tips past fp16_max even for a bounded input; the upstream clamp in the MLP and attention paths prevents this in the vast majority of steps.
  • The patches are no-ops when UNSLOTH_FORCE_FLOAT32=0 (i.e. any bf16 run) so bf16 users see no behavior change.

Gemma-4 E2B/E4B GRPO training hits a CUDA device-side assert at step 2
when dtype=torch.float16 because the text-decoder MLP saturates in fp16:
gate_proj and up_proj outputs reach fp16_max under the fp16 autocast
context, the gate*up product overflows on downcast, and the subsequent
down_proj fp16 matmul accumulator tips to +inf. Generation with an inf
residual stream produces NaN logits and the categorical sampler aborts.

bf16 training is unaffected because the same intermediate magnitudes fit
in bf16's range. Mirror the Gemma-3 UNSLOTH_FORCE_FLOAT32 recipe with
four Gemma-4 specific patches, all gated on the env flag:

  * Gemma4RMSNorm: fp32 norm and scale, clamp to 65280 before fp16 cast
    (65280 is one bf16 ulp below fp16_max, avoiding the bf16 round-up
    that otherwise produces +-inf on later fp16 conversions).
  * Gemma4TextScaledWordEmbedding: embed lookup and scale in fp32.
  * Gemma4TextMLP: execute gate_proj and up_proj with autocast disabled
    so they see bf16 weights (the actual model dtype under
    UNSLOTH_FORCE_FLOAT32), run activation plus multiply in fp32, clamp
    to the safe fp16 bound, then down_proj in fp16. A final nan_to_num
    rescues the rare case where down_proj's fp16 accumulator overflows
    for wide intermediate dims.
  * Gemma4TextAttention: same autocast-disabled pattern for Q/K/V
    projections, fp32 q_norm/k_norm/v_norm, fp32 RoPE, fp32 SDPA, clamp
    and fp16 cast before o_proj, plus nan_to_num safety net.

Repro recipe (see PR description for full table):

  python scripts/gemma4_grpo_sudoku_repro.py \
    --model unsloth/gemma-4-E2B-it --dtype float16 \
    --max-steps 8 --grad-accum 4 --num-generations 2

Without the patches this crashes at step 2 with CUDA device-side assert
and overflow flagged in layers.0.mlp.down_proj. With the patches
unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it both complete 8 GRPO
steps with finite loss and grad_norm values close to the bf16 baseline,
under both gradient_checkpointing="unsloth" and gradient_checkpointing
=False. bf16 behavior is unaffected because the patches are gated on
UNSLOTH_FORCE_FLOAT32=1 which only fires for float16 requests.
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces numerical stability patches for Gemma-4 float16 training, specifically targeting GRPO, by implementing custom forward passes for RMSNorm, TextScaledWordEmbedding, TextMLP, and TextAttention to handle precision and overflow issues. The review identified several unused imports and redundant variable definitions that should be cleaned up to improve code quality.

Comment thread unsloth_zoo/temporary_patches/gemma4.py Outdated
Comment on lines +17 to +28
import inspect
from typing import Optional
import torch
import os
from .common import TEMPORARY_PATCHES
from .utils import raise_error
from .common import TEMPORARY_PATCHES, torch_compile
from .utils import (
raise_error,
patch_function,
patch_function_past_key_values,
KWARGS_TYPE,
Cache,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Several newly added imports are unused in this file and can be removed to keep the code clean. Specifically, inspect, torch_compile, patch_function_past_key_values, and KWARGS_TYPE are not referenced in the current implementation. Note that inspect is already imported locally within the _make_kv_shared_use_cache_false_safe_forward function (line 422).

from typing import Optional
import torch
import os
from .common import TEMPORARY_PATCHES
from .utils import (
    raise_error,
    patch_function,
    Cache,
)

Comment thread unsloth_zoo/temporary_patches/gemma4.py Outdated
Comment on lines +802 to +803
fp16_max = float(torch.finfo(torch.float16).max)
fp16_min = float(torch.finfo(torch.float16).min)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variables fp16_max and fp16_min are defined but never used in the patched forward function. The implementation correctly uses the constant _SAFE_FP16 (65280.0) for clamping and saturation instead.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: bfb319b069

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/temporary_patches/gemma4.py Outdated
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor],
shared_kv_states: dict,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Match Gemma4TextAttention forward signature before patching

This patched forward adds a required shared_kv_states argument, but patch_function only installs replacements whose parameter names/kinds are compatible; in the supported Transformers range (<=5.5.0 in pyproject.toml), Gemma4TextAttention.forward takes past_key_values and **kwargs (no shared_kv_states). That signature mismatch causes the patch registration to fail silently, so this commit’s fp16 stability fix is not actually applied and the original GRPO NaN/device-assert failure remains.

Useful? React with 👍 / 👎.

Bisection showed the NaN chain originates exclusively in the text-decoder
MLP gate*up -> down_proj path:

  * Gemma4RMSNorm already computes internally in fp32 and returns
    type_as(hidden_states), which is finite for the trained weight ranges.
  * Gemma4TextScaledWordEmbedding scales by sqrt(hidden_size) which is
    ~45 to 60 for E2B/E4B and fits in fp16.
  * Gemma4TextAttention projections did not overflow in any run.

So the earlier RMSNorm, Embedding, and Attention patches were redundant
once the MLP is stabilized. Removing them lets the text pipeline keep its
original dtype contract (no autocast gymnastics, no KV cache dtype
mismatch), which in turn keeps the gate/up/down matmuls on fp16 tensor
cores.

The MLP patch itself is reduced to the three operations that matter:

  1. Compute act_fn(gate) * up in fp32 so the product cannot overflow.
  2. Clamp to 65280 (one bf16 ulp below fp16_max) before down_proj so
     the fp16 cast cannot produce +-inf.
  3. nan_to_num on the output as defense-in-depth for the rare fp16
     accumulator overflow in down_proj for wide intermediate dims.

Verified on B200 (fp16 autocast under GRPO) - all runs complete 8 steps
with no NaN and no bad params:

  | model                       | gc       | step1 grad | step8 grad | step8 kl |
  | unsloth/gemma-4-E2B-it      | unsloth  | 0.1110     | 0.4077     | 6.03e-05 |
  | unsloth/gemma-4-E2B-it      | off      | 0.1110     | 0.2655     | 1.18e-04 |
  | unsloth/gemma-4-E4B-it      | unsloth  | 0.0987     | 0.0000     | 1.47e-05 |

Tesla T4 compatibility: T4 has no bf16 tensor cores so the loader's
FORCE_FLOAT32 path lands on fp16 weights. The patch's fp32 gate*up is an
elementwise op (runs on CUDA cores at 8.1 TFLOPS); gate_proj, up_proj,
and down_proj matmuls stay on fp16 tensor cores (65 TFLOPS). Net perf
overhead is minimal and the overflow prevention cost is paid only once
per MLP forward.

UNSLOTH_ENABLE_FLEX_ATTENTION on Gemma-4: flex is enabled by default but
only used when sdpa is unavailable. Gemma-4 supports sdpa so flex is not
selected for the text decoder in practice. Setting the env var to 0
produces identical E4B trajectories and near-identical E2B (diverges only
at step 8, still finite, no crash).
@danielhanchen

Copy link
Copy Markdown
Member Author

Reduced to a single patch after a cleaner bisection.

patch_Gemma4TextMLP alone is sufficient to fix the step-2 NaN on unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it GRPO under dtype=torch.float16. The RMSNorm, Embedding, and Attention patches from the earlier version are redundant:

  • Upstream Gemma4RMSNorm already computes internally in fp32 and uses type_as(hidden_states); for trained weight ranges the output is finite in fp16.
  • Gemma4TextScaledWordEmbedding multiplies by sqrt(hidden_size) = ~45 to 60 for E2B/E4B, well within fp16.
  • Gemma4TextAttention projections did not overflow in any run.

Once the MLP no longer writes +-inf into the residual stream, the rest of the text pipeline stays finite without modification. Removing the three extra patches also drops the autocast-disable gymnastics and keeps the gate / up / down matmuls on fp16 tensor cores (important for Tesla T4 which has no bf16 tensor cores).

The MLP patch is now three operations:

def forward(self, x):
    gate = self.gate_proj(x)
    up = self.up_proj(x)
    product = self.act_fn(gate.float()) * up.float()
    product = torch.clamp(product, min=-65280.0, max=65280.0)
    out = self.down_proj(product.to(x.dtype))
    return torch.nan_to_num(out, nan=0.0, posinf=65280.0, neginf=-65280.0)

Verified on B200 (fp16 GRPO, max_steps=8, grad_accum=4, num_generations=2):

Model GC step 1 grad_norm step 8 grad_norm step 8 kl
unsloth/gemma-4-E2B-it unsloth 0.1110 0.4077 6.03e-05
unsloth/gemma-4-E2B-it off 0.1110 0.2655 1.18e-04
unsloth/gemma-4-E4B-it unsloth 0.0987 0.0000 1.47e-05

All runs complete 8 steps with post_train_bad_params: [].

UNSLOTH_ENABLE_FLEX_ATTENTION note: flex is enabled by default but Gemma-4 text attention uses sdpa when sdpa is available. Setting the env to 0 produces identical E4B trajectories and near-identical E2B (steps 1-7 identical; step 8 diverges by a small amount but stays finite). No action needed.

Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
The int8 branch (torch.ops.aten._weight_int8pack_mm) worked end-to-end
but failed the RL-parity goal: 100-step GRPO on unsloth/gemma-4-E2B-it
with temperature=0.05, min_p=0.5, seed=3407 gave total |KL| of 7.47e+06
for int8 vs 52.16 for fp16+clamp vs 9387 for bf16, and int8 step 7
already had grad_norm=NaN.

Cause: GRPO's log-pi-new - log-pi-old ratio amplifies the ~7%
per-matmul weight-quantization noise, because the rollout path runs
under torch.inference_mode (dense fp16) while the training forward
used int8. Making them symmetric would require a separate int8
reference model, which is out of scope for a surgical NaN fix.

This commit removes all int8 code (Int8LinearFn autograd wrapper, row
quantizer, PEFT LoRA-aware int8 forward, UNSLOTH_GEMMA4_MLP_INT8 env
switch) and keeps only the fp16 trick:

  - fp32 act_fn(gate) * up so the product cannot overflow
  - clamp to 65280 before down_proj
  - nan_to_num on down_proj output for the rare accumulator tail

The patched forward is now 7 effective lines. Dtype contract is
identical to upstream (input dtype -> input dtype), so no attention /
RMSNorm / embedding companion patches are needed and the KV cache
stays aligned.

100-step verification vs bf16 (temp=0.05, min_p=0.5, seed=3407):

  median |KL|   bf16 4.02e-05  fp16+clamp 7.99e-06 (0.20x)
  p95 |KL|      bf16 0.411     fp16+clamp 0.620    (1.51x)
  max |KL|      bf16 8529      fp16+clamp 35.77    (0.004x)
  total |KL|    bf16 9388      fp16+clamp 52.16    (0.006x)
  mean reward   bf16 1.2111    fp16+clamp 1.1359   (|d|=0.075)
  reward-equal  75/100 steps
  time          +0.3%

Note: bf16's two huge outlier steps (step 24 |KL|=8529 grad=26461;
step 42 |KL|=800 grad=1434) are what push its total KL above
fp16+clamp's. For the calm 63-66% of steps both trajectories track
each other to 1e-5 precision.
@danielhanchen

Copy link
Copy Markdown
Member Author

Final state: shipping only the minimal fp16 MLP trick.

The int8 weight-only variant (torch.ops.aten._weight_int8pack_mm + autograd.Function + LoRA-aware forward) was prototyped and removed after a 100-step GRPO parity run showed it fails the bf16-closeness goal: GRPO's log_pi_new - log_pi_old ratio amplifies the ~7% per-matmul quantization noise because the rollout path runs under torch.inference_mode (dense fp16 fallback) while the training forward would use int8. Symmetric int8 on both paths would need a separate int8 reference model, which is out of scope for a surgical NaN fix.

100-step parity table on unsloth/gemma-4-E2B-it GRPO (temperature=0.05, min_p=0.5, seed=3407, grad_accum=4, num_generations=2):

metric bf16 fp16+clamp (shipped) int8 weight-only (dropped)
median |KL| 4.02e-05 7.99e-06 1.06e-02
p95 |KL| 0.411 0.620 2.788
max |KL| 8529 35.77 7.46e+06
total |KL| 9388 52.16 7,471,000
mean reward 1.2111 1.1359 1.1611
calm steps (|KL|<1e-3) 63/100 66/100 42/100
big spikes (|KL|>1.0) 4/100 3/100 8/100
training time 3712s 3722s (+0.3%) 4256s (+14.7%)
post_train_bad_params 0 0 0

Reward-identical steps (|delta|<0.01): bf16 vs fp16+clamp = 75/100.

The shipped patch body is now three operations:

def forward(self, x):
    gate = self.gate_proj(x)
    up = self.up_proj(x)
    product = self.act_fn(gate.float()) * up.float()
    product = torch.clamp(product, min=-65280.0, max=65280.0)
    out = self.down_proj(product.to(x.dtype))
    return torch.nan_to_num(out, nan=0.0, posinf=65280.0, neginf=-65280.0)

Dtype contract unchanged from upstream, no attention / RMSNorm / embedding companion patches, gate / up / down matmuls stay on fp16 tensor cores (full T4 throughput at 65 TFLOPS).

Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b8d9b848f4

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/temporary_patches/gemma4.py Outdated
# downstream round-trip through PEFT's internal dtype casts.
_SAFE_FP16 = 65280.0

def forward(self, x: torch.Tensor) -> torch.Tensor:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Match Gemma4TextMLP patch signature to upstream forward

The patched forward adds type annotations (x: torch.Tensor -> torch.Tensor), but patch_function is called with default match_level="strict", which requires parameter annotations to exactly match the original method. In the supported Transformers range (verified on 5.5.0), Gemma4TextMLP.forward is (self, x) with no annotations, so patch_function(...) returns False and this stability patch is silently skipped; GRPO fp16 runs then continue using the unpatched path.

Useful? React with 👍 / 👎.

65280 is the largest value exactly representable in both fp16 and bf16
(one bf16 ULP below 65536, 224 below fp16_max=65504). The previous
wording claimed it was one bf16 ulp below fp16_max, but fp16_max=65504
is not representable in bf16 at all -- it rounds up to 65536.

Clamp value is unchanged; comment only.
danielhanchen added a commit to danielhanchen/unsloth-zoo-staging that referenced this pull request Apr 19, 2026
…nan rescue, inline import

- Add `x.dtype != torch.float16` guard so bf16/fp32 activations pass
  through the upstream forward unchanged. Under UNSLOTH_FORCE_FLOAT32
  the weights are fp16 so this is normally unreachable, but the guard
  protects against pipeline configurations where autocast is disabled
  (rl_replacements nullcontext path) and prevents unintended clamping
  if activations reach the MLP as bf16.

- Change nan_to_num replacements from +-65280 to 0 so overflow
  positions contribute nothing and leave the identity residual intact.
  Replacing with the fp16 ceiling would otherwise dominate the O(1)
  post-RMSNorm hidden state. Backward gradient through nan_to_num is
  `grad * isfinite(input)` in both cases, so this is loss-free.

- Drop the one-off `_gemma4_modeling` helper and inline the import to
  match the pattern of every other patch in this file.

- Drop the `x: torch.Tensor` annotation to match the rest of the
  temporary patches (`def forward(self, x):`).
danielhanchen added a commit to danielhanchen/unsloth-zoo-staging that referenced this pull request Apr 19, 2026
- Switch the stabilization guard from `x.dtype != torch.float16` to
  `gate.dtype != torch.float16`. The matmul output dtype is what
  determines whether overflow can occur, so this also catches
  mixed-precision cases (bf16 activations through fp16-cast weights via
  autocast or do_forced_float32) that the x.dtype check missed. For
  the standard UNSLOTH_FORCE_FLOAT32 path behavior is unchanged.

- Inline `self.up_proj(x).float()` so the fp16 up tensor is transient,
  and reuse the already-computed gate in the bypass path.

- Cast product back with `gate.dtype` instead of `x.dtype` to avoid a
  bf16-input/fp16-weight mismatch at down_proj if stabilization is
  active with non-fp16 activations.
danielhanchen added a commit to danielhanchen/unsloth-zoo-staging that referenced this pull request Apr 19, 2026
danielhanchen added a commit to danielhanchen/unsloth-zoo-staging that referenced this pull request Apr 19, 2026
danielhanchen added a commit to danielhanchen/unsloth-zoo-staging that referenced this pull request Apr 19, 2026
@danielhanchen danielhanchen added auto-approved Auto-review approved the PR and removed auto-reviewing Auto-review in progress labels Apr 19, 2026
@danielhanchen

Copy link
Copy Markdown
Member Author

Auto-review verdict: Approved

Adds patch_Gemma4TextMLP which stabilizes Gemma-4 GRPO training under fp16 on Tesla T4 by running gate*up in fp32, clamping to 65280, and rescuing down_proj accumulator overflows with nan_to_num, gated on UNSLOTH_FORCE_FLOAT32 and the gate_proj output dtype so bf16/fp32 paths stay byte-identical to upstream.

Reason: Correct fp16 overflow rescue for Gemma-4 MLP; all accepted review findings fixed, no remaining regressions.

Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
danielhanchen added a commit to shimmyshimmer/unsloth-zoo-staging-2 that referenced this pull request Apr 19, 2026
@danielhanchen danielhanchen removed the auto-reviewing Auto-review in progress label Apr 19, 2026
@danielhanchen

Copy link
Copy Markdown
Member Author

Auto-review verdict: Approved

Adds patch_Gemma4TextMLP to stabilize Gemma-4 fp16 GRPO training (notably Tesla T4) by computing gate*up in fp32 with safe clamping and nan_to_num rescue; review hardened it with signature-mismatch detection, inf/NaN neutralization before act_fn, and an idempotency sentinel.

Reason: All real issues fixed; 16/16 tests pass; core fp32-upcast + clamp + nan_to_num logic is correct and minimal.

The existing fast-path check on `gate.dtype != torch.float16` already
covers every case where the clamp is unnecessary (bf16, fp32, fp8). With
that in place, the outer FORCE_FLOAT32 env gate was the only thing
preventing the patch from firing for users who pass `dtype=torch.float16`
directly without also going through the FORCE_FLOAT32 load-dtype-swap
path.

Drops the dependency on the companion `FORCE_FLOAT32` entry in the
unsloth loader (PR unslothai/unsloth#5092) - the patch now self-activates
whenever the activation at the MLP boundary is fp16, regardless of how
that fp16 was produced (explicit dtype, autocast, PEFT cast).

Verified on unsloth/gemma-4-E2B-it GRPO at fp16 for 100 steps without
FORCE_FLOAT32: 1,224 fp16_max overflows on layer-0 gate*up rescued by
the clamp + nan_to_num, reward trajectory matches the FORCE_FLOAT32
baseline to within seed noise. Updated the section comment and docstring
to reflect the new behaviour.
Comment thread unsloth_zoo/temporary_patches/gemma4.py Fixed
The section header, docstring, and inline comments had grown to cover
every intermediate diagnostic we went through. Most of that detail
belongs in the PR description or a post-mortem, not in the shipped
source. Keep only what the next reader of this file needs: the single
sentence explaining the overflow, the fix, the dtype-gate behaviour,
and the reason _SAFE_FP16 is 65280 rather than 65504.
)
except Exception as e:
return raise_error("Gemma4TextMLP.forward", e)
pass
# trajectories).


def patch_Gemma4TextMLP():
@danielhanchen danielhanchen merged commit 74f49c2 into main Apr 20, 2026
3 checks passed
@danielhanchen danielhanchen deleted the gemma4-fp16-force-float32 branch April 20, 2026 05:56
danielhanchen added a commit to shimmyshimmer/unsloth-zoo-staging-2 that referenced this pull request Apr 20, 2026
All edits touch load-bearing code introduced by the Gemma4 / Qwen3.5 /
dtype-handling work on this branch. Citations below explain why each
hunk is not a regression of the cited commit.

vllm_utils.py:1782 (blame: "[WIP] gemma 4 dense fast inference"):
  The original gating correctly fires the LoRA patch only for vision
  Gemma4, but it also hides the BnB k_eq_v loader patch behind
  is_vision_model. Text-only Gemma4 E2B/E4B loaded with BnB4bit still
  needs the k_eq_v quant-state duplication the same commit added,
  because attention_k_eq_v is set on the text config regardless of
  modality. This hunk keeps the LoRA patch vision-gated and broadens
  the k_eq_v patch to every gemma4 load.

vllm_utils.py:1345 (blame: "fix lm_head detection and remove moe"):
  "conv1d" was added to layernorm_names as part of the Qwen3.5 GDN work
  to avoid the Linear-rebuild branch. However, the layernorm branch
  only swaps the .weight tensor on the empty-model placeholder Conv1d
  (kernel_size=1, groups=1), which does not match the real GDN
  depthwise conv (kernel_size = linear_conv_kernel_dim, groups =
  conv_dim) and breaks forward. The new dedicated conv1d branch
  rebuilds the module from the real weight shape; removing the
  substring entry from layernorm_names is required to reach it. No
  existing helper in unsloth_zoo rebuilds Conv1d modules (grepped),
  so the inline block is not a duplicate.

vllm_utils.py:1216 (_normalize_state_dict_tensor):
  Non-tensor guard added so quant_state dict values (added by the same
  PR's new GDN path) no longer raise AttributeError during
  assert_same_state_dict. The early return is justified because the
  function's only callers feed it through torch.testing.assert_close,
  which tolerates non-tensor equality via fallthrough upstream.

empty_model.py:724-746 (blame: "[WIP] gemma 4 dense fast inference"):
  The fresh_rotary_emb sync block is preserved verbatim; only its
  enclosing gate is split. The original `if (quantization_config or
  {}) == {} and bnb_config is None:` controlled both the device/dtype
  cast AND the Gemma4 rotary attention_scaling + float32 inv_freq
  restore. Quantized Gemma4 skipped the restore and silently regressed
  the float32 rotary stability that PR unslothai#600 / upstream Gemma4 rely on.
  The .to(...) call remains gated; the Gemma4 rotary sync now runs on
  the quantized path too. No sibling file owns this logic (grepped
  fresh_rotary_emb / attention_scaling across unsloth_zoo).

empty_model.py:711 (blame: "[WIP] gemma 4 dense fast inference"):
  The original `assert` preserves the same precondition; switching to
  `raise ValueError(...)` keeps identical behavior under regular
  Python and adds survival under `python -O`, where asserts are
  stripped and the user would otherwise see a confusing AttributeError
  on vision_config.hidden_size.

empty_model.py:638 (blame: "Bug fixes (unslothai#344)"):
  The print itself was the bug-fix addition; it is not being removed,
  only gated behind UNSLOTH_ENABLE_LOGGING to match the module-wide
  convention (e.g. hf_utils.set_dtype_in_config_fallback). The log
  message is preserved character-for-character.

empty_model.py:758+ (layer templates):
  Adds Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm to the shared fallback layer_templates.
  These modules are real per-layer submodules of
  Gemma4TextDecoderLayer (modeling_gemma4.py L1339-1344) that the new
  finalize path otherwise leaves at 1x1 placeholder shape, causing a
  runtime shape mismatch on text forward.

hf_utils.py:52-80 (blame: "Fix dtype setting"):
  The Fix dtype setting commit stored a runtime_dtype object to cover
  HF configs whose to_dict handles torch.dtype. Two regressions
  remained: (1) prefixed strings like "torch.float16" were stored
  verbatim because getattr(torch, "torch.float16", dtype) returns the
  original string; (2) the fallback path still stored a normalized
  string, leaving the two branches inconsistent. The new code strips
  the prefix first, then normalizes to the short string form before
  setattr, which keeps the original commit's intent (handle prefixed
  input, reach both torch_dtype and dtype fields, fall back on exotic
  configs) while matching set_dtype_in_config_fallback's output.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

auto-approved Auto-review approved the PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant