Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability#600
Conversation
Gemma-4 E2B/E4B GRPO training hits a CUDA device-side assert at step 2
when dtype=torch.float16 because the text-decoder MLP saturates in fp16:
gate_proj and up_proj outputs reach fp16_max under the fp16 autocast
context, the gate*up product overflows on downcast, and the subsequent
down_proj fp16 matmul accumulator tips to +inf. Generation with an inf
residual stream produces NaN logits and the categorical sampler aborts.
bf16 training is unaffected because the same intermediate magnitudes fit
in bf16's range. Mirror the Gemma-3 UNSLOTH_FORCE_FLOAT32 recipe with
four Gemma-4 specific patches, all gated on the env flag:
* Gemma4RMSNorm: fp32 norm and scale, clamp to 65280 before fp16 cast
(65280 is one bf16 ulp below fp16_max, avoiding the bf16 round-up
that otherwise produces +-inf on later fp16 conversions).
* Gemma4TextScaledWordEmbedding: embed lookup and scale in fp32.
* Gemma4TextMLP: execute gate_proj and up_proj with autocast disabled
so they see bf16 weights (the actual model dtype under
UNSLOTH_FORCE_FLOAT32), run activation plus multiply in fp32, clamp
to the safe fp16 bound, then down_proj in fp16. A final nan_to_num
rescues the rare case where down_proj's fp16 accumulator overflows
for wide intermediate dims.
* Gemma4TextAttention: same autocast-disabled pattern for Q/K/V
projections, fp32 q_norm/k_norm/v_norm, fp32 RoPE, fp32 SDPA, clamp
and fp16 cast before o_proj, plus nan_to_num safety net.
Repro recipe (see PR description for full table):
python scripts/gemma4_grpo_sudoku_repro.py \
--model unsloth/gemma-4-E2B-it --dtype float16 \
--max-steps 8 --grad-accum 4 --num-generations 2
Without the patches this crashes at step 2 with CUDA device-side assert
and overflow flagged in layers.0.mlp.down_proj. With the patches
unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it both complete 8 GRPO
steps with finite loss and grad_norm values close to the bf16 baseline,
under both gradient_checkpointing="unsloth" and gradient_checkpointing
=False. bf16 behavior is unaffected because the patches are gated on
UNSLOTH_FORCE_FLOAT32=1 which only fires for float16 requests.
There was a problem hiding this comment.
Code Review
This pull request introduces numerical stability patches for Gemma-4 float16 training, specifically targeting GRPO, by implementing custom forward passes for RMSNorm, TextScaledWordEmbedding, TextMLP, and TextAttention to handle precision and overflow issues. The review identified several unused imports and redundant variable definitions that should be cleaned up to improve code quality.
| import inspect | ||
| from typing import Optional | ||
| import torch | ||
| import os | ||
| from .common import TEMPORARY_PATCHES | ||
| from .utils import raise_error | ||
| from .common import TEMPORARY_PATCHES, torch_compile | ||
| from .utils import ( | ||
| raise_error, | ||
| patch_function, | ||
| patch_function_past_key_values, | ||
| KWARGS_TYPE, | ||
| Cache, | ||
| ) |
There was a problem hiding this comment.
Several newly added imports are unused in this file and can be removed to keep the code clean. Specifically, inspect, torch_compile, patch_function_past_key_values, and KWARGS_TYPE are not referenced in the current implementation. Note that inspect is already imported locally within the _make_kv_shared_use_cache_false_safe_forward function (line 422).
from typing import Optional
import torch
import os
from .common import TEMPORARY_PATCHES
from .utils import (
raise_error,
patch_function,
Cache,
)| fp16_max = float(torch.finfo(torch.float16).max) | ||
| fp16_min = float(torch.finfo(torch.float16).min) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: bfb319b069
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| hidden_states: torch.Tensor, | ||
| position_embeddings: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor], | ||
| shared_kv_states: dict, |
There was a problem hiding this comment.
Match Gemma4TextAttention forward signature before patching
This patched forward adds a required shared_kv_states argument, but patch_function only installs replacements whose parameter names/kinds are compatible; in the supported Transformers range (<=5.5.0 in pyproject.toml), Gemma4TextAttention.forward takes past_key_values and **kwargs (no shared_kv_states). That signature mismatch causes the patch registration to fail silently, so this commit’s fp16 stability fix is not actually applied and the original GRPO NaN/device-assert failure remains.
Useful? React with 👍 / 👎.
Bisection showed the NaN chain originates exclusively in the text-decoder
MLP gate*up -> down_proj path:
* Gemma4RMSNorm already computes internally in fp32 and returns
type_as(hidden_states), which is finite for the trained weight ranges.
* Gemma4TextScaledWordEmbedding scales by sqrt(hidden_size) which is
~45 to 60 for E2B/E4B and fits in fp16.
* Gemma4TextAttention projections did not overflow in any run.
So the earlier RMSNorm, Embedding, and Attention patches were redundant
once the MLP is stabilized. Removing them lets the text pipeline keep its
original dtype contract (no autocast gymnastics, no KV cache dtype
mismatch), which in turn keeps the gate/up/down matmuls on fp16 tensor
cores.
The MLP patch itself is reduced to the three operations that matter:
1. Compute act_fn(gate) * up in fp32 so the product cannot overflow.
2. Clamp to 65280 (one bf16 ulp below fp16_max) before down_proj so
the fp16 cast cannot produce +-inf.
3. nan_to_num on the output as defense-in-depth for the rare fp16
accumulator overflow in down_proj for wide intermediate dims.
Verified on B200 (fp16 autocast under GRPO) - all runs complete 8 steps
with no NaN and no bad params:
| model | gc | step1 grad | step8 grad | step8 kl |
| unsloth/gemma-4-E2B-it | unsloth | 0.1110 | 0.4077 | 6.03e-05 |
| unsloth/gemma-4-E2B-it | off | 0.1110 | 0.2655 | 1.18e-04 |
| unsloth/gemma-4-E4B-it | unsloth | 0.0987 | 0.0000 | 1.47e-05 |
Tesla T4 compatibility: T4 has no bf16 tensor cores so the loader's
FORCE_FLOAT32 path lands on fp16 weights. The patch's fp32 gate*up is an
elementwise op (runs on CUDA cores at 8.1 TFLOPS); gate_proj, up_proj,
and down_proj matmuls stay on fp16 tensor cores (65 TFLOPS). Net perf
overhead is minimal and the overflow prevention cost is paid only once
per MLP forward.
UNSLOTH_ENABLE_FLEX_ATTENTION on Gemma-4: flex is enabled by default but
only used when sdpa is unavailable. Gemma-4 supports sdpa so flex is not
selected for the text decoder in practice. Setting the env var to 0
produces identical E4B trajectories and near-identical E2B (diverges only
at step 8, still finite, no crash).
|
Reduced to a single patch after a cleaner bisection.
Once the MLP no longer writes +-inf into the residual stream, the rest of the text pipeline stays finite without modification. Removing the three extra patches also drops the autocast-disable gymnastics and keeps the gate / up / down matmuls on fp16 tensor cores (important for Tesla T4 which has no bf16 tensor cores). The MLP patch is now three operations: def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
product = self.act_fn(gate.float()) * up.float()
product = torch.clamp(product, min=-65280.0, max=65280.0)
out = self.down_proj(product.to(x.dtype))
return torch.nan_to_num(out, nan=0.0, posinf=65280.0, neginf=-65280.0)Verified on B200 (fp16 GRPO,
All runs complete 8 steps with
|
The int8 branch (torch.ops.aten._weight_int8pack_mm) worked end-to-end but failed the RL-parity goal: 100-step GRPO on unsloth/gemma-4-E2B-it with temperature=0.05, min_p=0.5, seed=3407 gave total |KL| of 7.47e+06 for int8 vs 52.16 for fp16+clamp vs 9387 for bf16, and int8 step 7 already had grad_norm=NaN. Cause: GRPO's log-pi-new - log-pi-old ratio amplifies the ~7% per-matmul weight-quantization noise, because the rollout path runs under torch.inference_mode (dense fp16) while the training forward used int8. Making them symmetric would require a separate int8 reference model, which is out of scope for a surgical NaN fix. This commit removes all int8 code (Int8LinearFn autograd wrapper, row quantizer, PEFT LoRA-aware int8 forward, UNSLOTH_GEMMA4_MLP_INT8 env switch) and keeps only the fp16 trick: - fp32 act_fn(gate) * up so the product cannot overflow - clamp to 65280 before down_proj - nan_to_num on down_proj output for the rare accumulator tail The patched forward is now 7 effective lines. Dtype contract is identical to upstream (input dtype -> input dtype), so no attention / RMSNorm / embedding companion patches are needed and the KV cache stays aligned. 100-step verification vs bf16 (temp=0.05, min_p=0.5, seed=3407): median |KL| bf16 4.02e-05 fp16+clamp 7.99e-06 (0.20x) p95 |KL| bf16 0.411 fp16+clamp 0.620 (1.51x) max |KL| bf16 8529 fp16+clamp 35.77 (0.004x) total |KL| bf16 9388 fp16+clamp 52.16 (0.006x) mean reward bf16 1.2111 fp16+clamp 1.1359 (|d|=0.075) reward-equal 75/100 steps time +0.3% Note: bf16's two huge outlier steps (step 24 |KL|=8529 grad=26461; step 42 |KL|=800 grad=1434) are what push its total KL above fp16+clamp's. For the calm 63-66% of steps both trajectories track each other to 1e-5 precision.
|
Final state: shipping only the minimal fp16 MLP trick. The int8 weight-only variant ( 100-step parity table on
Reward-identical steps (|delta|<0.01): bf16 vs fp16+clamp = 75/100. The shipped patch body is now three operations: def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
product = self.act_fn(gate.float()) * up.float()
product = torch.clamp(product, min=-65280.0, max=65280.0)
out = self.down_proj(product.to(x.dtype))
return torch.nan_to_num(out, nan=0.0, posinf=65280.0, neginf=-65280.0)Dtype contract unchanged from upstream, no attention / RMSNorm / embedding companion patches, gate / up / down matmuls stay on fp16 tensor cores (full T4 throughput at 65 TFLOPS). |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b8d9b848f4
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # downstream round-trip through PEFT's internal dtype casts. | ||
| _SAFE_FP16 = 65280.0 | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Match Gemma4TextMLP patch signature to upstream forward
The patched forward adds type annotations (x: torch.Tensor -> torch.Tensor), but patch_function is called with default match_level="strict", which requires parameter annotations to exactly match the original method. In the supported Transformers range (verified on 5.5.0), Gemma4TextMLP.forward is (self, x) with no annotations, so patch_function(...) returns False and this stability patch is silently skipped; GRPO fp16 runs then continue using the unpatched path.
Useful? React with 👍 / 👎.
65280 is the largest value exactly representable in both fp16 and bf16 (one bf16 ULP below 65536, 224 below fp16_max=65504). The previous wording claimed it was one bf16 ulp below fp16_max, but fp16_max=65504 is not representable in bf16 at all -- it rounds up to 65536. Clamp value is unchanged; comment only.
…nan rescue, inline import - Add `x.dtype != torch.float16` guard so bf16/fp32 activations pass through the upstream forward unchanged. Under UNSLOTH_FORCE_FLOAT32 the weights are fp16 so this is normally unreachable, but the guard protects against pipeline configurations where autocast is disabled (rl_replacements nullcontext path) and prevents unintended clamping if activations reach the MLP as bf16. - Change nan_to_num replacements from +-65280 to 0 so overflow positions contribute nothing and leave the identity residual intact. Replacing with the fp16 ceiling would otherwise dominate the O(1) post-RMSNorm hidden state. Backward gradient through nan_to_num is `grad * isfinite(input)` in both cases, so this is loss-free. - Drop the one-off `_gemma4_modeling` helper and inline the import to match the pattern of every other patch in this file. - Drop the `x: torch.Tensor` annotation to match the rest of the temporary patches (`def forward(self, x):`).
- Switch the stabilization guard from `x.dtype != torch.float16` to `gate.dtype != torch.float16`. The matmul output dtype is what determines whether overflow can occur, so this also catches mixed-precision cases (bf16 activations through fp16-cast weights via autocast or do_forced_float32) that the x.dtype check missed. For the standard UNSLOTH_FORCE_FLOAT32 path behavior is unchanged. - Inline `self.up_proj(x).float()` so the fp16 up tensor is transient, and reuse the already-computed gate in the bypass path. - Cast product back with `gate.dtype` instead of `x.dtype` to avoid a bf16-input/fp16-weight mismatch at down_proj if stabilization is active with non-fp16 activations.
|
Auto-review verdict: Approved Adds patch_Gemma4TextMLP which stabilizes Gemma-4 GRPO training under fp16 on Tesla T4 by running gate*up in fp32, clamping to 65280, and rescuing down_proj accumulator overflows with nan_to_num, gated on UNSLOTH_FORCE_FLOAT32 and the gate_proj output dtype so bf16/fp32 paths stay byte-identical to upstream. Reason: Correct fp16 overflow rescue for Gemma-4 MLP; all accepted review findings fixed, no remaining regressions. |
|
Auto-review verdict: Approved Adds patch_Gemma4TextMLP to stabilize Gemma-4 fp16 GRPO training (notably Tesla T4) by computing gate*up in fp32 with safe clamping and nan_to_num rescue; review hardened it with signature-mismatch detection, inf/NaN neutralization before act_fn, and an idempotency sentinel. Reason: All real issues fixed; 16/16 tests pass; core fp32-upcast + clamp + nan_to_num logic is correct and minimal. |
The existing fast-path check on `gate.dtype != torch.float16` already covers every case where the clamp is unnecessary (bf16, fp32, fp8). With that in place, the outer FORCE_FLOAT32 env gate was the only thing preventing the patch from firing for users who pass `dtype=torch.float16` directly without also going through the FORCE_FLOAT32 load-dtype-swap path. Drops the dependency on the companion `FORCE_FLOAT32` entry in the unsloth loader (PR unslothai/unsloth#5092) - the patch now self-activates whenever the activation at the MLP boundary is fp16, regardless of how that fp16 was produced (explicit dtype, autocast, PEFT cast). Verified on unsloth/gemma-4-E2B-it GRPO at fp16 for 100 steps without FORCE_FLOAT32: 1,224 fp16_max overflows on layer-0 gate*up rescued by the clamp + nan_to_num, reward trajectory matches the FORCE_FLOAT32 baseline to within seed noise. Updated the section comment and docstring to reflect the new behaviour.
The section header, docstring, and inline comments had grown to cover every intermediate diagnostic we went through. Most of that detail belongs in the PR description or a post-mortem, not in the shipped source. Keep only what the next reader of this file needs: the single sentence explaining the overflow, the fix, the dtype-gate behaviour, and the reason _SAFE_FP16 is 65280 rather than 65504.
| ) | ||
| except Exception as e: | ||
| return raise_error("Gemma4TextMLP.forward", e) | ||
| pass |
| # trajectories). | ||
|
|
||
|
|
||
| def patch_Gemma4TextMLP(): |
All edits touch load-bearing code introduced by the Gemma4 / Qwen3.5 /
dtype-handling work on this branch. Citations below explain why each
hunk is not a regression of the cited commit.
vllm_utils.py:1782 (blame: "[WIP] gemma 4 dense fast inference"):
The original gating correctly fires the LoRA patch only for vision
Gemma4, but it also hides the BnB k_eq_v loader patch behind
is_vision_model. Text-only Gemma4 E2B/E4B loaded with BnB4bit still
needs the k_eq_v quant-state duplication the same commit added,
because attention_k_eq_v is set on the text config regardless of
modality. This hunk keeps the LoRA patch vision-gated and broadens
the k_eq_v patch to every gemma4 load.
vllm_utils.py:1345 (blame: "fix lm_head detection and remove moe"):
"conv1d" was added to layernorm_names as part of the Qwen3.5 GDN work
to avoid the Linear-rebuild branch. However, the layernorm branch
only swaps the .weight tensor on the empty-model placeholder Conv1d
(kernel_size=1, groups=1), which does not match the real GDN
depthwise conv (kernel_size = linear_conv_kernel_dim, groups =
conv_dim) and breaks forward. The new dedicated conv1d branch
rebuilds the module from the real weight shape; removing the
substring entry from layernorm_names is required to reach it. No
existing helper in unsloth_zoo rebuilds Conv1d modules (grepped),
so the inline block is not a duplicate.
vllm_utils.py:1216 (_normalize_state_dict_tensor):
Non-tensor guard added so quant_state dict values (added by the same
PR's new GDN path) no longer raise AttributeError during
assert_same_state_dict. The early return is justified because the
function's only callers feed it through torch.testing.assert_close,
which tolerates non-tensor equality via fallthrough upstream.
empty_model.py:724-746 (blame: "[WIP] gemma 4 dense fast inference"):
The fresh_rotary_emb sync block is preserved verbatim; only its
enclosing gate is split. The original `if (quantization_config or
{}) == {} and bnb_config is None:` controlled both the device/dtype
cast AND the Gemma4 rotary attention_scaling + float32 inv_freq
restore. Quantized Gemma4 skipped the restore and silently regressed
the float32 rotary stability that PR unslothai#600 / upstream Gemma4 rely on.
The .to(...) call remains gated; the Gemma4 rotary sync now runs on
the quantized path too. No sibling file owns this logic (grepped
fresh_rotary_emb / attention_scaling across unsloth_zoo).
empty_model.py:711 (blame: "[WIP] gemma 4 dense fast inference"):
The original `assert` preserves the same precondition; switching to
`raise ValueError(...)` keeps identical behavior under regular
Python and adds survival under `python -O`, where asserts are
stripped and the user would otherwise see a confusing AttributeError
on vision_config.hidden_size.
empty_model.py:638 (blame: "Bug fixes (unslothai#344)"):
The print itself was the bug-fix addition; it is not being removed,
only gated behind UNSLOTH_ENABLE_LOGGING to match the module-wide
convention (e.g. hf_utils.set_dtype_in_config_fallback). The log
message is preserved character-for-character.
empty_model.py:758+ (layer templates):
Adds Gemma4 per_layer_input_gate / per_layer_projection /
post_per_layer_input_norm to the shared fallback layer_templates.
These modules are real per-layer submodules of
Gemma4TextDecoderLayer (modeling_gemma4.py L1339-1344) that the new
finalize path otherwise leaves at 1x1 placeholder shape, causing a
runtime shape mismatch on text forward.
hf_utils.py:52-80 (blame: "Fix dtype setting"):
The Fix dtype setting commit stored a runtime_dtype object to cover
HF configs whose to_dict handles torch.dtype. Two regressions
remained: (1) prefixed strings like "torch.float16" were stored
verbatim because getattr(torch, "torch.float16", dtype) returns the
original string; (2) the fallback path still stored a normalized
string, leaving the two branches inconsistent. The new code strips
the prefix first, then normalizes to the short string form before
setattr, which keeps the original commit's intent (handle prefixed
input, reach both torch_dtype and dtype fields, fall back on exotic
configs) while matching set_dtype_in_config_fallback's output.
Problem
unsloth/gemma-4-E2B-itandunsloth/gemma-4-E4B-itGRPO training NaNs and crashes with a CUDA device-side assert at step 2 when the user passesdtype=torch.float16.torch.bfloat16is unaffected.Minimal reproduction, 8 steps on a single B200:
Unpatched result (dtype=float16):
Hooks on the model show the first overflow occurs in
language_model.layers.0.mlp.down_proj.post_stats[base_layer]hits +inf during step 1's forward, poisons the residual stream, and step 2's generation produces NaN logits that trip the categorical sampler.Root cause
Under
UNSLOTH_FORCE_FLOAT32=1, the loader stores weights as bfloat16 but GRPO enters its fp16 autocast context for training. Gemma-4's text-decoder MLP doeswith no internal upcast. Inside fp16 autocast, gate and up outputs saturate at
fp16_max, the product becomes +inf on downcast, anddown_proj's fp16 matmul accumulator then produces +inf for the residual add. bf16 avoids this because its range covers fp32 magnitudes.Gemma-3 already ships
UNSLOTH_FORCE_FLOAT32patches for the same failure mode. The corresponding Gemma-4 patches existed briefly and were removed in 158e981 (2026-04-06) after initial testing found it stable. GRPO fp16 training exercises code paths the earlier testing did not.Fix
Four text-decoder patches in
unsloth_zoo/temporary_patches/gemma4.py, each gated onUNSLOTH_FORCE_FLOAT32=1and each appended toTEMPORARY_PATCHES:patch_Gemma4RMSNormpatch_Gemma4TextScaledWordEmbeddingpatch_Gemma4TextMLPtorch.nan_to_numsafety net rescues the rare case where down_proj's fp16 accumulator still overflows for wide intermediate dims.patch_Gemma4TextAttentionnan_to_numsafety net. Preserves the KV-sharing andshared_kv_statescontract.The paired change in
unslothai/unslothadds"gemma4"toFORCE_FLOAT32so these patches actually engage for fp16 requests.No Gemma-4 causal mask patch is added because, unlike Gemma-3, Gemma-4 delegates mask construction to
ALL_ATTENTION_FUNCTIONSand the existing sdpa / eager / flex_attention paths handle mask dtype correctly once the Q/K/V inputs are finite.Verification
All runs on a B200 using
scripts/gemma4_grpo_sudoku_repro.pywith the plan's parameters (max_steps=8,grad_accum=4,num_generations=2,batch_size=1, seed 3407).hooks.find_overflows()on the patched runs reports only the embedding input-id pseudo-overflow (hooks see integer token IDs up to the 256k vocab) and the layer-0 down_proj saturation that thenan_to_numsafety net catches. No parameter contains NaN or inf at the end of any patched run (post_train_bad_params: []).Notes for reviewers
fp16_max=65504 because bf16's next-representable value above 65504 is 65536, which becomes inf on the fp16 cast taken by the PEFT LoRA wrapper on the down_proj base_layer.nan_to_numcall is a defense-in-depth measure, not the primary fix. It only fires when the fp16 accumulator in the down_proj / o_proj matmul tips past fp16_max even for a bounded input; the upstream clamp in the MLP and attention paths prevents this in the vast majority of steps.UNSLOTH_FORCE_FLOAT32=0(i.e. any bf16 run) so bf16 users see no behavior change.