diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5651a7da41..f444f6bd37 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -22,6 +22,8 @@ import inspect import os import re +import sys +from contextlib import contextmanager from unsloth_zoo.compiler import create_new_function from unsloth_zoo.log import logger from unsloth_zoo.logging_utils import PatchRLStatistics @@ -1947,6 +1949,83 @@ def patch_trl_rl_trainers(): return +def patch_trl_disable_gradient_checkpointing(): + # TRL 1.0.0+ wraps generation in: + # with torch.no_grad(), disable_gradient_checkpointing(self.model, ...): + # The toggle exists only to suppress a cosmetic PyTorch warning + # ("None of the inputs have requires_grad=True"). Inside torch.no_grad() + # the gradient checkpointing state has no functional effect on the + # forward pass. + # + # On exit, the context manager calls model.gradient_checkpointing_enable() + # which dispatches to HuggingFace's generic implementation and overwrites + # Unsloth's custom `use_gradient_checkpointing="unsloth"` wrapper. For + # Gemma-4 (and likely other models) this corrupts the forward numerics + # enough to make GRPO KL divergence explode to ~10^12 at step 1. + # + # Replacing the context manager with a no-op preserves Unsloth's custom + # gradient checkpointing wrapper across generation/inference passes. + # + # Backwards compatibility: + # - trl < 1.0.0 (no disable_gradient_checkpointing): early return. + # - trl >= 1.0.0: noop is functionally equivalent for forward + # correctness. The only loss is a cosmetic warning being emitted + # by PyTorch when use_reentrant=True (which is exactly the warning + # TRL added the toggle to suppress in the first place). + try: + import trl.models.utils as _tmu + except ImportError: + return + if not hasattr(_tmu, "disable_gradient_checkpointing"): + return + if getattr( + _tmu.disable_gradient_checkpointing, + "_unsloth_noop_patched", + False, + ): + return + + @contextmanager + def _noop_disable_gradient_checkpointing(model, gradient_checkpointing_kwargs = None): + yield + + _noop_disable_gradient_checkpointing._unsloth_noop_patched = True + + _tmu.disable_gradient_checkpointing = _noop_disable_gradient_checkpointing + + # Also rebind any trl.* module that already imported the symbol by + # reference, so the noop applies even when the trainer module cached the + # original at import time. We walk sys.modules dynamically rather than + # hardcoding a list, so this picks up every trainer that does + # `from ...models.utils import disable_gradient_checkpointing` + # (grpo, dpo, rloo, dppo, gfpo, grpo_with_replay_buffer, and any future + # TRL trainer module). + for _mod_name, _mod in list(sys.modules.items()): + if _mod is None or not _mod_name.startswith("trl."): + continue + try: + _bound = getattr(_mod, "disable_gradient_checkpointing", None) + except (AttributeError, ImportError): + continue + if _bound is None: + continue + try: + setattr( + _mod, + "disable_gradient_checkpointing", + _noop_disable_gradient_checkpointing, + ) + except (AttributeError, TypeError): + pass + + logger.warning_once( + "Unsloth: Patched trl.models.utils.disable_gradient_checkpointing with " + "a no-op to preserve Unsloth gradient checkpointing across TRL " + "generation passes." + ) + return + + def patch_trl_openenv(): for function in RL_ADDITIONAL_FUNCTIONS["openenv"]: logger.info(f"Unsloth: Patching trl openenv with function: {function.__name__}") @@ -1981,6 +2060,14 @@ def patch_trl_vllm_generation(): def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) + # Install the disable_gradient_checkpointing noop BEFORE + # patch_trl_rl_trainers. patch_trl_rl_trainers imports extra trl.* trainer + # submodules while generating the compiled cache; any new trl.* modules + # imported after the sys.modules walk would keep their original (broken) + # binding of disable_gradient_checkpointing. Running the noop install + # first ensures the canonical trl.models.utils symbol is already replaced + # before those submodules bind it. + patch_trl_disable_gradient_checkpointing() patch_trl_rl_trainers() patch_trl_openenv() patch_trl_vllm_generation() diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2544afe82e..93a7f89bcb 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -855,9 +855,7 @@ def chunk_optional(tensor, chunks): image_sizes_chunks = chunk_optional(image_sizes, B) temperature = self.temperature - logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) - if logit_softcapping is None: - logit_softcapping = 0 + logit_softcapping = _unsloth_get_final_logit_softcapping(model.config) logit_scale_multiply = getattr(model.config, "logit_scale", 0) if logit_scale_multiply is None: logit_scale_multiply = 0 @@ -1004,11 +1002,38 @@ def chunk_optional(tensor, chunks): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps_and_entropies) + +def _unsloth_get_final_logit_softcapping(config): + """Return final_logit_softcapping for a model config, falling back to the + nested text sub-config for composite models. Handles both: + - Gemma-4-style configs where the attribute lives on ``config.text_config`` + - T5Gemma-style composite configs where the text sub-config is only + reachable via ``config.get_text_config()`` + Returns 0 if unset, matching the previous behaviour. + """ + softcap = getattr(config, "final_logit_softcapping", None) + if softcap is None: + text_cfg = getattr(config, "text_config", None) + if text_cfg is None: + get_text_config = getattr(config, "get_text_config", None) + if callable(get_text_config): + try: + text_cfg = get_text_config() + except (TypeError, ValueError): + text_cfg = None + if text_cfg is not None and text_cfg is not config: + softcap = getattr(text_cfg, "final_logit_softcapping", None) + return 0 if softcap is None else softcap + + grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"] UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] grpo_update_SamplingParams = RL_REPLACEMENTS["grpo_update_SamplingParams"] +RL_PRE_ITEMS["grpo_trainer"].append( + inspect.getsource(_unsloth_get_final_logit_softcapping) +) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) @@ -1107,9 +1132,7 @@ def compute_loss( input_ids = input_ids[:, -logits_to_keep:] # Get logit softcapping and logit scale - logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma - if logit_softcapping is None: - logit_softcapping = 0 + logit_softcapping = _unsloth_get_final_logit_softcapping(model.config) # Gemma logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere if logit_scale_multiply is None: logit_scale_multiply = 0