From 1dbc1ed12d94a537ad7970ecb7655834c4f79ffb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 9 Apr 2026 14:29:45 +0000 Subject: [PATCH 1/2] Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+ Two compounding bugs caused Gemma-4 GRPO training to diverge with KL ~10^12 at step 1 against TRL 1.0.0+. Both fixes are runtime patches in the existing TRL/model patch flow and are no-ops for models and TRL versions that are not affected. Fix 1 (rl.py): replace trl.models.utils.disable_gradient_checkpointing with a no-op context manager. TRL 1.0.0+ wraps generation in `with torch.no_grad(), disable_gradient_checkpointing(self.model, ...):` purely to suppress a cosmetic PyTorch warning ("None of the inputs have requires_grad=True"). Inside torch.no_grad() the gradient checkpointing state has no functional effect on the forward pass. On context exit, TRL calls model.gradient_checkpointing_enable() which dispatches to HF's generic implementation and overwrites Unsloth's custom `use_gradient_checkpointing="unsloth"` wrapper, corrupting Gemma-4 forward numerics. Replacing the toggle with a no-op preserves Unsloth's custom GC wrapper across generation passes. The patch walks sys.modules dynamically to also rebind the symbol on every trl.* module that already imported it (grpo_trainer, dpo_trainer, rloo_trainer, dppo_trainer, gfpo_trainer, grpo_with_replay_buffer_trainer, and any future trainer module). Fix 2 (vision.py): inject `final_logit_softcapping` from `config.text_config` into the top-level `model.config` for multimodal models. Unsloth's GRPO trainer reads `getattr(model.config, "final_logit_softcapping", 0)` but for Gemma-4 the attribute lives only on the nested `Gemma4TextConfig`, so the lookup silently defaults to 0 instead of 30. Backwards compatibility: - trl 0.22.2: no `disable_gradient_checkpointing` symbol exists, the patch early-returns via `hasattr` guard. - trl 0.27.1: same broken pattern as 1.0.0, the noop replacement is correct. - trl 1.0.0+: end-to-end verified on `unsloth/gemma-4-E2B-it` GRPO with TRL 1.0.0 and transformers 5.5.0. Step 1 loss=2.46e-08, kl=2.92e-05 (machine zero) vs broken baseline loss=1.37e+06, kl=1.76e+09. - Llama / non-VLM text models: Fix 2 is a no-op (no `text_config`); Fix 1 is functionally identical (Unsloth's GC wrapper is preserved). - Qwen3-VL and other VLMs without final_logit_softcapping: Fix 2 is a no-op (text_config.final_logit_softcapping is None). --- unsloth/models/rl.py | 85 ++++++++++++++++++++++++++++++++++++++++ unsloth/models/vision.py | 35 +++++++++++++++++ 2 files changed, 120 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5651a7da41..77e0f98986 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1947,6 +1947,86 @@ def patch_trl_rl_trainers(): return +def patch_trl_disable_gradient_checkpointing(): + # TRL 1.0.0+ wraps generation in: + # with torch.no_grad(), disable_gradient_checkpointing(self.model, ...): + # The toggle exists only to suppress a cosmetic PyTorch warning + # ("None of the inputs have requires_grad=True"). Inside torch.no_grad() + # the gradient checkpointing state has no functional effect on the + # forward pass. + # + # On exit, the context manager calls model.gradient_checkpointing_enable() + # which dispatches to HuggingFace's generic implementation and overwrites + # Unsloth's custom `use_gradient_checkpointing="unsloth"` wrapper. For + # Gemma-4 (and likely other models) this corrupts the forward numerics + # enough to make GRPO KL divergence explode to ~10^12 at step 1. + # + # Replacing the context manager with a no-op preserves Unsloth's custom + # gradient checkpointing wrapper across generation/inference passes. + # + # Backwards compatibility: + # - trl < 1.0.0 (no disable_gradient_checkpointing): early return. + # - trl >= 1.0.0: noop is functionally equivalent for forward + # correctness. The only loss is a cosmetic warning being emitted + # by PyTorch when use_reentrant=True (which is exactly the warning + # TRL added the toggle to suppress in the first place). + try: + import trl.models.utils as _tmu + except ImportError: + return + if not hasattr(_tmu, "disable_gradient_checkpointing"): + return + if getattr( + _tmu.disable_gradient_checkpointing, "_unsloth_noop_patched", False, + ): + return + + from contextlib import contextmanager + + @contextmanager + def _noop_disable_gradient_checkpointing( + model, gradient_checkpointing_kwargs = None, + ): + yield + _noop_disable_gradient_checkpointing._unsloth_noop_patched = True + + _tmu.disable_gradient_checkpointing = _noop_disable_gradient_checkpointing + + # Also rebind any trl.* module that already imported the symbol by + # reference, so the noop applies even when the trainer module cached the + # original at import time. We walk sys.modules dynamically rather than + # hardcoding a list, so this picks up: + # trl.trainer.grpo_trainer, trl.trainer.dpo_trainer, + # trl.trainer.rloo_trainer, trl.experimental.dppo.dppo_trainer, + # trl.experimental.gfpo.gfpo_trainer, + # trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainer + # and any future TRL module that adds `from ...models.utils import + # disable_gradient_checkpointing`. + import sys as _sys + for _mod_name, _mod in list(_sys.modules.items()): + if _mod is None: + continue + if not _mod_name.startswith("trl."): + continue + try: + _bound = getattr(_mod, "disable_gradient_checkpointing", None) + except Exception: + continue + if _bound is None: + continue + if getattr(_bound, "_unsloth_noop_patched", False): + continue + try: + setattr( + _mod, + "disable_gradient_checkpointing", + _noop_disable_gradient_checkpointing, + ) + except Exception: + pass + return + + def patch_trl_openenv(): for function in RL_ADDITIONAL_FUNCTIONS["openenv"]: logger.info(f"Unsloth: Patching trl openenv with function: {function.__name__}") @@ -1981,6 +2061,11 @@ def patch_trl_vllm_generation(): def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) + # Install the disable_gradient_checkpointing noop BEFORE + # patch_trl_rl_trainers so the compiled cache picks up the noop + # at its `from trl.trainer.grpo_trainer import disable_gradient_checkpointing` + # binding time. + patch_trl_disable_gradient_checkpointing() patch_trl_rl_trainers() patch_trl_openenv() patch_trl_vllm_generation() diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 5abeb3a81a..b450ebf270 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -1094,6 +1094,41 @@ def from_pretrained( # Log Unsloth version for future fastpaths for inference if hasattr(model, "config"): model.config.update({"unsloth_version": __version__}) + + # For multimodal models (e.g. Gemma-4) the + # `final_logit_softcapping` attribute lives on + # `config.text_config`, but Unsloth's GRPO trainer reads it from + # the top-level `model.config`. Inject it at the top level so the + # lookup finds the correct value (e.g. 30.0 for Gemma-4) instead + # of silently defaulting to 0. No-op for models that already + # expose it at the top level or do not use softcapping. + try: + _top_config = model.config + if getattr(_top_config, "final_logit_softcapping", None) is None: + _softcap = None + _text_cfg = getattr(_top_config, "text_config", None) + if _text_cfg is not None: + _softcap = getattr( + _text_cfg, "final_logit_softcapping", None, + ) + if _softcap is None: + _get_text = getattr(_top_config, "get_text_config", None) + if callable(_get_text): + try: + _softcap = getattr( + _get_text(), + "final_logit_softcapping", + None, + ) + except Exception: + pass + if _softcap is not None: + try: + setattr(_top_config, "final_logit_softcapping", _softcap) + except Exception: + pass + except Exception: + pass patch_saving_functions(model, vision = True) if tokenizer is None: # Last resort: try loading tokenizer via AutoTokenizer, then PreTrainedTokenizerFast From 1f27883ac0408fb5bf531322d7ef87361fdf7c63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Apr 2026 14:30:36 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl.py | 9 +++++++-- unsloth/models/vision.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 77e0f98986..45206045e1 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1977,7 +1977,9 @@ def patch_trl_disable_gradient_checkpointing(): if not hasattr(_tmu, "disable_gradient_checkpointing"): return if getattr( - _tmu.disable_gradient_checkpointing, "_unsloth_noop_patched", False, + _tmu.disable_gradient_checkpointing, + "_unsloth_noop_patched", + False, ): return @@ -1985,9 +1987,11 @@ def patch_trl_disable_gradient_checkpointing(): @contextmanager def _noop_disable_gradient_checkpointing( - model, gradient_checkpointing_kwargs = None, + model, + gradient_checkpointing_kwargs = None, ): yield + _noop_disable_gradient_checkpointing._unsloth_noop_patched = True _tmu.disable_gradient_checkpointing = _noop_disable_gradient_checkpointing @@ -2003,6 +2007,7 @@ def _noop_disable_gradient_checkpointing( # and any future TRL module that adds `from ...models.utils import # disable_gradient_checkpointing`. import sys as _sys + for _mod_name, _mod in list(_sys.modules.items()): if _mod is None: continue diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index b450ebf270..88587a4d10 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -1109,7 +1109,9 @@ def from_pretrained( _text_cfg = getattr(_top_config, "text_config", None) if _text_cfg is not None: _softcap = getattr( - _text_cfg, "final_logit_softcapping", None, + _text_cfg, + "final_logit_softcapping", + None, ) if _softcap is None: _get_text = getattr(_top_config, "get_text_config", None)