Skip to content

Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability#1

Closed
danielhanchen wants to merge 7 commits into
mainfrom
pr-600-head
Closed

Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability#1
danielhanchen wants to merge 7 commits into
mainfrom
pr-600-head

Conversation

@danielhanchen

Copy link
Copy Markdown
Owner

Staging mirror of unslothai#600

Original PR: unslothai#600
Author: danielhanchen

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Problem

unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it GRPO training NaNs and crashes with a CUDA device-side assert at step 2 when the user passes dtype=torch.float16. torch.bfloat16 is unaffected.

Minimal reproduction, 8 steps on a single B200:

python scripts/gemma4_grpo_sudoku_repro.py --model unsloth/gemma-4-E2B-it --dtype float16 --max-steps 8 --grad-accum 4 --num-generations 2

Unpatched result (dtype=float16):

step 1: loss=0.0, grad_norm=0.1108, reward=1.416
step 2: CUDA error: device-side assert triggered

Hooks on the model show the first overflow occurs in language_model.layers.0.mlp.down_proj. post_stats[base_layer] hits +inf during step 1's forward, poisons the residual stream, and step 2's generation produces NaN logits that trip the categorical sampler.

Root cause

Under UNSLOTH_FORCE_FLOAT32=1, the loader stores weights as bfloat16 but GRPO enters its fp16 autocast context for training. Gemma-4's text-decoder MLP does

down_proj(act_fn(gate_proj(x)) * up_proj(x))

with no internal upcast. Inside fp16 autocast, gate and up outputs saturate at fp16_max, the product becomes +inf on downcast, and down_proj's fp16 matmul accumulator then produces +inf for the residual add. bf16 avoids this because its range covers fp32 magnitudes.

Gemma-3 already ships UNSLOTH_FORCE_FLOAT32 patches for the same failure mode. The corresponding Gemma-4 patches existed briefly and were removed in 158e981 (2026-04-06) after initial testing found it stable. GRPO fp16 training exercises code paths the earlier testing did not.

Fix

Four text-decoder patches in unsloth_zoo/temporary_patches/gemma4.py, each gated on UNSLOTH_FORCE_FLOAT32=1 and each appended to TEMPORARY_PATCHES:

Patch What it does
patch_Gemma4RMSNorm fp32 variance and weight multiply, clamp to 65280 (one bf16 ulp below fp16_max so later fp16 casts cannot round up to inf), then fp16 cast.
patch_Gemma4TextScaledWordEmbedding Embed lookup in fp32 multiplied by the fp32 embed_scale so the first RMSNorm gets full precision input.
patch_Gemma4TextMLP Disable fp16 autocast so gate_proj and up_proj execute at the bf16 weight dtype; activation and multiply in fp32; clamp to 65280; fp16 cast; torch.nan_to_num safety net rescues the rare case where down_proj's fp16 accumulator still overflows for wide intermediate dims.
patch_Gemma4TextAttention Same autocast-disabled pattern for Q/K/V projections, fp32 q/k/v_norm, fp32 RoPE, fp32 SDPA, clamp and fp16 cast before o_proj, same nan_to_num safety net. Preserves the KV-sharing and shared_kv_states contract.

The paired change in unslothai/unsloth adds "gemma4" to FORCE_FLOAT32 so these patches actually engage for fp16 requests.

No Gemma-4 causal mask patch is added because, unlike Gemma-3, Gemma-4 delegates mask construction to ALL_ATTENTION_FUNCTIONS and the existing sdpa / eager / flex_attention paths handle mas

Gemma-4 E2B/E4B GRPO training hits a CUDA device-side assert at step 2
when dtype=torch.float16 because the text-decoder MLP saturates in fp16:
gate_proj and up_proj outputs reach fp16_max under the fp16 autocast
context, the gate*up product overflows on downcast, and the subsequent
down_proj fp16 matmul accumulator tips to +inf. Generation with an inf
residual stream produces NaN logits and the categorical sampler aborts.

