Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability#3
Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability#3danielhanchen wants to merge 11 commits into
Conversation
Gemma-4 E2B/E4B GRPO training hits a CUDA device-side assert at step 2
when dtype=torch.float16 because the text-decoder MLP saturates in fp16:
gate_proj and up_proj outputs reach fp16_max under the fp16 autocast
context, the gate*up product overflows on downcast, and the subsequent
down_proj fp16 matmul accumulator tips to +inf. Generation with an inf
residual stream produces NaN logits and the categorical sampler aborts.
bf16 training is unaffected because the same intermediate magnitudes fit
in bf16's range. Mirror the Gemma-3 UNSLOTH_FORCE_FLOAT32 recipe with
four Gemma-4 specific patches, all gated on the env flag:
* Gemma4RMSNorm: fp32 norm and scale, clamp to 65280 before fp16 cast
(65280 is one bf16 ulp below fp16_max, avoiding the bf16 round-up
that otherwise produces +-inf on later fp16 conversions).
* Gemma4TextScaledWordEmbedding: embed lookup and scale in fp32.
* Gemma4TextMLP: execute gate_proj and up_proj with autocast disabled
so they see bf16 weights (the actual model dtype under
UNSLOTH_FORCE_FLOAT32), run activation plus multiply in fp32, clamp
to the safe fp16 bound, then down_proj in fp16. A final nan_to_num
rescues the rare case where down_proj's fp16 accumulator overflows
for wide intermediate dims.
* Gemma4TextAttention: same autocast-disabled pattern for Q/K/V
projections, fp32 q_norm/k_norm/v_norm, fp32 RoPE, fp32 SDPA, clamp
and fp16 cast before o_proj, plus nan_to_num safety net.
Repro recipe (see PR description for full table):
python scripts/gemma4_grpo_sudoku_repro.py \
--model unsloth/gemma-4-E2B-it --dtype float16 \
--max-steps 8 --grad-accum 4 --num-generations 2
Without the patches this crashes at step 2 with CUDA device-side assert
and overflow flagged in layers.0.mlp.down_proj. With the patches
unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it both complete 8 GRPO
steps with finite loss and grad_norm values close to the bf16 baseline,
under both gradient_checkpointing="unsloth" and gradient_checkpointing
=False. bf16 behavior is unaffected because the patches are gated on
UNSLOTH_FORCE_FLOAT32=1 which only fires for float16 requests.
Bisection showed the NaN chain originates exclusively in the text-decoder
MLP gate*up -> down_proj path:
* Gemma4RMSNorm already computes internally in fp32 and returns
type_as(hidden_states), which is finite for the trained weight ranges.
* Gemma4TextScaledWordEmbedding scales by sqrt(hidden_size) which is
~45 to 60 for E2B/E4B and fits in fp16.
* Gemma4TextAttention projections did not overflow in any run.
So the earlier RMSNorm, Embedding, and Attention patches were redundant
once the MLP is stabilized. Removing them lets the text pipeline keep its
original dtype contract (no autocast gymnastics, no KV cache dtype
mismatch), which in turn keeps the gate/up/down matmuls on fp16 tensor
cores.
The MLP patch itself is reduced to the three operations that matter:
1. Compute act_fn(gate) * up in fp32 so the product cannot overflow.
2. Clamp to 65280 (one bf16 ulp below fp16_max) before down_proj so
the fp16 cast cannot produce +-inf.
3. nan_to_num on the output as defense-in-depth for the rare fp16
accumulator overflow in down_proj for wide intermediate dims.
Verified on B200 (fp16 autocast under GRPO) - all runs complete 8 steps
with no NaN and no bad params:
| model | gc | step1 grad | step8 grad | step8 kl |
| unsloth/gemma-4-E2B-it | unsloth | 0.1110 | 0.4077 | 6.03e-05 |
| unsloth/gemma-4-E2B-it | off | 0.1110 | 0.2655 | 1.18e-04 |
| unsloth/gemma-4-E4B-it | unsloth | 0.0987 | 0.0000 | 1.47e-05 |
Tesla T4 compatibility: T4 has no bf16 tensor cores so the loader's
FORCE_FLOAT32 path lands on fp16 weights. The patch's fp32 gate*up is an
elementwise op (runs on CUDA cores at 8.1 TFLOPS); gate_proj, up_proj,
and down_proj matmuls stay on fp16 tensor cores (65 TFLOPS). Net perf
overhead is minimal and the overflow prevention cost is paid only once
per MLP forward.
UNSLOTH_ENABLE_FLEX_ATTENTION on Gemma-4: flex is enabled by default but
only used when sdpa is unavailable. Gemma-4 supports sdpa so flex is not
selected for the text decoder in practice. Setting the env var to 0
produces identical E4B trajectories and near-identical E2B (diverges only
at step 8, still finite, no crash).
The int8 branch (torch.ops.aten._weight_int8pack_mm) worked end-to-end but failed the RL-parity goal: 100-step GRPO on unsloth/gemma-4-E2B-it with temperature=0.05, min_p=0.5, seed=3407 gave total |KL| of 7.47e+06 for int8 vs 52.16 for fp16+clamp vs 9387 for bf16, and int8 step 7 already had grad_norm=NaN. Cause: GRPO's log-pi-new - log-pi-old ratio amplifies the ~7% per-matmul weight-quantization noise, because the rollout path runs under torch.inference_mode (dense fp16) while the training forward used int8. Making them symmetric would require a separate int8 reference model, which is out of scope for a surgical NaN fix. This commit removes all int8 code (Int8LinearFn autograd wrapper, row quantizer, PEFT LoRA-aware int8 forward, UNSLOTH_GEMMA4_MLP_INT8 env switch) and keeps only the fp16 trick: - fp32 act_fn(gate) * up so the product cannot overflow - clamp to 65280 before down_proj - nan_to_num on down_proj output for the rare accumulator tail The patched forward is now 7 effective lines. Dtype contract is identical to upstream (input dtype -> input dtype), so no attention / RMSNorm / embedding companion patches are needed and the KV cache stays aligned. 100-step verification vs bf16 (temp=0.05, min_p=0.5, seed=3407): median |KL| bf16 4.02e-05 fp16+clamp 7.99e-06 (0.20x) p95 |KL| bf16 0.411 fp16+clamp 0.620 (1.51x) max |KL| bf16 8529 fp16+clamp 35.77 (0.004x) total |KL| bf16 9388 fp16+clamp 52.16 (0.006x) mean reward bf16 1.2111 fp16+clamp 1.1359 (|d|=0.075) reward-equal 75/100 steps time +0.3% Note: bf16's two huge outlier steps (step 24 |KL|=8529 grad=26461; step 42 |KL|=800 grad=1434) are what push its total KL above fp16+clamp's. For the calm 63-66% of steps both trajectories track each other to 1e-5 precision.
65280 is the largest value exactly representable in both fp16 and bf16 (one bf16 ULP below 65536, 224 below fp16_max=65504). The previous wording claimed it was one bf16 ulp below fp16_max, but fp16_max=65504 is not representable in bf16 at all -- it rounds up to 65536. Clamp value is unchanged; comment only.
…nan rescue, inline import - Add `x.dtype != torch.float16` guard so bf16/fp32 activations pass through the upstream forward unchanged. Under UNSLOTH_FORCE_FLOAT32 the weights are fp16 so this is normally unreachable, but the guard protects against pipeline configurations where autocast is disabled (rl_replacements nullcontext path) and prevents unintended clamping if activations reach the MLP as bf16. - Change nan_to_num replacements from +-65280 to 0 so overflow positions contribute nothing and leave the identity residual intact. Replacing with the fp16 ceiling would otherwise dominate the O(1) post-RMSNorm hidden state. Backward gradient through nan_to_num is `grad * isfinite(input)` in both cases, so this is loss-free. - Drop the one-off `_gemma4_modeling` helper and inline the import to match the pattern of every other patch in this file. - Drop the `x: torch.Tensor` annotation to match the rest of the temporary patches (`def forward(self, x):`).
- Switch the stabilization guard from `x.dtype != torch.float16` to `gate.dtype != torch.float16`. The matmul output dtype is what determines whether overflow can occur, so this also catches mixed-precision cases (bf16 activations through fp16-cast weights via autocast or do_forced_float32) that the x.dtype check missed. For the standard UNSLOTH_FORCE_FLOAT32 path behavior is unchanged. - Inline `self.up_proj(x).float()` so the fp16 up tensor is transient, and reuse the already-computed gate in the bypass path. - Cast product back with `gate.dtype` instead of `x.dtype` to avoid a bf16-input/fp16-weight mismatch at down_proj if stabilization is active with non-fp16 activations.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a stability patch for Gemma4TextMLP to address numerical instability issues during float16 training, particularly for GRPO on Tesla T4 GPUs. The patch modifies the MLP's forward pass to compute the gate and up-projection product in fp32, clamps the result to a safe fp16 boundary, and applies nan_to_num to the down_proj output to prevent residual stream poisoning. A review comment suggests that the initial projections might still saturate if fp16 autocast is enabled, potentially losing precision before the stability logic is applied, and recommends disabling autocast for those specific operations to leverage the full range of the weights.
| # projections over bf16/fp32 activations (autocast or forced fp16). | ||
| if gate.dtype != torch.float16: | ||
| return self.down_proj(self.act_fn(gate) * self.up_proj(x)) | ||
| product = self.act_fn(gate.float()) * self.up_proj(x).float() |
There was a problem hiding this comment.
The PR description mentions that the fix should "Disable fp16 autocast so gate_proj and up_proj execute at the bf16 weight dtype". However, the current implementation calls self.gate_proj(x) and self.up_proj(x) within the existing autocast context. If fp16 autocast is active, these matmuls will still execute in fp16 and potentially saturate at fp16_max (65504) before the .float() call. While the subsequent clamp and nan_to_num provide stability against product and accumulator overflows, they do not recover the range lost during the initial projections if they saturated. Consider wrapping the projections in with torch.cuda.amp.autocast(enabled=False): if the intent was to leverage the full bf16 range of the weights.
There was a problem hiding this comment.
Code Review
This pull request introduces a stability patch for the Gemma-4 MLP to address numerical instability during fp16 training, particularly for GRPO on hardware like the Tesla T4. The patch ensures that the MLP product is computed in fp32, clamped to a safe range, and that any remaining overflows in the down_proj output are handled via nan_to_num. Feedback was provided to optimize the float16 path by reusing the result of the up_proj calculation instead of computing it twice.
| # projections over bf16/fp32 activations (autocast or forced fp16). | ||
| if gate.dtype != torch.float16: | ||
| return self.down_proj(self.act_fn(gate) * self.up_proj(x)) | ||
| product = self.act_fn(gate.float()) * self.up_proj(x).float() |
There was a problem hiding this comment.
The current implementation calls self.up_proj(x) twice in the float16 path: once implicitly via the matmul and once explicitly to cast to float. This can be optimized by computing the projection once and reusing the result.
| product = self.act_fn(gate.float()) * self.up_proj(x).float() | |
| up = self.up_proj(x) | |
| product = self.act_fn(gate.float()) * up.float() |
…n, finalize_huggingface_model
- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
_call_create_lora_manager's signature inspection still sees vllm_config; pass model
positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
rotary_pos_emb on vision_config availability; mirror language_model detection from
set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
(defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
so non-tensor state-dict values pass through.
Staging mirror of unslothai#600
Original PR: unslothai#600
Author: danielhanchen
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Problem
unsloth/gemma-4-E2B-itandunsloth/gemma-4-E4B-itGRPO training NaNs and crashes with a CUDA device-side assert at step 2 when the user passesdtype=torch.float16.torch.bfloat16is unaffected.Minimal reproduction, 8 steps on a single B200:
Unpatched result (dtype=float16):
Hooks on the model show the first overflow occurs in
language_model.layers.0.mlp.down_proj.post_stats[base_layer]hits +inf during step 1's forward, poisons the residual stream, and step 2's generation produces NaN logits that trip the categorical sampler.Root cause
Under
UNSLOTH_FORCE_FLOAT32=1, the loader stores weights as bfloat16 but GRPO enters its fp16 autocast context for training. Gemma-4's text-decoder MLP doeswith no internal upcast. Inside fp16 autocast, gate and up outputs saturate at
fp16_max, the product becomes +inf on downcast, anddown_proj's fp16 matmul accumulator then produces +inf for the residual add. bf16 avoids this because its range covers fp32 magnitudes.Gemma-3 already ships
UNSLOTH_FORCE_FLOAT32patches for the same failure mode. The corresponding Gemma-4 patches existed briefly and were removed in 158e981 (2026-04-06) after initial testing found it stable. GRPO fp16 training exercises code paths the earlier testing did not.Fix
Four text-decoder patches in
unsloth_zoo/temporary_patches/gemma4.py, each gated onUNSLOTH_FORCE_FLOAT32=1and each appended toTEMPORARY_PATCHES:patch_Gemma4RMSNormpatch_Gemma4TextScaledWordEmbeddingpatch_Gemma4TextMLPtorch.nan_to_numsafety net rescues the rare case where down_proj's fp16 accumulator still overflows for wide intermediate dims.patch_Gemma4TextAttentionnan_to_numsafety net. Preserves the KV-sharing andshared_kv_statescontract.The paired change in
unslothai/unslothadds"gemma4"toFORCE_FLOAT32so these patches actually engage for fp16 requests.No Gemma-4 causal mask patch is added because, unlike Gemma-3, Gemma-4 delegates mask construction to
ALL_ATTENTION_FUNCTIONSand the existing sdpa / eager / flex_attention paths handle masThis PR contains code changes only (1 files). Test changes are in a separate PR.
Changed files:
unsloth_zoo/temporary_patches/gemma4.py