Skip to content

Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability#3

Open
danielhanchen wants to merge 11 commits into
mainfrom
pr-600-code
Open

Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability#3
danielhanchen wants to merge 11 commits into
mainfrom
pr-600-code

Conversation

@danielhanchen

Copy link
Copy Markdown
Owner

Staging mirror of unslothai#600

Original PR: unslothai#600
Author: danielhanchen

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Problem

unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it GRPO training NaNs and crashes with a CUDA device-side assert at step 2 when the user passes dtype=torch.float16. torch.bfloat16 is unaffected.

Minimal reproduction, 8 steps on a single B200:

python scripts/gemma4_grpo_sudoku_repro.py --model unsloth/gemma-4-E2B-it --dtype float16 --max-steps 8 --grad-accum 4 --num-generations 2

Unpatched result (dtype=float16):

step 1: loss=0.0, grad_norm=0.1108, reward=1.416
step 2: CUDA error: device-side assert triggered

Hooks on the model show the first overflow occurs in language_model.layers.0.mlp.down_proj. post_stats[base_layer] hits +inf during step 1's forward, poisons the residual stream, and step 2's generation produces NaN logits that trip the categorical sampler.

Root cause

Under UNSLOTH_FORCE_FLOAT32=1, the loader stores weights as bfloat16 but GRPO enters its fp16 autocast context for training. Gemma-4's text-decoder MLP does

down_proj(act_fn(gate_proj(x)) * up_proj(x))

with no internal upcast. Inside fp16 autocast, gate and up outputs saturate at fp16_max, the product becomes +inf on downcast, and down_proj's fp16 matmul accumulator then produces +inf for the residual add. bf16 avoids this because its range covers fp32 magnitudes.

Gemma-3 already ships UNSLOTH_FORCE_FLOAT32 patches for the same failure mode. The corresponding Gemma-4 patches existed briefly and were removed in 158e981 (2026-04-06) after initial testing found it stable. GRPO fp16 training exercises code paths the earlier testing did not.

Fix

Four text-decoder patches in unsloth_zoo/temporary_patches/gemma4.py, each gated on UNSLOTH_FORCE_FLOAT32=1 and each appended to TEMPORARY_PATCHES:

Patch What it does
patch_Gemma4RMSNorm fp32 variance and weight multiply, clamp to 65280 (one bf16 ulp below fp16_max so later fp16 casts cannot round up to inf), then fp16 cast.
patch_Gemma4TextScaledWordEmbedding Embed lookup in fp32 multiplied by the fp32 embed_scale so the first RMSNorm gets full precision input.
patch_Gemma4TextMLP Disable fp16 autocast so gate_proj and up_proj execute at the bf16 weight dtype; activation and multiply in fp32; clamp to 65280; fp16 cast; torch.nan_to_num safety net rescues the rare case where down_proj's fp16 accumulator still overflows for wide intermediate dims.
patch_Gemma4TextAttention Same autocast-disabled pattern for Q/K/V projections, fp32 q/k/v_norm, fp32 RoPE, fp32 SDPA, clamp and fp16 cast before o_proj, same nan_to_num safety net. Preserves the KV-sharing and shared_kv_states contract.

The paired change in unslothai/unsloth adds "gemma4" to FORCE_FLOAT32 so these patches actually engage for fp16 requests.

No Gemma-4 causal mask patch is added because, unlike Gemma-3, Gemma-4 delegates mask construction to ALL_ATTENTION_FUNCTIONS and the existing sdpa / eager / flex_attention paths handle mas


This PR contains code changes only (1 files). Test changes are in a separate PR.

Changed files:

  • unsloth_zoo/temporary_patches/gemma4.py