bf16 training is unaffected because the same intermediate magnitudes fit
in bf16's range. Mirror the Gemma-3 UNSLOTH_FORCE_FLOAT32 recipe with
four Gemma-4 specific patches, all gated on the env flag:

  * Gemma4RMSNorm: fp32 norm and scale, clamp to 65280 before fp16 cast
    (65280 is one bf16 ulp below fp16_max, avoiding the bf16 round-up
    that otherwise produces +-inf on later fp16 conversions).
  * Gemma4TextScaledWordEmbedding: embed lookup and scale in fp32.
  * Gemma4TextMLP: execute gate_proj and up_proj with autocast disabled
    so they see bf16 weights (the actual model dtype under
    UNSLOTH_FORCE_FLOAT32), run activation plus multiply in fp32, clamp
    to the safe fp16 bound, then down_proj in fp16. A final nan_to_num
    rescues the rare case where down_proj's fp16 accumulator overflows
    for wide intermediate dims.
  * Gemma4TextAttention: same autocast-disabled pattern for Q/K/V
    projections, fp32 q_norm/k_norm/v_norm, fp32 RoPE, fp32 SDPA, clamp
    and fp16 cast before o_proj, plus nan_to_num safety net.

Repro recipe (see PR description for full table):

  python scripts/gemma4_grpo_sudoku_repro.py \
    --model unsloth/gemma-4-E2B-it --dtype float16 \
    --max-steps 8 --grad-accum 4 --num-generations 2

Without the patches this crashes at step 2 with CUDA device-side assert
and overflow flagged in layers.0.mlp.down_proj. With the patches
unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it both complete 8 GRPO
steps with finite loss and grad_norm values close to the bf16 baseline,
under both gradient_checkpointing="unsloth" and gradient_checkpointing
=False. bf16 behavior is unaffected because the patches are gated on
UNSLOTH_FORCE_FLOAT32=1 which only fires for float16 requests.
Bisection showed the NaN chain originates exclusively in the text-decoder
MLP gate*up -> down_proj path:

  * Gemma4RMSNorm already computes internally in fp32 and returns
    type_as(hidden_states), which is finite for the trained weight ranges.
  * Gemma4TextScaledWordEmbedding scales by sqrt(hidden_size) which is
    ~45 to 60 for E2B/E4B and fits in fp16.
  * Gemma4TextAttention projections did not overflow in any run.

So the earlier RMSNorm, Embedding, and Attention patches were redundant
once the MLP is stabilized. Removing them lets the text pipeline keep its
original dtype contract (no autocast gymnastics, no KV cache dtype
mismatch), which in turn keeps the gate/up/down matmuls on fp16 tensor
cores.

The MLP patch itself is reduced to the three operations that matter:

  1. Compute act_fn(gate) * up in fp32 so the product cannot overflow.
  2. Clamp to 65280 (one bf16 ulp below fp16_max) before down_proj so
     the fp16 cast cannot produce +-inf.
  3. nan_to_num on the output as defense-in-depth for the rare fp16
     accumulator overflow in down_proj for wide intermediate dims.

Verified on B200 (fp16 autocast under GRPO) - all runs complete 8 steps
with no NaN and no bad params:

  | model                       | gc       | step1 grad | step8 grad | step8 kl |
  | unsloth/gemma-4-E2B-it      | unsloth  | 0.1110     | 0.4077     | 6.03e-05 |
  | unsloth/gemma-4-E2B-it      | off      | 0.1110     | 0.2655     | 1.18e-04 |
  | unsloth/gemma-4-E4B-it      | unsloth  | 0.0987     | 0.0000     | 1.47e-05 |

Tesla T4 compatibility: T4 has no bf16 tensor cores so the loader's
FORCE_FLOAT32 path lands on fp16 weights. The patch's fp32 gate*up is an
elementwise op (runs on CUDA cores at 8.1 TFLOPS); gate_proj, up_proj,
and down_proj matmuls stay on fp16 tensor cores (65 TFLOPS). Net perf
overhead is minimal and the overflow prevention cost is paid only once
per MLP forward.

UNSLOTH_ENABLE_FLEX_ATTENTION on Gemma-4: flex is enabled by default but
only used when sdpa is unavailable. Gemma-4 supports sdpa so flex is not
selected for the text decoder in practice. Setting the env var to 0
produces identical E4B trajectories and near-identical E2B (diverges only
at step 8, still finite, no crash).
The int8 branch (torch.ops.aten._weight_int8pack_mm) worked end-to-end
but failed the RL-parity goal: 100-step GRPO on unsloth/gemma-4-E2B-it
with temperature=0.05, min_p=0.5, seed=3407 gave total |KL| of 7.47e+06
for int8 vs 52.16 for fp16+clamp vs 9387 for bf16, and int8 step 7
already had grad_norm=NaN.

Cause: GRPO's log-pi-new - log-pi-old ratio amplifies the ~7%
per-matmul weight-quantization noise, because the rollout path runs
under torch.inference_mode (dense fp16) while the training forward
used int8. Making them symmetric would require a separate int8
reference model, which is out of scope for a surgical NaN fix.

This commit removes all int8 code (Int8LinearFn autograd wrapper, row
quantizer, PEFT LoRA-aware int8 forward, UNSLOTH_GEMMA4_MLP_INT8 env
switch) and keeps only the fp16 trick:

  - fp32 act_fn(gate) * up so the product cannot overflow
  - clamp to 65280 before down_proj
  - nan_to_num on down_proj output for the rare accumulator tail

The patched forward is now 7 effective lines. Dtype contract is
identical to upstream (input dtype -> input dtype), so no attention /
RMSNorm / embedding companion patches are needed and the KV cache
stays aligned.

100-step verification vs bf16 (temp=0.05, min_p=0.5, seed=3407):

  median |KL|   bf16 4.02e-05  fp16+clamp 7.99e-06 (0.20x)
  p95 |KL|      bf16 0.411     fp16+clamp 0.620    (1.51x)
  max |KL|      bf16 8529      fp16+clamp 35.77    (0.004x)
  total |KL|    bf16 9388      fp16+clamp 52.16    (0.006x)
  mean reward   bf16 1.2111    fp16+clamp 1.1359   (|d|=0.075)
  reward-equal  75/100 steps
  time          +0.3%

Note: bf16's two huge outlier steps (step 24 |KL|=8529 grad=26461;
step 42 |KL|=800 grad=1434) are what push its total KL above
fp16+clamp's. For the calm 63-66% of steps both trajectories track
each other to 1e-5 precision.
@danielhanchen

Copy link
Copy Markdown
Owner Author

/gemini review

65280 is the largest value exactly representable in both fp16 and bf16
(one bf16 ULP below 65536, 224 below fp16_max=65504). The previous
wording claimed it was one bf16 ulp below fp16_max, but fp16_max=65504
is not representable in bf16 at all -- it rounds up to 65536.

Clamp value is unchanged; comment only.
…nan rescue, inline import

- Add `x.dtype != torch.float16` guard so bf16/fp32 activations pass
  through the upstream forward unchanged. Under UNSLOTH_FORCE_FLOAT32
  the weights are fp16 so this is normally unreachable, but the guard
  protects against pipeline configurations where autocast is disabled
  (rl_replacements nullcontext path) and prevents unintended clamping
  if activations reach the MLP as bf16.

- Change nan_to_num replacements from +-65280 to 0 so overflow
  positions contribute nothing and leave the identity residual intact.
  Replacing with the fp16 ceiling would otherwise dominate the O(1)
  post-RMSNorm hidden state. Backward gradient through nan_to_num is
  `grad * isfinite(input)` in both cases, so this is loss-free.

- Drop the one-off `_gemma4_modeling` helper and inline the import to
  match the pattern of every other patch in this file.

- Drop the `x: torch.Tensor` annotation to match the rest of the
  temporary patches (`def forward(self, x):`).
- Switch the stabilization guard from `x.dtype != torch.float16` to
  `gate.dtype != torch.float16`. The matmul output dtype is what
  determines whether overflow can occur, so this also catches
  mixed-precision cases (bf16 activations through fp16-cast weights via
  autocast or do_forced_float32) that the x.dtype check missed. For
  the standard UNSLOTH_FORCE_FLOAT32 path behavior is unchanged.

- Inline `self.up_proj(x).float()` so the fp16 up tensor is transient,
  and reuse the already-computed gate in the bypass path.

- Cast product back with `gate.dtype` instead of `x.dtype` to avoid a
  bf16-input/fp16-weight mismatch at down_proj if stabilization is
  active with non-fp16 activations.
@danielhanchen

Copy link
Copy Markdown
Owner Author

Fixes pushed to unslothai#600.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant