Skip to content

Remove Gemma-4 temporary patches#576

Merged
danielhanchen merged 1 commit into
mainfrom
gemma4-remove-temporary-patches
Apr 6, 2026
Merged

Remove Gemma-4 temporary patches#576
danielhanchen merged 1 commit into
mainfrom
gemma4-remove-temporary-patches

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

  • Remove all 10 FORCE_FLOAT32 and generic temporary patches for Gemma-4
  • Gemma-4 works correctly in both float16 and bfloat16 without any patching

Background

The FORCE_FLOAT32 patches were originally added to handle potential float16 overflow in Gemma-4. However, testing shows that the model's activation magnitudes stay well within float16 range (max ~2080 vs fp16 max 65504), and the patches themselves were causing training divergence when the compiler interacted with the forced float32 path.

Test results

Inference (greedy, enable_thinking=False):

  • float16 and bfloat16 produce identical outputs across multiple prompts

Training (Gemma-4 E2B, 4-bit LoRA, SFT on FineTome-100k, 100 steps):

Metric float16 bfloat16
Final loss (step 100) 3.048 3.065
Min loss 2.389 (step 76) 2.396 (step 76)
Avg loss (last 20 steps) 3.198 3.211
Grad norms Healthy (~3.0) Healthy (~3.0)

Losses converge to within 0.02 by step 60. float16 is actually slightly better in the final 20 steps.

Companion PR

  • unslothai/unsloth -- removes gemma4 from FORCE_FLOAT32 list in loader.py

Test plan

  • Verify float16 inference produces correct output
  • Verify bfloat16 inference produces correct output
  • Verify float16 training converges (100 steps)
  • Verify bfloat16 training converges (100 steps)
  • Verify losses match between float16 and bfloat16
  • Test on Tesla T4 (float16-only GPU)

Gemma-4 does not need FORCE_FLOAT32 temporary patches. The model works
correctly in both float16 and bfloat16 without any intervention.

Removed all 10 patch functions:
- patch_Gemma4RMSNorm (FORCE_FLOAT32 path)
- patch_Gemma4RMSNorm_generic (bf16 path)
- patch_Gemma4TextScaledWordEmbedding
- patch_Gemma4TextAttention (FORCE_FLOAT32 path)
- patch_Gemma4TextAttention_generic (bf16 path)
- patch_Gemma4ForConditionalGeneration_causal_mask
- patch_Gemma4TextMLP
- patch_Gemma4TextDecoderLayer
- patch_Gemma4AudioAttention
- patch_Gemma4TextModel_project_per_layer_inputs

Testing (100 steps, Gemma-4 E2B, 4-bit LoRA, SFT on FineTome-100k):
- float16 final loss: 3.048, bfloat16 final loss: 3.065
- Losses converge to within 0.02 by step 60
- Inference outputs identical between float16 and bfloat16

The FORCE_FLOAT32 patches were causing compiled float32 training to
diverge at step ~28. Without them, training works correctly.

Companion PR: unslothai/unsloth (remove gemma4 from FORCE_FLOAT32 list)
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes extensive temporary patches and workarounds for Gemma-4 in unsloth_zoo/temporary_patches/gemma4.py, including custom implementations for RMSNorm, Attention, and MLP layers. These changes reflect that Gemma-4 now functions correctly with standard float16 and bfloat16 without additional intervention. I have no feedback to provide.

@danielhanchen danielhanchen merged commit 158e981 into main Apr 6, 2026
3 checks passed
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8b83700ead

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +17 to +18
# Gemma-4 does not need FORCE_FLOAT32 or temporary patches.
# float16 and bfloat16 both work correctly without intervention.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restore per-layer input dtype guard for Gemma4

This change removes patch_Gemma4TextModel_project_per_layer_inputs, which was previously casting inputs_embeds to per_layer_model_projection.weight.dtype before the projection call; upstream Gemma4 still does self.per_layer_model_projection(inputs_embeds) directly, so when callers pass inputs_embeds in float32 to an fp16/bf16 model (a supported forward path used by custom/multimodal embedding flows), F.linear hits a dtype mismatch at runtime. With all Gemma4 temporary patches deleted, that failure mode is reintroduced for those inputs.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant