Remove Gemma-4 temporary patches#576
Conversation
Gemma-4 does not need FORCE_FLOAT32 temporary patches. The model works correctly in both float16 and bfloat16 without any intervention. Removed all 10 patch functions: - patch_Gemma4RMSNorm (FORCE_FLOAT32 path) - patch_Gemma4RMSNorm_generic (bf16 path) - patch_Gemma4TextScaledWordEmbedding - patch_Gemma4TextAttention (FORCE_FLOAT32 path) - patch_Gemma4TextAttention_generic (bf16 path) - patch_Gemma4ForConditionalGeneration_causal_mask - patch_Gemma4TextMLP - patch_Gemma4TextDecoderLayer - patch_Gemma4AudioAttention - patch_Gemma4TextModel_project_per_layer_inputs Testing (100 steps, Gemma-4 E2B, 4-bit LoRA, SFT on FineTome-100k): - float16 final loss: 3.048, bfloat16 final loss: 3.065 - Losses converge to within 0.02 by step 60 - Inference outputs identical between float16 and bfloat16 The FORCE_FLOAT32 patches were causing compiled float32 training to diverge at step ~28. Without them, training works correctly. Companion PR: unslothai/unsloth (remove gemma4 from FORCE_FLOAT32 list)
There was a problem hiding this comment.
Code Review
This pull request removes extensive temporary patches and workarounds for Gemma-4 in unsloth_zoo/temporary_patches/gemma4.py, including custom implementations for RMSNorm, Attention, and MLP layers. These changes reflect that Gemma-4 now functions correctly with standard float16 and bfloat16 without additional intervention. I have no feedback to provide.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 8b83700ead
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Gemma-4 does not need FORCE_FLOAT32 or temporary patches. | ||
| # float16 and bfloat16 both work correctly without intervention. |
There was a problem hiding this comment.
Restore per-layer input dtype guard for Gemma4
This change removes patch_Gemma4TextModel_project_per_layer_inputs, which was previously casting inputs_embeds to per_layer_model_projection.weight.dtype before the projection call; upstream Gemma4 still does self.per_layer_model_projection(inputs_embeds) directly, so when callers pass inputs_embeds in float32 to an fp16/bf16 model (a supported forward path used by custom/multimodal embedding flows), F.linear hits a dtype mismatch at runtime. With all Gemma4 temporary patches deleted, that failure mode is reintroduced for those inputs.
Useful? React with 👍 / 👎.
Summary
Background
The FORCE_FLOAT32 patches were originally added to handle potential float16 overflow in Gemma-4. However, testing shows that the model's activation magnitudes stay well within float16 range (max ~2080 vs fp16 max 65504), and the patches themselves were causing training divergence when the compiler interacted with the forced float32 path.
Test results
Inference (greedy,
enable_thinking=False):Training (Gemma-4 E2B, 4-bit LoRA, SFT on FineTome-100k, 100 steps):
Losses converge to within 0.02 by step 60. float16 is actually slightly better in the final 20 steps.
Companion PR
gemma4fromFORCE_FLOAT32list inloader.pyTest plan