Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+#21
Closed
danielhanchen wants to merge 2 commits into
Closed
Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+#21danielhanchen wants to merge 2 commits into
danielhanchen wants to merge 2 commits into
Conversation
Two compounding bugs caused Gemma-4 GRPO training to diverge with KL ~10^12
at step 1 against TRL 1.0.0+. Both fixes are runtime patches in the existing
TRL/model patch flow and are no-ops for models and TRL versions that are not
affected.
Fix 1 (rl.py): replace trl.models.utils.disable_gradient_checkpointing with
a no-op context manager. TRL 1.0.0+ wraps generation in
`with torch.no_grad(), disable_gradient_checkpointing(self.model, ...):`
purely to suppress a cosmetic PyTorch warning ("None of the inputs have
requires_grad=True"). Inside torch.no_grad() the gradient checkpointing
state has no functional effect on the forward pass. On context exit, TRL
calls model.gradient_checkpointing_enable() which dispatches to HF's
generic implementation and overwrites Unsloth's custom
`use_gradient_checkpointing="unsloth"` wrapper, corrupting Gemma-4 forward
numerics. Replacing the toggle with a no-op preserves Unsloth's custom GC
wrapper across generation passes. The patch walks sys.modules dynamically
to also rebind the symbol on every trl.* module that already imported it
(grpo_trainer, dpo_trainer, rloo_trainer, dppo_trainer, gfpo_trainer,
grpo_with_replay_buffer_trainer, and any future trainer module).
Fix 2 (vision.py): inject `final_logit_softcapping` from `config.text_config`
into the top-level `model.config` for multimodal models. Unsloth's GRPO
trainer reads `getattr(model.config, "final_logit_softcapping", 0)` but
for Gemma-4 the attribute lives only on the nested `Gemma4TextConfig`,
so the lookup silently defaults to 0 instead of 30.
Backwards compatibility:
- trl 0.22.2: no `disable_gradient_checkpointing` symbol exists, the patch
early-returns via `hasattr` guard.
- trl 0.27.1: same broken pattern as 1.0.0, the noop replacement is correct.
- trl 1.0.0+: end-to-end verified on `unsloth/gemma-4-E2B-it` GRPO with TRL
1.0.0 and transformers 5.5.0. Step 1 loss=2.46e-08, kl=2.92e-05 (machine
zero) vs broken baseline loss=1.37e+06, kl=1.76e+09.
- Llama / non-VLM text models: Fix 2 is a no-op (no `text_config`); Fix 1
is functionally identical (Unsloth's GC wrapper is preserved).
- Qwen3-VL and other VLMs without final_logit_softcapping: Fix 2 is a no-op
(text_config.final_logit_softcapping is None).
for more information, see https://pre-commit.ci
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Staging mirror of unslothai#4934
Original PR: unslothai#4934
Author: danielhanchen
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Summary
Fixes Gemma-4 GRPO training diverging with KL ~10^12 at step 1 against TRL 1.0.0+, by adding two runtime patches to the existing TRL/model patch flow. Both patches are no-ops for models and TRL versions that are not affected.
The bugs
Bug 1 (primary): TRL
disable_gradient_checkpointingoverwrites Unsloth's custom GC wrapperTRL 1.0.0+ wraps generation in:
The toggle exists only to suppress a cosmetic PyTorch warning (
None of the inputs have requires_grad=True). Insidetorch.no_grad()the gradient checkpointing state has no functional effect on the forward pass.On context exit, TRL calls
model.gradient_checkpointing_enable(...)which dispatches to HuggingFace's generic implementation and overwrites Unsloth's customuse_gradient_checkpointing="unsloth"wrapper. For Gemma-4 (and likely other models) this corrupts the forward numerics enough to make the training-step forward diverge from the reference forward, producing KL ~10^12 at step 1.Bug 2 (secondary):
final_logit_softcappinglookup misses for multimodal Gemma-4UnslothGRPOTrainerreadsgetattr(model.config, "final_logit_softcapping", 0). ForGemma4ForConditionalGenerationthe attribute lives only on the nestedGemma4TextConfig, so the lookup silently defaults to0instead of30. Both ref and policy paths hit the same bug for LoRA so KL cancels, but full fine-tuning with a separateref_modelproduces numerically incorrect logps.The fix
Fix 1:
unsloth/models/rl.py- newpatch_trl_disable_gradient_checkpointing()Replaces
trl.models.utils.disable_gradient_checkpointingwith a no-op context manager. The patch dynamically walkssys.modulesfor anytrl.*module that already imported the symbol by reference and rebinds it, so it picks up:trl.trainer.grpo_trainertrl.trainer.dpo_trainertrl.trainer.rloo_trainertrl.experimental.dppo.dppo_trainertrl.experimental.gfpo.gfpo_trainertrl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainerThe patch is wired into
PatchFastRLBEFOREpatch_trl_rl_trainersso the compiled cache picks up the noop at itsfrom trl.trainer.grpo_trainer import disable_gradient_checkpointingbinding time.Fix 2:
unsloth/models/vision.py- injectfinal_logit_softcappingfromtext_configIn
FastBaseModel.from_pretrained, after the model is loaded, liftsfinal_logit_softcappingfromconfig.text_config(orconfig.get_text_config()) to the top-levelmodel.configif and only if the top-level config does not already expose it. Skips silently for models that already have it or do not use softcapping.Backwards compatibility
disable_gradient_checkpointingsymbol does not exist. Thehasattrguard early-returns. Veri