Gemma-4 E2B/E4B GRPO training hits a CUDA device-side assert at step 2
when dtype=torch.float16 because the text-decoder MLP saturates in fp16:
gate_proj and up_proj outputs reach fp16_max under the fp16 autocast
context, the gate*up product overflows on downcast, and the subsequent
down_proj fp16 matmul accumulator tips to +inf. Generation with an inf
residual stream produces NaN logits and the categorical sampler aborts.

bf16 training is unaffected because the same intermediate magnitudes fit
in bf16's range. Mirror the Gemma-3 UNSLOTH_FORCE_FLOAT32 recipe with
four Gemma-4 specific patches, all gated on the env flag:

  * Gemma4RMSNorm: fp32 norm and scale, clamp to 65280 before fp16 cast
    (65280 is one bf16 ulp below fp16_max, avoiding the bf16 round-up
    that otherwise produces +-inf on later fp16 conversions).
  * Gemma4TextScaledWordEmbedding: embed lookup and scale in fp32.
  * Gemma4TextMLP: execute gate_proj and up_proj with autocast disabled
    so they see bf16 weights (the actual model dtype under
    UNSLOTH_FORCE_FLOAT32), run activation plus multiply in fp32, clamp
    to the safe fp16 bound, then down_proj in fp16. A final nan_to_num
    rescues the rare case where down_proj's fp16 accumulator overflows
    for wide intermediate dims.
  * Gemma4TextAttention: same autocast-disabled pattern for Q/K/V
    projections, fp32 q_norm/k_norm/v_norm, fp32 RoPE, fp32 SDPA, clamp
    and fp16 cast before o_proj, plus nan_to_num safety net.

Repro recipe (see PR description for full table):

  python scripts/gemma4_grpo_sudoku_repro.py \
    --model unsloth/gemma-4-E2B-it --dtype float16 \
    --max-steps 8 --grad-accum 4 --num-generations 2

Without the patches this crashes at step 2 with CUDA device-side assert
and overflow flagged in layers.0.mlp.down_proj. With the patches
unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it both complete 8 GRPO
steps with finite loss and grad_norm values close to the bf16 baseline,
under both gradient_checkpointing="unsloth" and gradient_checkpointing
=False. bf16 behavior is unaffected because the patches are gated on
UNSLOTH_FORCE_FLOAT32=1 which only fires for float16 requests.
Bisection showed the NaN chain originates exclusively in the text-decoder
MLP gate*up -> down_proj path:

  * Gemma4RMSNorm already computes internally in fp32 and returns
    type_as(hidden_states), which is finite for the trained weight ranges.
  * Gemma4TextScaledWordEmbedding scales by sqrt(hidden_size) which is
    ~45 to 60 for E2B/E4B and fits in fp16.
  * Gemma4TextAttention projections did not overflow in any run.

So the earlier RMSNorm, Embedding, and Attention patches were redundant
once the MLP is stabilized. Removing them lets the text pipeline keep its
original dtype contract (no autocast gymnastics, no KV cache dtype
mismatch), which in turn keeps the gate/up/down matmuls on fp16 tensor
cores.

The MLP patch itself is reduced to the three operations that matter:

  1. Compute act_fn(gate) * up in fp32 so the product cannot overflow.
  2. Clamp to 65280 (one bf16 ulp below fp16_max) before down_proj so
     the fp16 cast cannot produce +-inf.
  3. nan_to_num on the output as defense-in-depth for the rare fp16
     accumulator overflow in down_proj for wide intermediate dims.

Verified on B200 (fp16 autocast under GRPO) - all runs complete 8 steps
with no NaN and no bad params:

  | model                       | gc       | step1 grad | step8 grad | step8 kl |
  | unsloth/gemma-4-E2B-it      | unsloth  | 0.1110     | 0.4077     | 6.03e-05 |
  | unsloth/gemma-4-E2B-it      | off      | 0.1110     | 0.2655     | 1.18e-04 |
  | unsloth/gemma-4-E4B-it      | unsloth  | 0.0987     | 0.0000     | 1.47e-05 |

