forked from unslothai/unsloth-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
Add Gemma-4 float16 UNSLOTH_FORCE_FLOAT32 patches for GRPO stability #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
danielhanchen
wants to merge
11
commits into
main
Choose a base branch
from
pr-600-code
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
bfb319b
Add Gemma-4 UNSLOTH_FORCE_FLOAT32 patches for float16 GRPO stability
danielhanchen 513ce6d
Reduce Gemma-4 fp16 stability to a single MLP patch
danielhanchen b8d9b84
Drop int8 experimental path, ship only the minimal fp16 MLP patch
danielhanchen 1ae66ac
Correct 65280 comment in patch_Gemma4TextMLP
danielhanchen 2c1c307
Harden patch_Gemma4TextMLP: fp16-only dtype guard, identity-residual …
danielhanchen d943395
patch_Gemma4TextMLP: gate on matmul output dtype, inline up projection
danielhanchen fcd23f4
Add review tests
danielhanchen a512197
Shorten comments in patch_Gemma4TextMLP
danielhanchen f21085d
Consolidate review tests into test_patch_gemma4_text_mlp.py
danielhanchen a709dfc
Merge remote-tracking branch 'staging/pr-600-tests' into pr-600-head
danielhanchen 1321af7
Split: keep only 1 file(s)
danielhanchen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |||||||
| import torch | ||||||||
| import os | ||||||||
| from .common import TEMPORARY_PATCHES | ||||||||
| from .utils import raise_error | ||||||||
| from .utils import raise_error, patch_function | ||||||||
|
|
||||||||
|
|
||||||||
| # ============================================================================ | ||||||||
|
|
@@ -606,3 +606,101 @@ def forward(self, hidden_states, position_embeddings, attention_mask=None, **kwa | |||||||
| Gemma4AudioAttention.forward = forward | ||||||||
| pass | ||||||||
| TEMPORARY_PATCHES.append(patch_Gemma4AudioAttention) | ||||||||
|
|
||||||||
|
|
||||||||
| # ============================================================================ | ||||||||
| # Gemma-4 float16 stability patch. | ||||||||
| # | ||||||||
| # Goal: make float16 training (notably GRPO on E2B and E4B) numerically | ||||||||
| # stable without paying blanket fp32 compute cost, so that Tesla T4 - which | ||||||||
| # has no bf16 tensor cores - can run Gemma-4 GRPO at full fp16 matmul speed. | ||||||||
| # | ||||||||
| # The NaN chain (verified via TensorStatisticsHooks on E2B GRPO): | ||||||||
| # | ||||||||
| # 1. Under UNSLOTH_FORCE_FLOAT32 the loader stores linear weights as fp16 | ||||||||
| # (see unsloth_zoo/patching_utils.py do_forced_float32 block). | ||||||||
| # 2. `down_proj(act_fn(gate_proj(x)) * up_proj(x))` saturates in fp16 at | ||||||||
| # `layers.0.mlp.down_proj` (E2B) / `layers.1.mlp.down_proj` (E4B). | ||||||||
| # 3. The +-inf in down_proj output poisons the residual stream; the next | ||||||||
| # generation step samples NaN logits and the CUDA categorical sampler | ||||||||
| # trips a device-side assert. | ||||||||
| # | ||||||||
| # Minimal fix: a single patch on Gemma4TextMLP. The product is computed in | ||||||||
| # fp32, clamped to a safe fp16 bound, and down_proj's output is rescued | ||||||||
| # with `nan_to_num` for the rare tail overflow from its fp16 accumulator. | ||||||||
| # | ||||||||
| # Why we do NOT patch RMSNorm / Attention / Embedding: | ||||||||
| # * Gemma4RMSNorm already casts to fp32 internally and returns | ||||||||
| # `type_as(hidden_states)` - that dtype contract is fine for fp16. | ||||||||
| # * Gemma4TextScaledWordEmbedding multiplies by sqrt(hidden_size) which | ||||||||
| # is ~45-60 for E2B/E4B, well within fp16 range. | ||||||||
| # * Gemma4TextAttention projections did not overflow in any run; the | ||||||||
| # failing path is exclusively the MLP gate*up product + down_proj | ||||||||
| # accumulator. | ||||||||
| # | ||||||||
| # Bisection evidence: 8-step GRPO on unsloth/gemma-4-E2B-it and | ||||||||
| # unsloth/gemma-4-E4B-it, with gradient_checkpointing on and off, completes | ||||||||
| # cleanly with just patch_Gemma4TextMLP. Adding Attention, RMSNorm, or | ||||||||
| # Embedding patches on top produces byte-identical loss / grad_norm / kl | ||||||||
| # trajectories. | ||||||||
| # ============================================================================ | ||||||||
|
|
||||||||
|
|
||||||||
| def patch_Gemma4TextMLP(): | ||||||||
| """Stabilize Gemma-4 MLP under fp16 autocast (GRPO on fp16, Tesla T4). | ||||||||
|
|
||||||||
| Root cause: `down_proj(act_fn(gate_proj(x)) * up_proj(x))` summed over | ||||||||
| the wide intermediate dimension can exceed fp16_max = 65504 so the fp16 | ||||||||
| matmul cast produces +-inf. That inf poisons the residual stream and | ||||||||
| generation then samples NaN logits, tripping the categorical assert at | ||||||||
| GRPO step ~2 on E2B/E4B with dtype=torch.float16. | ||||||||
|
|
||||||||
| Fix, three cheap operations, single patch: | ||||||||
|
|
||||||||
| 1. act_fn(gate) * up in fp32 so the product cannot overflow. | ||||||||
| 2. Clamp to 65280 (largest value exactly representable in both fp16 | ||||||||
| and bf16) before down_proj. | ||||||||
| 3. nan_to_num on the output, rescuing the rare down_proj fp16 | ||||||||
| accumulator overflow on wide intermediate dims. | ||||||||
|
|
||||||||
| Dtype contract is unchanged from upstream (input dtype -> input dtype), | ||||||||
| so RMSNorm / Attention / Embedding need no companion patches and the KV | ||||||||
| cache stays aligned with the text attention output. gate_proj, up_proj | ||||||||
| and down_proj remain fp16 tensor-core matmuls (full T4 throughput at | ||||||||
| 65 TFLOPS). | ||||||||
| """ | ||||||||
| if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": | ||||||||
| return | ||||||||
| try: | ||||||||
| import transformers.models.gemma4.modeling_gemma4 as mod | ||||||||
| except ImportError: | ||||||||
| return | ||||||||
| try: | ||||||||
| Gemma4TextMLP = mod.Gemma4TextMLP | ||||||||
| except AttributeError as e: | ||||||||
| return raise_error("Gemma4TextMLP.forward", e) | ||||||||
|
|
||||||||
| # Largest value exactly representable in both fp16 and bf16 (one bf16 | ||||||||
| # ULP below 65536, which would round to fp16 inf). | ||||||||
| _SAFE_FP16 = 65280.0 | ||||||||
|
|
||||||||
| def forward(self, x): | ||||||||
| gate = self.gate_proj(x) | ||||||||
| # Gate on matmul output dtype, not x.dtype, to catch fp16-cast | ||||||||
| # projections over bf16/fp32 activations (autocast or forced fp16). | ||||||||
| if gate.dtype != torch.float16: | ||||||||
| return self.down_proj(self.act_fn(gate) * self.up_proj(x)) | ||||||||
| product = self.act_fn(gate.float()) * self.up_proj(x).float() | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation calls
Suggested change
|
||||||||
| product = torch.clamp(product, min=-_SAFE_FP16, max=_SAFE_FP16) | ||||||||
| out = self.down_proj(product.to(gate.dtype)) | ||||||||
| # Replace overflow with 0 so the residual keeps the identity path | ||||||||
| # instead of being dominated by a near-fp16_max injection. | ||||||||
| return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0) | ||||||||
| try: | ||||||||
| patch_function( | ||||||||
| Gemma4TextMLP, "forward", forward, fullgraph=False, | ||||||||
| ) | ||||||||
| except Exception as e: | ||||||||
| return raise_error("Gemma4TextMLP.forward", e) | ||||||||
| pass | ||||||||
| TEMPORARY_PATCHES.append(patch_Gemma4TextMLP) | ||||||||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR description mentions that the fix should "Disable fp16 autocast so gate_proj and up_proj execute at the bf16 weight dtype". However, the current implementation calls
self.gate_proj(x)andself.up_proj(x)within the existing autocast context. Iffp16autocast is active, these matmuls will still execute infp16and potentially saturate atfp16_max(65504) before the.float()call. While the subsequentclampandnan_to_numprovide stability against product and accumulator overflows, they do not recover the range lost during the initial projections if they saturated. Consider wrapping the projections inwith torch.cuda.amp.autocast(enabled=False):if the intent was to leverage the fullbf16range of the weights.