Skip to content

Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+#4934

Merged
danielhanchen merged 6 commits into
mainfrom
gemma4-grpo-trl1.0-compat
Apr 10, 2026
Merged

Fix Gemma-4 GRPO catastrophic KL divergence with TRL 1.0.0+#4934
danielhanchen merged 6 commits into
mainfrom
gemma4-grpo-trl1.0-compat

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

Fixes Gemma-4 GRPO training diverging with KL ~10^12 at step 1 against TRL 1.0.0+, by adding two runtime patches to the existing TRL/model patch flow. Both patches are no-ops for models and TRL versions that are not affected.

The bugs

Bug 1 (primary): TRL disable_gradient_checkpointing overwrites Unsloth's custom GC wrapper

TRL 1.0.0+ wraps generation in:

with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):

The toggle exists only to suppress a cosmetic PyTorch warning (None of the inputs have requires_grad=True). Inside torch.no_grad() the gradient checkpointing state has no functional effect on the forward pass.

On context exit, TRL calls model.gradient_checkpointing_enable(...) which dispatches to HuggingFace's generic implementation and overwrites Unsloth's custom use_gradient_checkpointing="unsloth" wrapper. For Gemma-4 (and likely other models) this corrupts the forward numerics enough to make the training-step forward diverge from the reference forward, producing KL ~10^12 at step 1.

Bug 2 (secondary): final_logit_softcapping lookup misses for multimodal Gemma-4

UnslothGRPOTrainer reads getattr(model.config, "final_logit_softcapping", 0). For Gemma4ForConditionalGeneration the attribute lives only on the nested Gemma4TextConfig, so the lookup silently defaults to 0 instead of 30. Both ref and policy paths hit the same bug for LoRA so KL cancels, but full fine-tuning with a separate ref_model produces numerically incorrect logps.

The fix

Fix 1: unsloth/models/rl.py - new patch_trl_disable_gradient_checkpointing()

Replaces trl.models.utils.disable_gradient_checkpointing with a no-op context manager. The patch dynamically walks sys.modules for any trl.* module that already imported the symbol by reference and rebinds it, so it picks up:

  • trl.trainer.grpo_trainer
  • trl.trainer.dpo_trainer
  • trl.trainer.rloo_trainer
  • trl.experimental.dppo.dppo_trainer
  • trl.experimental.gfpo.gfpo_trainer
  • trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainer
  • and any future TRL trainer module

The patch is wired into PatchFastRL BEFORE patch_trl_rl_trainers so the compiled cache picks up the noop at its from trl.trainer.grpo_trainer import disable_gradient_checkpointing binding time.

Fix 2: unsloth/models/vision.py - inject final_logit_softcapping from text_config

In FastBaseModel.from_pretrained, after the model is loaded, lifts final_logit_softcapping from config.text_config (or config.get_text_config()) to the top-level model.config if and only if the top-level config does not already expose it. Skips silently for models that already have it or do not use softcapping.

Backwards compatibility

