Skip to content

Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+#21

Closed
danielhanchen wants to merge 2 commits into
mainfrom
pr-4934-head
Closed

Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+#21
danielhanchen wants to merge 2 commits into
mainfrom
pr-4934-head

Conversation

@danielhanchen
Copy link
Copy Markdown
Owner

Staging mirror of unslothai#4934

Original PR: unslothai#4934
Author: danielhanchen

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Summary

Fixes Gemma-4 GRPO training diverging with KL ~10^12 at step 1 against TRL 1.0.0+, by adding two runtime patches to the existing TRL/model patch flow. Both patches are no-ops for models and TRL versions that are not affected.

The bugs

Bug 1 (primary): TRL disable_gradient_checkpointing overwrites Unsloth's custom GC wrapper

TRL 1.0.0+ wraps generation in:

with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):

The toggle exists only to suppress a cosmetic PyTorch warning (None of the inputs have requires_grad=True). Inside torch.no_grad() the gradient checkpointing state has no functional effect on the forward pass.

On context exit, TRL calls model.gradient_checkpointing_enable(...) which dispatches to HuggingFace's generic implementation and overwrites Unsloth's custom use_gradient_checkpointing="unsloth" wrapper. For Gemma-4 (and likely other models) this corrupts the forward numerics enough to make the training-step forward diverge from the reference forward, producing KL ~10^12 at step 1.

Bug 2 (secondary): final_logit_softcapping lookup misses for multimodal Gemma-4

UnslothGRPOTrainer reads getattr(model.config, "final_logit_softcapping", 0). For Gemma4ForConditionalGeneration the attribute lives only on the nested Gemma4TextConfig, so the lookup silently defaults to 0 instead of 30. Both ref and policy paths hit the same bug for LoRA so KL cancels, but full fine-tuning with a separate ref_model produces numerically incorrect logps.

The fix

Fix 1: unsloth/models/rl.py - new patch_trl_disable_gradient_checkpointing()

Replaces trl.models.utils.disable_gradient_checkpointing with a no-op context manager. The patch dynamically walks sys.modules for any trl.* module that already imported the symbol by reference and rebinds it, so it picks up:

  • trl.trainer.grpo_trainer
  • trl.trainer.dpo_trainer
  • trl.trainer.rloo_trainer
  • trl.experimental.dppo.dppo_trainer
  • trl.experimental.gfpo.gfpo_trainer
  • trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainer
  • and any future TRL trainer module

The patch is wired into PatchFastRL BEFORE patch_trl_rl_trainers so the compiled cache picks up the noop at its from trl.trainer.grpo_trainer import disable_gradient_checkpointing binding time.

Fix 2: unsloth/models/vision.py - inject final_logit_softcapping from text_config

In FastBaseModel.from_pretrained, after the model is loaded, lifts final_logit_softcapping from config.text_config (or config.get_text_config()) to the top-level model.config if and only if the top-level config does not already expose it. Skips silently for models that already have it or do not use softcapping.

Backwards compatibility

TRL version Behavior
0.22.2 disable_gradient_checkpointing symbol does not exist. The hasattr guard early-returns. Veri

danielhanchen and others added 2 commits April 9, 2026 14:29
Two compounding bugs caused Gemma-4 GRPO training to diverge with KL ~10^12
at step 1 against TRL 1.0.0+. Both fixes are runtime patches in the existing
TRL/model patch flow and are no-ops for models and TRL versions that are not
affected.

Fix 1 (rl.py): replace trl.models.utils.disable_gradient_checkpointing with
a no-op context manager. TRL 1.0.0+ wraps generation in
`with torch.no_grad(), disable_gradient_checkpointing(self.model, ...):`
purely to suppress a cosmetic PyTorch warning ("None of the inputs have
requires_grad=True"). Inside torch.no_grad() the gradient checkpointing
state has no functional effect on the forward pass. On context exit, TRL
calls model.gradient_checkpointing_enable() which dispatches to HF's
generic implementation and overwrites Unsloth's custom
`use_gradient_checkpointing="unsloth"` wrapper, corrupting Gemma-4 forward
numerics. Replacing the toggle with a no-op preserves Unsloth's custom GC
wrapper across generation passes. The patch walks sys.modules dynamically
to also rebind the symbol on every trl.* module that already imported it
(grpo_trainer, dpo_trainer, rloo_trainer, dppo_trainer, gfpo_trainer,
grpo_with_replay_buffer_trainer, and any future trainer module).

Fix 2 (vision.py): inject `final_logit_softcapping` from `config.text_config`
into the top-level `model.config` for multimodal models. Unsloth's GRPO
trainer reads `getattr(model.config, "final_logit_softcapping", 0)` but
for Gemma-4 the attribute lives only on the nested `Gemma4TextConfig`,
so the lookup silently defaults to 0 instead of 30.

Backwards compatibility:
- trl 0.22.2: no `disable_gradient_checkpointing` symbol exists, the patch
  early-returns via `hasattr` guard.
- trl 0.27.1: same broken pattern as 1.0.0, the noop replacement is correct.
- trl 1.0.0+: end-to-end verified on `unsloth/gemma-4-E2B-it` GRPO with TRL
  1.0.0 and transformers 5.5.0. Step 1 loss=2.46e-08, kl=2.92e-05 (machine
  zero) vs broken baseline loss=1.37e+06, kl=1.76e+09.
- Llama / non-VLM text models: Fix 2 is a no-op (no `text_config`); Fix 1
  is functionally identical (Unsloth's GC wrapper is preserved).
- Qwen3-VL and other VLMs without final_logit_softcapping: Fix 2 is a no-op
  (text_config.final_logit_softcapping is None).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant