Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+#4934
Conversation
Two compounding bugs caused Gemma-4 GRPO training to diverge with KL ~10^12
at step 1 against TRL 1.0.0+. Both fixes are runtime patches in the existing
TRL/model patch flow and are no-ops for models and TRL versions that are not
affected.
Fix 1 (rl.py): replace trl.models.utils.disable_gradient_checkpointing with
a no-op context manager. TRL 1.0.0+ wraps generation in
`with torch.no_grad(), disable_gradient_checkpointing(self.model, ...):`
purely to suppress a cosmetic PyTorch warning ("None of the inputs have
requires_grad=True"). Inside torch.no_grad() the gradient checkpointing
state has no functional effect on the forward pass. On context exit, TRL
calls model.gradient_checkpointing_enable() which dispatches to HF's
generic implementation and overwrites Unsloth's custom
`use_gradient_checkpointing="unsloth"` wrapper, corrupting Gemma-4 forward
numerics. Replacing the toggle with a no-op preserves Unsloth's custom GC
wrapper across generation passes. The patch walks sys.modules dynamically
to also rebind the symbol on every trl.* module that already imported it
(grpo_trainer, dpo_trainer, rloo_trainer, dppo_trainer, gfpo_trainer,
grpo_with_replay_buffer_trainer, and any future trainer module).
Fix 2 (vision.py): inject `final_logit_softcapping` from `config.text_config`
into the top-level `model.config` for multimodal models. Unsloth's GRPO
trainer reads `getattr(model.config, "final_logit_softcapping", 0)` but
for Gemma-4 the attribute lives only on the nested `Gemma4TextConfig`,
so the lookup silently defaults to 0 instead of 30.
Backwards compatibility:
- trl 0.22.2: no `disable_gradient_checkpointing` symbol exists, the patch
early-returns via `hasattr` guard.
- trl 0.27.1: same broken pattern as 1.0.0, the noop replacement is correct.
- trl 1.0.0+: end-to-end verified on `unsloth/gemma-4-E2B-it` GRPO with TRL
1.0.0 and transformers 5.5.0. Step 1 loss=2.46e-08, kl=2.92e-05 (machine
zero) vs broken baseline loss=1.37e+06, kl=1.76e+09.
- Llama / non-VLM text models: Fix 2 is a no-op (no `text_config`); Fix 1
is functionally identical (Unsloth's GC wrapper is preserved).
- Qwen3-VL and other VLMs without final_logit_softcapping: Fix 2 is a no-op
(text_config.final_logit_softcapping is None).
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces a patch to disable TRL's disable_gradient_checkpointing context manager with a no-op, preventing numerical issues with Unsloth's custom gradient checkpointing, particularly for models like Gemma-4. This patch is applied early in PatchFastRL. Additionally, logic is added to from_pretrained for vision models to ensure final_logit_softcapping is correctly exposed at the top-level model.config for multimodal models, which is essential for the GRPO trainer. The review comments suggest improving error handling by logging exceptions instead of silently passing them in the newly added code blocks, adhering to the rule against broad, silent exception handlers.
| setattr( | ||
| _mod, |
There was a problem hiding this comment.
Silently passing on exceptions here can hide potential issues where the monkey-patching fails. As per the rule 'Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.', it's better to log these exceptions, even at a debug level, to help with future troubleshooting. This will make the patching process more transparent without crashing on unexpected errors.
| setattr( | |
| _mod, | |
| except Exception as e: | |
| logger.debug(f"Unsloth: Could not patch disable_gradient_checkpointing on {_mod_name}: {e}") |
References
- Avoid using broad, silent exception handlers like
except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.
| try: | ||
| _top_config = model.config | ||
| if getattr(_top_config, "final_logit_softcapping", None) is None: | ||
| _softcap = None | ||
| _text_cfg = getattr(_top_config, "text_config", None) | ||
| if _text_cfg is not None: | ||
| _softcap = getattr( | ||
| _text_cfg, | ||
| "final_logit_softcapping", | ||
| None, | ||
| ) | ||
| if _softcap is None: | ||
| _get_text = getattr(_top_config, "get_text_config", None) | ||
| if callable(_get_text): | ||
| try: | ||
| _softcap = getattr( | ||
| _get_text(), | ||
| "final_logit_softcapping", | ||
| None, | ||
| ) | ||
| except Exception: | ||
| pass | ||
| if _softcap is not None: | ||
| try: | ||
| setattr(_top_config, "final_logit_softcapping", _softcap) | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
This block uses multiple nested try...except Exception: pass statements, which can make the code harder to debug and understand. As per the rule 'Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.', it's better to flatten the logic and log exceptions when they occur. This will improve maintainability and help diagnose issues if this patching logic fails for certain model configurations.
try:
_top_config = model.config
if getattr(_top_config, "final_logit_softcapping", None) is None:
_softcap = None
_text_cfg = getattr(_top_config, "text_config", None)
if _text_cfg is not None:
_softcap = getattr(
_text_cfg, "final_logit_softcapping", None,
)
if _softcap is None:
_get_text = getattr(_top_config, "get_text_config", None)
if callable(_get_text):
try:
_softcap = getattr(
_get_text(),
"final_logit_softcapping",
None,
)
except Exception as e:
logger.debug(f"Unsloth: Failed to get text_config for softcapping: {e}")
if _softcap is not None:
try:
setattr(_top_config, "final_logit_softcapping", _softcap)
except Exception as e:
logger.warning(f"Unsloth: Could not set final_logit_softcapping on model config: {e}")
except Exception as e:
logger.warning(f"Unsloth: An unexpected error occurred during final_logit_softcapping patch: {e}")References
- Avoid using broad, silent exception handlers like
except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.
|
Yeah I too noticed trl messing up with gradient checkpointing. This seems like a necessary change |
- Move Fix 2 from vision.py to rl_replacements.py:858 and :1110 at the actual consumer sites. This avoids mutating model.config (which could leak into save_pretrained output) and covers text-only Gemma-4 paths that do not flow through FastBaseModel.from_pretrained. - Revert the vision.py injection block entirely. - Narrow the bare except blocks in patch_trl_disable_gradient_checkpointing from `except Exception:` to `(AttributeError, ImportError)` and `(AttributeError, TypeError)` to avoid masking unrelated bugs. - Add logger.warning_once when the noop patch is installed, matching patch_trl_openenv and patch_trl_vllm_generation convention. - Remove the dead per-module `_unsloth_noop_patched` sentinel check inside the sys.modules walk. The function-level early return already covers this case. - Move `import sys` and `from contextlib import contextmanager` to the module-level imports instead of inside the function body. - Rewrite the ordering comment in PatchFastRL to accurately describe why patch_trl_disable_gradient_checkpointing must run before patch_trl_rl_trainers. - Fix keyword default spacing to match surrounding rl.py style. End-to-end verified: Gemma-4-E2B GRPO on TRL 1.0.0 + transformers 5.5.0 step 1 loss=2.464e-08 kl=2.921e-05, all 5 steps succeed.
Extract the final_logit_softcapping fallback logic into a shared helper `_unsloth_get_final_logit_softcapping(config)` defined in rl_replacements.py and injected into the compiled cache via RL_PRE_ITEMS["grpo_trainer"]. Both call sites (`grpo_trainer__generate_and_score_completions` and `grpo_trainer_compute_loss`) now use the helper instead of inlining the same text_config fallback block twice. Verified: compiled cache file lists the helper at module scope and both consumer sites call it. Gemma-4-E2B GRPO step 1 loss=2.464e-08 kl=2.921e-05 (unchanged), all 5 steps pass.
Extend _unsloth_get_final_logit_softcapping to also fall back to config.get_text_config() for composite configs such as T5GemmaConfig where the text sub-config is not exposed via the text_config attribute but only via the get_text_config() method. Guard against (TypeError, ValueError) raised by ambiguous composite configs, and skip the self-referential case where get_text_config() returns self. This addresses the 6/7 reviewer consensus from the third review loop. Verified: - Helper returns 30.0 for Gemma-4, T5Gemma, and Gemma 1/2 configs. - Helper returns 0 for Llama, Qwen, Mistral, Cohere, Granite, and ambiguous configs raising ValueError. - Gemma-4-E2B GRPO step 1 loss=2.464e-08 kl=2.921e-05 (unchanged). - Llama-3.2-1B GRPO all 5 steps loss=0 kl=0 (no regression).
for more information, see https://pre-commit.ci
New step "MoE per-family coverage + GRPO patches + grouped_gemm AST" that hardens the matrix against the recurring MoE bug class behind unslothai/unsloth-zoo#624 / #612 / #607 / #601 and unslothai/unsloth #4934 / #3598. Five clusters of pytest cases inside one shim: 1. Per-MoE-family side-effect contract (8 parametrized cases): For each `patch_*_moe` in unsloth_zoo.temporary_patches.{qwen3_moe, qwen3_5_moe, qwen3_next_moe, qwen3_vl_moe, gemma4_moe, glm4_moe, deepseek_v3_moe, gpt_oss}, look up the transformers target classes, skip when none import on this matrix cell, run the patch fn, and assert at least one importable target now carries an unsloth "patched" marker. Accepts five marker conventions used across the codebase (_unsloth_already_patched, _unsloth_lora_patched, _unsloth_lora_extractor_fn, _original_<modeling_tail>_<cls>_forward, plain _original_forward). Surfaces silent early-returns (PR #612) that escape the registration-coverage test. gpt_oss specifically reads UNSLOTH_MODEL_NAME and only runs on transformers >= 5; the shim sets the env var via monkeypatch and skips on the 4.57.6 cell with a documented reason. 2. PR #4934 (TRL 1.0 GRPO disable_gradient_checkpointing): rebinding contract. After patch_trl_disable_gradient_checkpointing(), the no-op decorated function MUST be the symbol on trl.models.utils AND every trl.* module that imported it by reference. Skips on TRL < 1.0 (no symbol present). 3. PR #3598 (gradient_accumulation): patch_gradient_accumulation_fix on a vanilla transformers.Trainer must run cleanly without raising AND be idempotent. Catches future double-scale or import-injection regressions in the source rewriter. 4. unsloth/kernels/moe/grouped_gemm AST smoke: walks every .py under the directory (12 files) and asserts ast.parse succeeds. Triton kernels are GPU-only at runtime, but a syntax error in source surfaces as ImportError on every install. Also sanity-checks the directory layout (interface.py, kernels/forward.py, kernels/backward.py, reference/moe_block.py, reference/moe_ops.py must exist). Local verification on host TRL 0.25.1 + transformers 4.57.6: 4 pass (qwen3_moe, qwen3_vl_moe, GRPO disable-GC, grad-accum, grouped_gemm AST), 7 skip legitimately (qwen3_5/qwen3_next/gemma4/glm4/deepseek/ gpt_oss absent or version-gated). Wall-time ~10s on host; budget ~30-60s per matrix cell.
Summary
Fixes Gemma-4 GRPO training diverging with KL ~10^12 at step 1 against TRL 1.0.0+, by adding two runtime patches to the existing TRL/model patch flow. Both patches are no-ops for models and TRL versions that are not affected.
The bugs
Bug 1 (primary): TRL
disable_gradient_checkpointingoverwrites Unsloth's custom GC wrapperTRL 1.0.0+ wraps generation in:
The toggle exists only to suppress a cosmetic PyTorch warning (
None of the inputs have requires_grad=True). Insidetorch.no_grad()the gradient checkpointing state has no functional effect on the forward pass.On context exit, TRL calls
model.gradient_checkpointing_enable(...)which dispatches to HuggingFace's generic implementation and overwrites Unsloth's customuse_gradient_checkpointing="unsloth"wrapper. For Gemma-4 (and likely other models) this corrupts the forward numerics enough to make the training-step forward diverge from the reference forward, producing KL ~10^12 at step 1.Bug 2 (secondary):
final_logit_softcappinglookup misses for multimodal Gemma-4UnslothGRPOTrainerreadsgetattr(model.config, "final_logit_softcapping", 0). ForGemma4ForConditionalGenerationthe attribute lives only on the nestedGemma4TextConfig, so the lookup silently defaults to0instead of30. Both ref and policy paths hit the same bug for LoRA so KL cancels, but full fine-tuning with a separateref_modelproduces numerically incorrect logps.The fix
Fix 1:
unsloth/models/rl.py- newpatch_trl_disable_gradient_checkpointing()Replaces
trl.models.utils.disable_gradient_checkpointingwith a no-op context manager. The patch dynamically walkssys.modulesfor anytrl.*module that already imported the symbol by reference and rebinds it, so it picks up:trl.trainer.grpo_trainertrl.trainer.dpo_trainertrl.trainer.rloo_trainertrl.experimental.dppo.dppo_trainertrl.experimental.gfpo.gfpo_trainertrl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainerThe patch is wired into
PatchFastRLBEFOREpatch_trl_rl_trainersso the compiled cache picks up the noop at itsfrom trl.trainer.grpo_trainer import disable_gradient_checkpointingbinding time.Fix 2:
unsloth/models/vision.py- injectfinal_logit_softcappingfromtext_configIn
FastBaseModel.from_pretrained, after the model is loaded, liftsfinal_logit_softcappingfromconfig.text_config(orconfig.get_text_config()) to the top-levelmodel.configif and only if the top-level config does not already expose it. Skips silently for models that already have it or do not use softcapping.Backwards compatibility
disable_gradient_checkpointingsymbol does not exist. Thehasattrguard early-returns. Verified by installing trl 0.22.2 in a clean venv and inspecting the symbol.unsloth/gemma-4-E2B-itGRPO with TRL 1.0.0 and transformers 5.5.0.text_config).text_config.final_logit_softcappingisNone).Test plan
hasattrearly-return pathsys.moduleswalker coversgrpo_trainer,dpo_trainer,rloo_trainer,dppo_trainer,gfpo_trainer,grpo_with_replay_buffer_trainerfinal_logit_softcappingis30.0afterFastModel.from_pretrainedand survivesget_peft_modelEmpirical numbers