Tesla T4 compatibility: T4 has no bf16 tensor cores so the loader's
FORCE_FLOAT32 path lands on fp16 weights. The patch's fp32 gate*up is an
elementwise op (runs on CUDA cores at 8.1 TFLOPS); gate_proj, up_proj,
and down_proj matmuls stay on fp16 tensor cores (65 TFLOPS). Net perf
overhead is minimal and the overflow prevention cost is paid only once
per MLP forward.

UNSLOTH_ENABLE_FLEX_ATTENTION on Gemma-4: flex is enabled by default but
only used when sdpa is unavailable. Gemma-4 supports sdpa so flex is not
selected for the text decoder in practice. Setting the env var to 0
produces identical E4B trajectories and near-identical E2B (diverges only
at step 8, still finite, no crash).
The int8 branch (torch.ops.aten._weight_int8pack_mm) worked end-to-end
but failed the RL-parity goal: 100-step GRPO on unsloth/gemma-4-E2B-it
with temperature=0.05, min_p=0.5, seed=3407 gave total |KL| of 7.47e+06
for int8 vs 52.16 for fp16+clamp vs 9387 for bf16, and int8 step 7
already had grad_norm=NaN.

Cause: GRPO's log-pi-new - log-pi-old ratio amplifies the ~7%
per-matmul weight-quantization noise, because the rollout path runs
under torch.inference_mode (dense fp16) while the training forward
used int8. Making them symmetric would require a separate int8
reference model, which is out of scope for a surgical NaN fix.

This commit removes all int8 code (Int8LinearFn autograd wrapper, row
quantizer, PEFT LoRA-aware int8 forward, UNSLOTH_GEMMA4_MLP_INT8 env
switch) and keeps only the fp16 trick:

  - fp32 act_fn(gate) * up so the product cannot overflow
  - clamp to 65280 before down_proj
  - nan_to_num on down_proj output for the rare accumulator tail

The patched forward is now 7 effective lines. Dtype contract is
identical to upstream (input dtype -> input dtype), so no attention /
RMSNorm / embedding companion patches are needed and the KV cache
stays aligned.

100-step verification vs bf16 (temp=0.05, min_p=0.5, seed=3407):

  median |KL|   bf16 4.02e-05  fp16+clamp 7.99e-06 (0.20x)
  p95 |KL|      bf16 0.411     fp16+clamp 0.620    (1.51x)
  max |KL|      bf16 8529      fp16+clamp 35.77    (0.004x)
  total |KL|    bf16 9388      fp16+clamp 52.16    (0.006x)
  mean reward   bf16 1.2111    fp16+clamp 1.1359   (|d|=0.075)
  reward-equal  75/100 steps
  time          +0.3%

Note: bf16's two huge outlier steps (step 24 |KL|=8529 grad=26461;
step 42 |KL|=800 grad=1434) are what push its total KL above
fp16+clamp's. For the calm 63-66% of steps both trajectories track
each other to 1e-5 precision.
65280 is the largest value exactly representable in both fp16 and bf16
(one bf16 ULP below 65536, 224 below fp16_max=65504). The previous
wording claimed it was one bf16 ulp below fp16_max, but fp16_max=65504
is not representable in bf16 at all -- it rounds up to 65536.

Clamp value is unchanged; comment only.
…nan rescue, inline import

- Add `x.dtype != torch.float16` guard so bf16/fp32 activations pass
  through the upstream forward unchanged. Under UNSLOTH_FORCE_FLOAT32
  the weights are fp16 so this is normally unreachable, but the guard
  protects against pipeline configurations where autocast is disabled
  (rl_replacements nullcontext path) and prevents unintended clamping
  if activations reach the MLP as bf16.

- Change nan_to_num replacements from +-65280 to 0 so overflow
  positions contribute nothing and leave the identity residual intact.
  Replacing with the fp16 ceiling would otherwise dominate the O(1)
  post-RMSNorm hidden state. Backward gradient through nan_to_num is
  `grad * isfinite(input)` in both cases, so this is loss-free.

- Drop the one-off `_gemma4_modeling` helper and inline the import to
  match the pattern of every other patch in this file.

- Drop the `x: torch.Tensor` annotation to match the rest of the
  temporary patches (`def forward(self, x):`).
- Switch the stabilization guard from `x.dtype != torch.float16` to
  `gate.dtype != torch.float16`. The matmul output dtype is what
  determines whether overflow can occur, so this also catches
  mixed-precision cases (bf16 activations through fp16-cast weights via
  autocast or do_forced_float32) that the x.dtype check missed. For
  the standard UNSLOTH_FORCE_FLOAT32 path behavior is unchanged.

- Inline `self.up_proj(x).float()` so the fp16 up tensor is transient,
  and reuse the already-computed gate in the bypass path.

- Cast product back with `gate.dtype` instead of `x.dtype` to avoid a
  bf16-input/fp16-weight mismatch at down_proj if stabilization is
  active with non-fp16 activations.
@danielhanchen

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a stability patch for Gemma4TextMLP to address numerical instability issues during float16 training, particularly for GRPO on Tesla T4 GPUs. The patch modifies the MLP's forward pass to compute the gate and up-projection product in fp32, clamps the result to a safe fp16 boundary, and applies nan_to_num to the down_proj output to prevent residual stream poisoning. A review comment suggests that the initial projections might still saturate if fp16 autocast is enabled, potentially losing precision before the stability logic is applied, and recommends disabling autocast for those specific operations to leverage the full range of the weights.

# projections over bf16/fp32 activations (autocast or forced fp16).
if gate.dtype != torch.float16:
return self.down_proj(self.act_fn(gate) * self.up_proj(x))
product = self.act_fn(gate.float()) * self.up_proj(x).float()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PR description mentions that the fix should "Disable fp16 autocast so gate_proj and up_proj execute at the bf16 weight dtype". However, the current implementation calls self.gate_proj(x) and self.up_proj(x) within the existing autocast context. If fp16 autocast is active, these matmuls will still execute in fp16 and potentially saturate at fp16_max (65504) before the .float() call. While the subsequent clamp and nan_to_num provide stability against product and accumulator overflows, they do not recover the range lost during the initial projections if they saturated. Consider wrapping the projections in with torch.cuda.amp.autocast(enabled=False): if the intent was to leverage the full bf16 range of the weights.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a stability patch for the Gemma-4 MLP to address numerical instability during fp16 training, particularly for GRPO on hardware like the Tesla T4. The patch ensures that the MLP product is computed in fp32, clamped to a safe range, and that any remaining overflows in the down_proj output are handled via nan_to_num. Feedback was provided to optimize the float16 path by reusing the result of the up_proj calculation instead of computing it twice.

# projections over bf16/fp32 activations (autocast or forced fp16).
if gate.dtype != torch.float16:
return self.down_proj(self.act_fn(gate) * self.up_proj(x))
product = self.act_fn(gate.float()) * self.up_proj(x).float()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation calls self.up_proj(x) twice in the float16 path: once implicitly via the matmul and once explicitly to cast to float. This can be optimized by computing the projection once and reusing the result.

Suggested change
product = self.act_fn(gate.float()) * self.up_proj(x).float()
up = self.up_proj(x)
product = self.act_fn(gate.float()) * up.float()

danielhanchen added a commit that referenced this pull request Apr 19, 2026
…n, finalize_huggingface_model

- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
  _call_create_lora_manager's signature inspection still sees vllm_config; pass model
  positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
  upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
  qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
  so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
  ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
  of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
  handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
  export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
  rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
  rotary_pos_emb on vision_config availability; mirror language_model detection from
  set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
  the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
  'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
  current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
  (defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
  so non-tensor state-dict values pass through.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant