Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import inspect
import os
import re
import sys
from contextlib import contextmanager
from unsloth_zoo.compiler import create_new_function
from unsloth_zoo.log import logger
from unsloth_zoo.logging_utils import PatchRLStatistics
Expand Down Expand Up @@ -1947,6 +1949,83 @@ def patch_trl_rl_trainers():
return


def patch_trl_disable_gradient_checkpointing():
# TRL 1.0.0+ wraps generation in:
# with torch.no_grad(), disable_gradient_checkpointing(self.model, ...):
# The toggle exists only to suppress a cosmetic PyTorch warning
# ("None of the inputs have requires_grad=True"). Inside torch.no_grad()
# the gradient checkpointing state has no functional effect on the
# forward pass.
#
# On exit, the context manager calls model.gradient_checkpointing_enable()
# which dispatches to HuggingFace's generic implementation and overwrites
# Unsloth's custom `use_gradient_checkpointing="unsloth"` wrapper. For
# Gemma-4 (and likely other models) this corrupts the forward numerics
# enough to make GRPO KL divergence explode to ~10^12 at step 1.
#
# Replacing the context manager with a no-op preserves Unsloth's custom
# gradient checkpointing wrapper across generation/inference passes.
#
# Backwards compatibility:
# - trl < 1.0.0 (no disable_gradient_checkpointing): early return.
# - trl >= 1.0.0: noop is functionally equivalent for forward
# correctness. The only loss is a cosmetic warning being emitted
# by PyTorch when use_reentrant=True (which is exactly the warning
# TRL added the toggle to suppress in the first place).
try:
import trl.models.utils as _tmu
except ImportError:
return
if not hasattr(_tmu, "disable_gradient_checkpointing"):
return
if getattr(
_tmu.disable_gradient_checkpointing,
"_unsloth_noop_patched",
False,
):
return

@contextmanager
def _noop_disable_gradient_checkpointing(model, gradient_checkpointing_kwargs = None):
yield

_noop_disable_gradient_checkpointing._unsloth_noop_patched = True

_tmu.disable_gradient_checkpointing = _noop_disable_gradient_checkpointing

# Also rebind any trl.* module that already imported the symbol by
# reference, so the noop applies even when the trainer module cached the
# original at import time. We walk sys.modules dynamically rather than
# hardcoding a list, so this picks up every trainer that does
# `from ...models.utils import disable_gradient_checkpointing`
# (grpo, dpo, rloo, dppo, gfpo, grpo_with_replay_buffer, and any future
# TRL trainer module).
for _mod_name, _mod in list(sys.modules.items()):
if _mod is None or not _mod_name.startswith("trl."):
continue
try:
_bound = getattr(_mod, "disable_gradient_checkpointing", None)
except (AttributeError, ImportError):
continue
if _bound is None:
continue
try:
setattr(
_mod,
Comment on lines +2013 to +2014
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Silently passing on exceptions here can hide potential issues where the monkey-patching fails. As per the rule 'Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.', it's better to log these exceptions, even at a debug level, to help with future troubleshooting. This will make the patching process more transparent without crashing on unexpected errors.

Suggested change
setattr(
_mod,
except Exception as e:
logger.debug(f"Unsloth: Could not patch disable_gradient_checkpointing on {_mod_name}: {e}")
References
  1. Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.

"disable_gradient_checkpointing",
_noop_disable_gradient_checkpointing,
)
except (AttributeError, TypeError):
pass

logger.warning_once(
"Unsloth: Patched trl.models.utils.disable_gradient_checkpointing with "
"a no-op to preserve Unsloth gradient checkpointing across TRL "
"generation passes."
)
return


def patch_trl_openenv():
for function in RL_ADDITIONAL_FUNCTIONS["openenv"]:
logger.info(f"Unsloth: Patching trl openenv with function: {function.__name__}")
Expand Down Expand Up @@ -1981,6 +2060,14 @@ def patch_trl_vllm_generation():
def PatchFastRL(algorithm = None, FastLanguageModel = None):
if FastLanguageModel is not None:
PatchRL(FastLanguageModel)
# Install the disable_gradient_checkpointing noop BEFORE
# patch_trl_rl_trainers. patch_trl_rl_trainers imports extra trl.* trainer
# submodules while generating the compiled cache; any new trl.* modules
# imported after the sys.modules walk would keep their original (broken)
# binding of disable_gradient_checkpointing. Running the noop install
# first ensures the canonical trl.models.utils symbol is already replaced
# before those submodules bind it.
patch_trl_disable_gradient_checkpointing()
patch_trl_rl_trainers()
patch_trl_openenv()
patch_trl_vllm_generation()
Expand Down
35 changes: 29 additions & 6 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,9 +855,7 @@ def chunk_optional(tensor, chunks):
image_sizes_chunks = chunk_optional(image_sizes, B)

temperature = self.temperature
logit_softcapping = getattr(model.config, "final_logit_softcapping", 0)
if logit_softcapping is None:
logit_softcapping = 0
logit_softcapping = _unsloth_get_final_logit_softcapping(model.config)
logit_scale_multiply = getattr(model.config, "logit_scale", 0)
if logit_scale_multiply is None:
logit_scale_multiply = 0
Expand Down Expand Up @@ -1004,11 +1002,38 @@ def chunk_optional(tensor, chunks):

RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps_and_entropies)


def _unsloth_get_final_logit_softcapping(config):
"""Return final_logit_softcapping for a model config, falling back to the
nested text sub-config for composite models. Handles both:
- Gemma-4-style configs where the attribute lives on ``config.text_config``
- T5Gemma-style composite configs where the text sub-config is only
reachable via ``config.get_text_config()``
Returns 0 if unset, matching the previous behaviour.
"""
softcap = getattr(config, "final_logit_softcapping", None)
if softcap is None:
text_cfg = getattr(config, "text_config", None)
if text_cfg is None:
get_text_config = getattr(config, "get_text_config", None)
if callable(get_text_config):
try:
text_cfg = get_text_config()
except (TypeError, ValueError):
text_cfg = None
if text_cfg is not None and text_cfg is not config:
softcap = getattr(text_cfg, "final_logit_softcapping", None)
return 0 if softcap is None else softcap


grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"]
grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"]
UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"]
grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"]
grpo_update_SamplingParams = RL_REPLACEMENTS["grpo_update_SamplingParams"]
RL_PRE_ITEMS["grpo_trainer"].append(
inspect.getsource(_unsloth_get_final_logit_softcapping)
)
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss))
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO))
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss))
Expand Down Expand Up @@ -1107,9 +1132,7 @@ def compute_loss(
input_ids = input_ids[:, -logits_to_keep:]

# Get logit softcapping and logit scale
logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma
if logit_softcapping is None:
logit_softcapping = 0
logit_softcapping = _unsloth_get_final_logit_softcapping(model.config) # Gemma
logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere
if logit_scale_multiply is None:
logit_scale_multiply = 0
Expand Down
Loading