diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5651a7da41..45206045e1 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1947,6 +1947,91 @@ def patch_trl_rl_trainers(): return +def patch_trl_disable_gradient_checkpointing(): + # TRL 1.0.0+ wraps generation in: + # with torch.no_grad(), disable_gradient_checkpointing(self.model, ...): + # The toggle exists only to suppress a cosmetic PyTorch warning + # ("None of the inputs have requires_grad=True"). Inside torch.no_grad() + # the gradient checkpointing state has no functional effect on the + # forward pass. + # + # On exit, the context manager calls model.gradient_checkpointing_enable() + # which dispatches to HuggingFace's generic implementation and overwrites + # Unsloth's custom `use_gradient_checkpointing="unsloth"` wrapper. For + # Gemma-4 (and likely other models) this corrupts the forward numerics + # enough to make GRPO KL divergence explode to ~10^12 at step 1. + # + # Replacing the context manager with a no-op preserves Unsloth's custom + # gradient checkpointing wrapper across generation/inference passes. + # + # Backwards compatibility: + # - trl < 1.0.0 (no disable_gradient_checkpointing): early return. + # - trl >= 1.0.0: noop is functionally equivalent for forward + # correctness. The only loss is a cosmetic warning being emitted + # by PyTorch when use_reentrant=True (which is exactly the warning + # TRL added the toggle to suppress in the first place). + try: + import trl.models.utils as _tmu + except ImportError: + return + if not hasattr(_tmu, "disable_gradient_checkpointing"): + return + if getattr( + _tmu.disable_gradient_checkpointing, + "_unsloth_noop_patched", + False, + ): + return + + from contextlib import contextmanager + + @contextmanager + def _noop_disable_gradient_checkpointing( + model, + gradient_checkpointing_kwargs = None, + ): + yield + + _noop_disable_gradient_checkpointing._unsloth_noop_patched = True + + _tmu.disable_gradient_checkpointing = _noop_disable_gradient_checkpointing + + # Also rebind any trl.* module that already imported the symbol by + # reference, so the noop applies even when the trainer module cached the + # original at import time. We walk sys.modules dynamically rather than + # hardcoding a list, so this picks up: + # trl.trainer.grpo_trainer, trl.trainer.dpo_trainer, + # trl.trainer.rloo_trainer, trl.experimental.dppo.dppo_trainer, + # trl.experimental.gfpo.gfpo_trainer, + # trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainer + # and any future TRL module that adds `from ...models.utils import + # disable_gradient_checkpointing`. + import sys as _sys + + for _mod_name, _mod in list(_sys.modules.items()): + if _mod is None: + continue + if not _mod_name.startswith("trl."): + continue + try: + _bound = getattr(_mod, "disable_gradient_checkpointing", None) + except Exception: + continue + if _bound is None: + continue + if getattr(_bound, "_unsloth_noop_patched", False): + continue + try: + setattr( + _mod, + "disable_gradient_checkpointing", + _noop_disable_gradient_checkpointing, + ) + except Exception: + pass + return + + def patch_trl_openenv(): for function in RL_ADDITIONAL_FUNCTIONS["openenv"]: logger.info(f"Unsloth: Patching trl openenv with function: {function.__name__}") @@ -1981,6 +2066,11 @@ def patch_trl_vllm_generation(): def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) + # Install the disable_gradient_checkpointing noop BEFORE + # patch_trl_rl_trainers so the compiled cache picks up the noop + # at its `from trl.trainer.grpo_trainer import disable_gradient_checkpointing` + # binding time. + patch_trl_disable_gradient_checkpointing() patch_trl_rl_trainers() patch_trl_openenv() patch_trl_vllm_generation() diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 5abeb3a81a..88587a4d10 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -1094,6 +1094,43 @@ def from_pretrained( # Log Unsloth version for future fastpaths for inference if hasattr(model, "config"): model.config.update({"unsloth_version": __version__}) + + # For multimodal models (e.g. Gemma-4) the + # `final_logit_softcapping` attribute lives on + # `config.text_config`, but Unsloth's GRPO trainer reads it from + # the top-level `model.config`. Inject it at the top level so the + # lookup finds the correct value (e.g. 30.0 for Gemma-4) instead + # of silently defaulting to 0. No-op for models that already + # expose it at the top level or do not use softcapping. + try: + _top_config = model.config + if getattr(_top_config, "final_logit_softcapping", None) is None: + _softcap = None + _text_cfg = getattr(_top_config, "text_config", None) + if _text_cfg is not None: + _softcap = getattr( + _text_cfg, + "final_logit_softcapping", + None, + ) + if _softcap is None: + _get_text = getattr(_top_config, "get_text_config", None) + if callable(_get_text): + try: + _softcap = getattr( + _get_text(), + "final_logit_softcapping", + None, + ) + except Exception: + pass + if _softcap is not None: + try: + setattr(_top_config, "final_logit_softcapping", _softcap) + except Exception: + pass + except Exception: + pass patch_saving_functions(model, vision = True) if tokenizer is None: # Last resort: try loading tokenizer via AutoTokenizer, then PreTrainedTokenizerFast