TRL version Behavior
0.22.2 disable_gradient_checkpointing symbol does not exist. The hasattr guard early-returns. Verified by installing trl 0.22.2 in a clean venv and inspecting the symbol.
0.27.1 Same broken positional-arg pattern as 1.0.0. The noop replacement applies. Verified by installing trl 0.27.1 and running the patch logic against it.
1.0.0+ End-to-end verified on unsloth/gemma-4-E2B-it GRPO with TRL 1.0.0 and transformers 5.5.0.
Model Behavior
Gemma-4 (E2B/E4B) Fix 1 fixes the KL blow-up. Fix 2 lifts softcap=30 to the top-level config.
Llama / Qwen / text models Fix 1 is functionally identical (Unsloth's GC wrapper is preserved). Fix 2 is a no-op (no text_config).
Qwen3-VL and other VLMs without softcapping Fix 2 is a no-op (text_config.final_logit_softcapping is None).

Test plan

  • Gemma-4-E2B GRPO with TRL 1.0.0 + transformers 5.5.0: train 5 steps, expect step 1 KL near zero
  • Llama-3.2-1B GRPO with the patches applied: train 5 steps, expect identical numerics to baseline
  • Verify trl 0.22.2 hits the hasattr early-return path
  • Verify trl 0.27.1 patch logic replaces the symbol and the noop context manager works
  • Verify the dynamic sys.modules walker covers grpo_trainer, dpo_trainer, rloo_trainer, dppo_trainer, gfpo_trainer, grpo_with_replay_buffer_trainer
  • Confirm Gemma-4-E2B final_logit_softcapping is 30.0 after FastModel.from_pretrained and survives get_peft_model

Empirical numbers

Run Step 1 loss Step 1 KL
Gemma-4-E2B without fix 1.37e+06 1.76e+09
Gemma-4-E2B with fix 2.46e-08 2.92e-05
Llama-3.2-1B with fix 0 0

Two compounding bugs caused Gemma-4 GRPO training to diverge with KL ~10^12
at step 1 against TRL 1.0.0+. Both fixes are runtime patches in the existing
TRL/model patch flow and are no-ops for models and TRL versions that are not
affected.

Fix 1 (rl.py): replace trl.models.utils.disable_gradient_checkpointing with
a no-op context manager. TRL 1.0.0+ wraps generation in
`with torch.no_grad(), disable_gradient_checkpointing(self.model, ...):`
purely to suppress a cosmetic PyTorch warning ("None of the inputs have
requires_grad=True"). Inside torch.no_grad() the gradient checkpointing
state has no functional effect on the forward pass. On context exit, TRL
calls model.gradient_checkpointing_enable() which dispatches to HF's
generic implementation and overwrites Unsloth's custom
`use_gradient_checkpointing="unsloth"` wrapper, corrupting Gemma-4 forward
numerics. Replacing the toggle with a no-op preserves Unsloth's custom GC
wrapper across generation passes. The patch walks sys.modules dynamically
to also rebind the symbol on every trl.* module that already imported it
(grpo_trainer, dpo_trainer, rloo_trainer, dppo_trainer, gfpo_trainer,
grpo_with_replay_buffer_trainer, and any future trainer module).

Fix 2 (vision.py): inject `final_logit_softcapping` from `config.text_config`
into the top-level `model.config` for multimodal models. Unsloth's GRPO
trainer reads `getattr(model.config, "final_logit_softcapping", 0)` but
for Gemma-4 the attribute lives only on the nested `Gemma4TextConfig`,
so the lookup silently defaults to 0 instead of 30.

Backwards compatibility:
- trl 0.22.2: no `disable_gradient_checkpointing` symbol exists, the patch
  early-returns via `hasattr` guard.
- trl 0.27.1: same broken pattern as 1.0.0, the noop replacement is correct.
- trl 1.0.0+: end-to-end verified on `unsloth/gemma-4-E2B-it` GRPO with TRL
  1.0.0 and transformers 5.5.0. Step 1 loss=2.46e-08, kl=2.92e-05 (machine
  zero) vs broken baseline loss=1.37e+06, kl=1.76e+09.
- Llama / non-VLM text models: Fix 2 is a no-op (no `text_config`); Fix 1
  is functionally identical (Unsloth's GC wrapper is preserved).
- Qwen3-VL and other VLMs without final_logit_softcapping: Fix 2 is a no-op
  (text_config.final_logit_softcapping is None).
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a patch to disable TRL's disable_gradient_checkpointing context manager with a no-op, preventing numerical issues with Unsloth's custom gradient checkpointing, particularly for models like Gemma-4. This patch is applied early in PatchFastRL. Additionally, logic is added to from_pretrained for vision models to ensure final_logit_softcapping is correctly exposed at the top-level model.config for multimodal models, which is essential for the GRPO trainer. The review comments suggest improving error handling by logging exceptions instead of silently passing them in the newly added code blocks, adhering to the rule against broad, silent exception handlers.

Comment thread unsloth/models/rl.py
Comment on lines +2025 to +2026
setattr(
_mod,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Silently passing on exceptions here can hide potential issues where the monkey-patching fails. As per the rule 'Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.', it's better to log these exceptions, even at a debug level, to help with future troubleshooting. This will make the patching process more transparent without crashing on unexpected errors.

Suggested change
setattr(
_mod,
except Exception as e:
logger.debug(f"Unsloth: Could not patch disable_gradient_checkpointing on {_mod_name}: {e}")
References
  1. Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.

Comment thread unsloth/models/vision.py Outdated
Comment on lines +1105 to +1131
try:
_top_config = model.config
if getattr(_top_config, "final_logit_softcapping", None) is None:
_softcap = None
_text_cfg = getattr(_top_config, "text_config", None)
if _text_cfg is not None:
_softcap = getattr(
_text_cfg,
"final_logit_softcapping",
None,
)
if _softcap is None:
_get_text = getattr(_top_config, "get_text_config", None)
if callable(_get_text):
try:
_softcap = getattr(
_get_text(),
"final_logit_softcapping",
None,
)
except Exception:
pass
if _softcap is not None:
try:
setattr(_top_config, "final_logit_softcapping", _softcap)
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block uses multiple nested try...except Exception: pass statements, which can make the code harder to debug and understand. As per the rule 'Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.', it's better to flatten the logic and log exceptions when they occur. This will improve maintainability and help diagnose issues if this patching logic fails for certain model configurations.

            try:
                _top_config = model.config
                if getattr(_top_config, "final_logit_softcapping", None) is None:
                    _softcap = None
                    _text_cfg = getattr(_top_config, "text_config", None)
                    if _text_cfg is not None:
                        _softcap = getattr(
                            _text_cfg, "final_logit_softcapping", None,
                        )
                    if _softcap is None:
                        _get_text = getattr(_top_config, "get_text_config", None)
                        if callable(_get_text):
                            try:
                                _softcap = getattr(
                                    _get_text(),
                                    "final_logit_softcapping",
                                    None,
                                )
                            except Exception as e:
                                logger.debug(f"Unsloth: Failed to get text_config for softcapping: {e}")
                    if _softcap is not None:
                        try:
                            setattr(_top_config, "final_logit_softcapping", _softcap)
                        except Exception as e:
                            logger.warning(f"Unsloth: Could not set final_logit_softcapping on model config: {e}")
            except Exception as e:
                logger.warning(f"Unsloth: An unexpected error occurred during final_logit_softcapping patch: {e}")
References
  1. Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.

@Datta0
Copy link
Copy Markdown
Collaborator

Datta0 commented Apr 9, 2026

Yeah I too noticed trl messing up with gradient checkpointing. This seems like a necessary change

danielhanchen and others added 4 commits April 9, 2026 15:18
- Move Fix 2 from vision.py to rl_replacements.py:858 and :1110 at the
  actual consumer sites. This avoids mutating model.config (which could
  leak into save_pretrained output) and covers text-only Gemma-4 paths
  that do not flow through FastBaseModel.from_pretrained.
- Revert the vision.py injection block entirely.
- Narrow the bare except blocks in patch_trl_disable_gradient_checkpointing
  from `except Exception:` to `(AttributeError, ImportError)` and
  `(AttributeError, TypeError)` to avoid masking unrelated bugs.
- Add logger.warning_once when the noop patch is installed, matching
  patch_trl_openenv and patch_trl_vllm_generation convention.
- Remove the dead per-module `_unsloth_noop_patched` sentinel check inside
  the sys.modules walk. The function-level early return already covers
  this case.
- Move `import sys` and `from contextlib import contextmanager` to the
  module-level imports instead of inside the function body.
- Rewrite the ordering comment in PatchFastRL to accurately describe
  why patch_trl_disable_gradient_checkpointing must run before
  patch_trl_rl_trainers.
- Fix keyword default spacing to match surrounding rl.py style.

End-to-end verified: Gemma-4-E2B GRPO on TRL 1.0.0 + transformers 5.5.0
step 1 loss=2.464e-08 kl=2.921e-05, all 5 steps succeed.
Extract the final_logit_softcapping fallback logic into a shared helper
`_unsloth_get_final_logit_softcapping(config)` defined in rl_replacements.py
and injected into the compiled cache via RL_PRE_ITEMS["grpo_trainer"]. Both
call sites (`grpo_trainer__generate_and_score_completions` and
`grpo_trainer_compute_loss`) now use the helper instead of inlining the
same text_config fallback block twice.

Verified: compiled cache file lists the helper at module scope and both
consumer sites call it. Gemma-4-E2B GRPO step 1 loss=2.464e-08 kl=2.921e-05
(unchanged), all 5 steps pass.
Extend _unsloth_get_final_logit_softcapping to also fall back to
config.get_text_config() for composite configs such as T5GemmaConfig
where the text sub-config is not exposed via the text_config attribute
but only via the get_text_config() method. Guard against (TypeError,
ValueError) raised by ambiguous composite configs, and skip the
self-referential case where get_text_config() returns self.

This addresses the 6/7 reviewer consensus from the third review loop.

Verified:
- Helper returns 30.0 for Gemma-4, T5Gemma, and Gemma 1/2 configs.
- Helper returns 0 for Llama, Qwen, Mistral, Cohere, Granite, and
  ambiguous configs raising ValueError.
- Gemma-4-E2B GRPO step 1 loss=2.464e-08 kl=2.921e-05 (unchanged).
- Llama-3.2-1B GRPO all 5 steps loss=0 kl=0 (no regression).
@danielhanchen danielhanchen merged commit 53af4a1 into main Apr 10, 2026
5 checks passed
@danielhanchen danielhanchen deleted the gemma4-grpo-trl1.0-compat branch April 10, 2026 14:58
danielhanchen added a commit that referenced this pull request May 7, 2026
New step "MoE per-family coverage + GRPO patches + grouped_gemm AST"
that hardens the matrix against the recurring MoE bug class behind
unslothai/unsloth-zoo#624 / #612 / #607 / #601 and unslothai/unsloth
#4934 / #3598. Five clusters of pytest cases inside one shim:

1. Per-MoE-family side-effect contract (8 parametrized cases):
   For each `patch_*_moe` in unsloth_zoo.temporary_patches.{qwen3_moe,
   qwen3_5_moe, qwen3_next_moe, qwen3_vl_moe, gemma4_moe, glm4_moe,
   deepseek_v3_moe, gpt_oss}, look up the transformers target classes,
   skip when none import on this matrix cell, run the patch fn, and
   assert at least one importable target now carries an unsloth
   "patched" marker. Accepts five marker conventions used across the
   codebase (_unsloth_already_patched, _unsloth_lora_patched,
   _unsloth_lora_extractor_fn, _original_<modeling_tail>_<cls>_forward,
   plain _original_forward). Surfaces silent early-returns (PR #612)
   that escape the registration-coverage test.

   gpt_oss specifically reads UNSLOTH_MODEL_NAME and only runs on
   transformers >= 5; the shim sets the env var via monkeypatch and
   skips on the 4.57.6 cell with a documented reason.

2. PR #4934 (TRL 1.0 GRPO disable_gradient_checkpointing): rebinding
   contract. After patch_trl_disable_gradient_checkpointing(), the
   no-op decorated function MUST be the symbol on
   trl.models.utils AND every trl.* module that imported it by
   reference. Skips on TRL < 1.0 (no symbol present).

3. PR #3598 (gradient_accumulation): patch_gradient_accumulation_fix
   on a vanilla transformers.Trainer must run cleanly without raising
   AND be idempotent. Catches future double-scale or import-injection
   regressions in the source rewriter.

4. unsloth/kernels/moe/grouped_gemm AST smoke: walks every .py under
   the directory (12 files) and asserts ast.parse succeeds. Triton
   kernels are GPU-only at runtime, but a syntax error in source
   surfaces as ImportError on every install. Also sanity-checks the
   directory layout (interface.py, kernels/forward.py,
   kernels/backward.py, reference/moe_block.py, reference/moe_ops.py
   must exist).

Local verification on host TRL 0.25.1 + transformers 4.57.6: 4 pass
(qwen3_moe, qwen3_vl_moe, GRPO disable-GC, grad-accum, grouped_gemm
AST), 7 skip legitimately (qwen3_5/qwen3_next/gemma4/glm4/deepseek/
gpt_oss absent or version-gated). Wall-time ~10s on host; budget
~30-60s per matrix cell.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants