From 1dbc1ed12d94a537ad7970ecb7655834c4f79ffb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 9 Apr 2026 14:29:45 +0000 Subject: [PATCH 1/6] Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+ Two compounding bugs caused Gemma-4 GRPO training to diverge with KL ~10^12 at step 1 against TRL 1.0.0+. Both fixes are runtime patches in the existing TRL/model patch flow and are no-ops for models and TRL versions that are not affected. Fix 1 (rl.py): replace trl.models.utils.disable_gradient_checkpointing with a no-op context manager. TRL 1.0.0+ wraps generation in `with torch.no_grad(), disable_gradient_checkpointing(self.model, ...):` purely to suppress a cosmetic PyTorch warning ("None of the inputs have requires_grad=True"). Inside torch.no_grad() the gradient checkpointing state has no functional effect on the forward pass. On context exit, TRL calls model.gradient_checkpointing_enable() which dispatches to HF's generic implementation and overwrites Unsloth's custom `use_gradient_checkpointing="unsloth"` wrapper, corrupting Gemma-4 forward numerics. Replacing the toggle with a no-op preserves Unsloth's custom GC wrapper across generation passes. The patch walks sys.modules dynamically to also rebind the symbol on every trl.* module that already imported it (grpo_trainer, dpo_trainer, rloo_trainer, dppo_trainer, gfpo_trainer, grpo_with_replay_buffer_trainer, and any future trainer module). Fix 2 (vision.py): inject `final_logit_softcapping` from `config.text_config` into the top-level `model.config` for multimodal models. Unsloth's GRPO trainer reads `getattr(model.config, "final_logit_softcapping", 0)` but for Gemma-4 the attribute lives only on the nested `Gemma4TextConfig`, so the lookup silently defaults to 0 instead of 30. Backwards compatibility: - trl 0.22.2: no `disable_gradient_checkpointing` symbol exists, the patch early-returns via `hasattr` guard. - trl 0.27.1: same broken pattern as 1.0.0, the noop replacement is correct. - trl 1.0.0+: end-to-end verified on `unsloth/gemma-4-E2B-it` GRPO with TRL 1.0.0 and transformers 5.5.0. Step 1 loss=2.46e-08, kl=2.92e-05 (machine zero) vs broken baseline loss=1.37e+06, kl=1.76e+09. - Llama / non-VLM text models: Fix 2 is a no-op (no `text_config`); Fix 1 is functionally identical (Unsloth's GC wrapper is preserved). - Qwen3-VL and other VLMs without final_logit_softcapping: Fix 2 is a no-op (text_config.final_logit_softcapping is None). --- unsloth/models/rl.py | 85 ++++++++++++++++++++++++++++++++++++++++ unsloth/models/vision.py | 35 +++++++++++++++++ 2 files changed, 120 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5651a7da41..77e0f98986 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1947,6 +1947,86 @@ def patch_trl_rl_trainers(): return +def patch_trl_disable_gradient_checkpointing(): + # TRL 1.0.0+ wraps generation in: + # with torch.no_grad(), disable_gradient_checkpointing(self.model, ...): + # The toggle exists only to suppress a cosmetic PyTorch warning + # ("None of the inputs have requires_grad=True"). Inside torch.no_grad() + # the gradient checkpointing state has no functional effect on the + # forward pass. + # + # On exit, the context manager calls model.gradient_checkpointing_enable() + # which dispatches to HuggingFace's generic implementation and overwrites + # Unsloth's custom `use_gradient_checkpointing="unsloth"` wrapper. For + # Gemma-4 (and likely other models) this corrupts the forward numerics + # enough to make GRPO KL divergence explode to ~10^12 at step 1. + # + # Replacing the context manager with a no-op preserves Unsloth's custom + # gradient checkpointing wrapper across generation/inference passes. + # + # Backwards compatibility: + # - trl < 1.0.0 (no disable_gradient_checkpointing): early return. + # - trl >= 1.0.0: noop is functionally equivalent for forward + # correctness. The only loss is a cosmetic warning being emitted + # by PyTorch when use_reentrant=True (which is exactly the warning + # TRL added the toggle to suppress in the first place). + try: + import trl.models.utils as _tmu + except ImportError: + return + if not hasattr(_tmu, "disable_gradient_checkpointing"): + return + if getattr( + _tmu.disable_gradient_checkpointing, "_unsloth_noop_patched", False, + ): + return + + from contextlib import contextmanager + + @contextmanager + def _noop_disable_gradient_checkpointing( + model, gradient_checkpointing_kwargs = None, + ): + yield + _noop_disable_gradient_checkpointing._unsloth_noop_patched = True + + _tmu.disable_gradient_checkpointing = _noop_disable_gradient_checkpointing + + # Also rebind any trl.* module that already imported the symbol by + # reference, so the noop applies even when the trainer module cached the + # original at import time. We walk sys.modules dynamically rather than + # hardcoding a list, so this picks up: + # trl.trainer.grpo_trainer, trl.trainer.dpo_trainer, + # trl.trainer.rloo_trainer, trl.experimental.dppo.dppo_trainer, + # trl.experimental.gfpo.gfpo_trainer, + # trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainer + # and any future TRL module that adds `from ...models.utils import + # disable_gradient_checkpointing`. + import sys as _sys + for _mod_name, _mod in list(_sys.modules.items()): + if _mod is None: + continue + if not _mod_name.startswith("trl."): + continue + try: + _bound = getattr(_mod, "disable_gradient_checkpointing", None) + except Exception: + continue + if _bound is None: + continue + if getattr(_bound, "_unsloth_noop_patched", False): + continue + try: + setattr( + _mod, + "disable_gradient_checkpointing", + _noop_disable_gradient_checkpointing, + ) + except Exception: + pass + return + + def patch_trl_openenv(): for function in RL_ADDITIONAL_FUNCTIONS["openenv"]: logger.info(f"Unsloth: Patching trl openenv with function: {function.__name__}") @@ -1981,6 +2061,11 @@ def patch_trl_vllm_generation(): def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) + # Install the disable_gradient_checkpointing noop BEFORE + # patch_trl_rl_trainers so the compiled cache picks up the noop + # at its `from trl.trainer.grpo_trainer import disable_gradient_checkpointing` + # binding time. + patch_trl_disable_gradient_checkpointing() patch_trl_rl_trainers() patch_trl_openenv() patch_trl_vllm_generation() diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 5abeb3a81a..b450ebf270 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -1094,6 +1094,41 @@ def from_pretrained( # Log Unsloth version for future fastpaths for inference if hasattr(model, "config"): model.config.update({"unsloth_version": __version__}) + + # For multimodal models (e.g. Gemma-4) the + # `final_logit_softcapping` attribute lives on + # `config.text_config`, but Unsloth's GRPO trainer reads it from + # the top-level `model.config`. Inject it at the top level so the + # lookup finds the correct value (e.g. 30.0 for Gemma-4) instead + # of silently defaulting to 0. No-op for models that already + # expose it at the top level or do not use softcapping. + try: + _top_config = model.config + if getattr(_top_config, "final_logit_softcapping", None) is None: + _softcap = None + _text_cfg = getattr(_top_config, "text_config", None) + if _text_cfg is not None: + _softcap = getattr( + _text_cfg, "final_logit_softcapping", None, + ) + if _softcap is None: + _get_text = getattr(_top_config, "get_text_config", None) + if callable(_get_text): + try: + _softcap = getattr( + _get_text(), + "final_logit_softcapping", + None, + ) + except Exception: + pass + if _softcap is not None: + try: + setattr(_top_config, "final_logit_softcapping", _softcap) + except Exception: + pass + except Exception: + pass patch_saving_functions(model, vision = True) if tokenizer is None: # Last resort: try loading tokenizer via AutoTokenizer, then PreTrainedTokenizerFast From 1f27883ac0408fb5bf531322d7ef87361fdf7c63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Apr 2026 14:30:36 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl.py | 9 +++++++-- unsloth/models/vision.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 77e0f98986..45206045e1 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1977,7 +1977,9 @@ def patch_trl_disable_gradient_checkpointing(): if not hasattr(_tmu, "disable_gradient_checkpointing"): return if getattr( - _tmu.disable_gradient_checkpointing, "_unsloth_noop_patched", False, + _tmu.disable_gradient_checkpointing, + "_unsloth_noop_patched", + False, ): return @@ -1985,9 +1987,11 @@ def patch_trl_disable_gradient_checkpointing(): @contextmanager def _noop_disable_gradient_checkpointing( - model, gradient_checkpointing_kwargs = None, + model, + gradient_checkpointing_kwargs = None, ): yield + _noop_disable_gradient_checkpointing._unsloth_noop_patched = True _tmu.disable_gradient_checkpointing = _noop_disable_gradient_checkpointing @@ -2003,6 +2007,7 @@ def _noop_disable_gradient_checkpointing( # and any future TRL module that adds `from ...models.utils import # disable_gradient_checkpointing`. import sys as _sys + for _mod_name, _mod in list(_sys.modules.items()): if _mod is None: continue diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index b450ebf270..88587a4d10 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -1109,7 +1109,9 @@ def from_pretrained( _text_cfg = getattr(_top_config, "text_config", None) if _text_cfg is not None: _softcap = getattr( - _text_cfg, "final_logit_softcapping", None, + _text_cfg, + "final_logit_softcapping", + None, ) if _softcap is None: _get_text = getattr(_top_config, "get_text_config", None) From cfa28484be39f652ffaa8e60ffd1f4cdddff37b4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 9 Apr 2026 15:18:41 +0000 Subject: [PATCH 3/6] Apply loop 1 review fixes for PR #4934 - Move Fix 2 from vision.py to rl_replacements.py:858 and :1110 at the actual consumer sites. This avoids mutating model.config (which could leak into save_pretrained output) and covers text-only Gemma-4 paths that do not flow through FastBaseModel.from_pretrained. - Revert the vision.py injection block entirely. - Narrow the bare except blocks in patch_trl_disable_gradient_checkpointing from `except Exception:` to `(AttributeError, ImportError)` and `(AttributeError, TypeError)` to avoid masking unrelated bugs. - Add logger.warning_once when the noop patch is installed, matching patch_trl_openenv and patch_trl_vllm_generation convention. - Remove the dead per-module `_unsloth_noop_patched` sentinel check inside the sys.modules walk. The function-level early return already covers this case. - Move `import sys` and `from contextlib import contextmanager` to the module-level imports instead of inside the function body. - Rewrite the ordering comment in PatchFastRL to accurately describe why patch_trl_disable_gradient_checkpointing must run before patch_trl_rl_trainers. - Fix keyword default spacing to match surrounding rl.py style. End-to-end verified: Gemma-4-E2B GRPO on TRL 1.0.0 + transformers 5.5.0 step 1 loss=2.464e-08 kl=2.921e-05, all 5 steps succeed. --- unsloth/models/rl.py | 49 +++++++++++++++---------------- unsloth/models/rl_replacements.py | 16 ++++++++-- unsloth/models/vision.py | 37 ----------------------- 3 files changed, 37 insertions(+), 65 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 45206045e1..3ce9024f6a 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -22,6 +22,8 @@ import inspect import os import re +import sys +from contextlib import contextmanager from unsloth_zoo.compiler import create_new_function from unsloth_zoo.log import logger from unsloth_zoo.logging_utils import PatchRLStatistics @@ -1983,13 +1985,8 @@ def patch_trl_disable_gradient_checkpointing(): ): return - from contextlib import contextmanager - @contextmanager - def _noop_disable_gradient_checkpointing( - model, - gradient_checkpointing_kwargs = None, - ): + def _noop_disable_gradient_checkpointing(model, gradient_checkpointing_kwargs=None): yield _noop_disable_gradient_checkpointing._unsloth_noop_patched = True @@ -1999,36 +1996,33 @@ def _noop_disable_gradient_checkpointing( # Also rebind any trl.* module that already imported the symbol by # reference, so the noop applies even when the trainer module cached the # original at import time. We walk sys.modules dynamically rather than - # hardcoding a list, so this picks up: - # trl.trainer.grpo_trainer, trl.trainer.dpo_trainer, - # trl.trainer.rloo_trainer, trl.experimental.dppo.dppo_trainer, - # trl.experimental.gfpo.gfpo_trainer, - # trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainer - # and any future TRL module that adds `from ...models.utils import - # disable_gradient_checkpointing`. - import sys as _sys - - for _mod_name, _mod in list(_sys.modules.items()): - if _mod is None: - continue - if not _mod_name.startswith("trl."): + # hardcoding a list, so this picks up every trainer that does + # `from ...models.utils import disable_gradient_checkpointing` + # (grpo, dpo, rloo, dppo, gfpo, grpo_with_replay_buffer, and any future + # TRL trainer module). + for _mod_name, _mod in list(sys.modules.items()): + if _mod is None or not _mod_name.startswith("trl."): continue try: _bound = getattr(_mod, "disable_gradient_checkpointing", None) - except Exception: + except (AttributeError, ImportError): continue if _bound is None: continue - if getattr(_bound, "_unsloth_noop_patched", False): - continue try: setattr( _mod, "disable_gradient_checkpointing", _noop_disable_gradient_checkpointing, ) - except Exception: + except (AttributeError, TypeError): pass + + logger.warning_once( + "Unsloth: Patched trl.models.utils.disable_gradient_checkpointing with " + "a no-op to preserve Unsloth gradient checkpointing across TRL " + "generation passes." + ) return @@ -2067,9 +2061,12 @@ def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) # Install the disable_gradient_checkpointing noop BEFORE - # patch_trl_rl_trainers so the compiled cache picks up the noop - # at its `from trl.trainer.grpo_trainer import disable_gradient_checkpointing` - # binding time. + # patch_trl_rl_trainers. patch_trl_rl_trainers imports extra trl.* trainer + # submodules while generating the compiled cache; any new trl.* modules + # imported after the sys.modules walk would keep their original (broken) + # binding of disable_gradient_checkpointing. Running the noop install + # first ensures the canonical trl.models.utils symbol is already replaced + # before those submodules bind it. patch_trl_disable_gradient_checkpointing() patch_trl_rl_trainers() patch_trl_openenv() diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2544afe82e..dda59ba517 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -855,7 +855,13 @@ def chunk_optional(tensor, chunks): image_sizes_chunks = chunk_optional(image_sizes, B) temperature = self.temperature - logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) + # Gemma-4 multimodal configs store final_logit_softcapping on the + # nested text_config; fall back to it so VLM training matches HF. + logit_softcapping = getattr(model.config, "final_logit_softcapping", None) + if logit_softcapping is None: + _text_cfg = getattr(model.config, "text_config", None) + if _text_cfg is not None: + logit_softcapping = getattr(_text_cfg, "final_logit_softcapping", None) if logit_softcapping is None: logit_softcapping = 0 logit_scale_multiply = getattr(model.config, "logit_scale", 0) @@ -1107,7 +1113,13 @@ def compute_loss( input_ids = input_ids[:, -logits_to_keep:] # Get logit softcapping and logit scale - logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma + logit_softcapping = getattr(model.config, "final_logit_softcapping", None) # Gemma + if logit_softcapping is None: + # Gemma-4 multimodal configs store final_logit_softcapping on the + # nested text_config; fall back to it so VLM training matches HF. + _text_cfg = getattr(model.config, "text_config", None) + if _text_cfg is not None: + logit_softcapping = getattr(_text_cfg, "final_logit_softcapping", None) if logit_softcapping is None: logit_softcapping = 0 logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 88587a4d10..5abeb3a81a 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -1094,43 +1094,6 @@ def from_pretrained( # Log Unsloth version for future fastpaths for inference if hasattr(model, "config"): model.config.update({"unsloth_version": __version__}) - - # For multimodal models (e.g. Gemma-4) the - # `final_logit_softcapping` attribute lives on - # `config.text_config`, but Unsloth's GRPO trainer reads it from - # the top-level `model.config`. Inject it at the top level so the - # lookup finds the correct value (e.g. 30.0 for Gemma-4) instead - # of silently defaulting to 0. No-op for models that already - # expose it at the top level or do not use softcapping. - try: - _top_config = model.config - if getattr(_top_config, "final_logit_softcapping", None) is None: - _softcap = None - _text_cfg = getattr(_top_config, "text_config", None) - if _text_cfg is not None: - _softcap = getattr( - _text_cfg, - "final_logit_softcapping", - None, - ) - if _softcap is None: - _get_text = getattr(_top_config, "get_text_config", None) - if callable(_get_text): - try: - _softcap = getattr( - _get_text(), - "final_logit_softcapping", - None, - ) - except Exception: - pass - if _softcap is not None: - try: - setattr(_top_config, "final_logit_softcapping", _softcap) - except Exception: - pass - except Exception: - pass patch_saving_functions(model, vision = True) if tokenizer is None: # Last resort: try loading tokenizer via AutoTokenizer, then PreTrainedTokenizerFast From c860961595cee6de5b20d441e750b6f7f9f74d53 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 9 Apr 2026 15:38:23 +0000 Subject: [PATCH 4/6] Apply loop 2 review fix for PR #4934 Extract the final_logit_softcapping fallback logic into a shared helper `_unsloth_get_final_logit_softcapping(config)` defined in rl_replacements.py and injected into the compiled cache via RL_PRE_ITEMS["grpo_trainer"]. Both call sites (`grpo_trainer__generate_and_score_completions` and `grpo_trainer_compute_loss`) now use the helper instead of inlining the same text_config fallback block twice. Verified: compiled cache file lists the helper at module scope and both consumer sites call it. Gemma-4-E2B GRPO step 1 loss=2.464e-08 kl=2.921e-05 (unchanged), all 5 steps pass. --- unsloth/models/rl_replacements.py | 34 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index dda59ba517..a8ad52c716 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -855,15 +855,7 @@ def chunk_optional(tensor, chunks): image_sizes_chunks = chunk_optional(image_sizes, B) temperature = self.temperature - # Gemma-4 multimodal configs store final_logit_softcapping on the - # nested text_config; fall back to it so VLM training matches HF. - logit_softcapping = getattr(model.config, "final_logit_softcapping", None) - if logit_softcapping is None: - _text_cfg = getattr(model.config, "text_config", None) - if _text_cfg is not None: - logit_softcapping = getattr(_text_cfg, "final_logit_softcapping", None) - if logit_softcapping is None: - logit_softcapping = 0 + logit_softcapping = _unsloth_get_final_logit_softcapping(model.config) logit_scale_multiply = getattr(model.config, "logit_scale", 0) if logit_scale_multiply is None: logit_scale_multiply = 0 @@ -1010,11 +1002,25 @@ def chunk_optional(tensor, chunks): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps_and_entropies) +def _unsloth_get_final_logit_softcapping(config): + """Return final_logit_softcapping for a model config, falling back to the + nested ``text_config`` for multimodal models such as Gemma-4 where the + attribute only lives on ``config.text_config``. Returns 0 if unset. + """ + softcap = getattr(config, "final_logit_softcapping", None) + if softcap is None: + text_cfg = getattr(config, "text_config", None) + if text_cfg is not None: + softcap = getattr(text_cfg, "final_logit_softcapping", None) + return 0 if softcap is None else softcap + + grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"] UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] grpo_update_SamplingParams = RL_REPLACEMENTS["grpo_update_SamplingParams"] +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(_unsloth_get_final_logit_softcapping)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) @@ -1113,15 +1119,7 @@ def compute_loss( input_ids = input_ids[:, -logits_to_keep:] # Get logit softcapping and logit scale - logit_softcapping = getattr(model.config, "final_logit_softcapping", None) # Gemma - if logit_softcapping is None: - # Gemma-4 multimodal configs store final_logit_softcapping on the - # nested text_config; fall back to it so VLM training matches HF. - _text_cfg = getattr(model.config, "text_config", None) - if _text_cfg is not None: - logit_softcapping = getattr(_text_cfg, "final_logit_softcapping", None) - if logit_softcapping is None: - logit_softcapping = 0 + logit_softcapping = _unsloth_get_final_logit_softcapping(model.config) # Gemma logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere if logit_scale_multiply is None: logit_scale_multiply = 0 From 306f06fd72e44910efc6a8d6190b5f3676f54ddd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 9 Apr 2026 15:58:56 +0000 Subject: [PATCH 5/6] Apply loop 3 review fix for PR #4934 Extend _unsloth_get_final_logit_softcapping to also fall back to config.get_text_config() for composite configs such as T5GemmaConfig where the text sub-config is not exposed via the text_config attribute but only via the get_text_config() method. Guard against (TypeError, ValueError) raised by ambiguous composite configs, and skip the self-referential case where get_text_config() returns self. This addresses the 6/7 reviewer consensus from the third review loop. Verified: - Helper returns 30.0 for Gemma-4, T5Gemma, and Gemma 1/2 configs. - Helper returns 0 for Llama, Qwen, Mistral, Cohere, Granite, and ambiguous configs raising ValueError. - Gemma-4-E2B GRPO step 1 loss=2.464e-08 kl=2.921e-05 (unchanged). - Llama-3.2-1B GRPO all 5 steps loss=0 kl=0 (no regression). --- unsloth/models/rl_replacements.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a8ad52c716..9aeee6337c 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1004,13 +1004,23 @@ def chunk_optional(tensor, chunks): def _unsloth_get_final_logit_softcapping(config): """Return final_logit_softcapping for a model config, falling back to the - nested ``text_config`` for multimodal models such as Gemma-4 where the - attribute only lives on ``config.text_config``. Returns 0 if unset. + nested text sub-config for composite models. Handles both: + - Gemma-4-style configs where the attribute lives on ``config.text_config`` + - T5Gemma-style composite configs where the text sub-config is only + reachable via ``config.get_text_config()`` + Returns 0 if unset, matching the previous behaviour. """ softcap = getattr(config, "final_logit_softcapping", None) if softcap is None: text_cfg = getattr(config, "text_config", None) - if text_cfg is not None: + if text_cfg is None: + get_text_config = getattr(config, "get_text_config", None) + if callable(get_text_config): + try: + text_cfg = get_text_config() + except (TypeError, ValueError): + text_cfg = None + if text_cfg is not None and text_cfg is not config: softcap = getattr(text_cfg, "final_logit_softcapping", None) return 0 if softcap is None else softcap From 46b352238ede4d3589e3ec46600b71c7f3d0d4a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:59:24 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl.py | 2 +- unsloth/models/rl_replacements.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3ce9024f6a..f444f6bd37 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1986,7 +1986,7 @@ def patch_trl_disable_gradient_checkpointing(): return @contextmanager - def _noop_disable_gradient_checkpointing(model, gradient_checkpointing_kwargs=None): + def _noop_disable_gradient_checkpointing(model, gradient_checkpointing_kwargs = None): yield _noop_disable_gradient_checkpointing._unsloth_noop_patched = True diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9aeee6337c..93a7f89bcb 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1002,6 +1002,7 @@ def chunk_optional(tensor, chunks): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps_and_entropies) + def _unsloth_get_final_logit_softcapping(config): """Return final_logit_softcapping for a model config, falling back to the nested text sub-config for composite models. Handles both: @@ -1030,7 +1031,9 @@ def _unsloth_get_final_logit_softcapping(config): UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] grpo_update_SamplingParams = RL_REPLACEMENTS["grpo_update_SamplingParams"] -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(_unsloth_get_final_logit_softcapping)) +RL_PRE_ITEMS["grpo_trainer"].append( + inspect.getsource(_unsloth_get_final_logit_softcapping) +) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss))