Skip to content

Respect GC for GRPO#5269

Merged
danielhanchen merged 6 commits into
unslothai:mainfrom
Datta0:fix_gc_grpo
May 22, 2026
Merged

Respect GC for GRPO#5269
danielhanchen merged 6 commits into
unslothai:mainfrom
Datta0:fix_gc_grpo

Conversation

@Datta0
Copy link
Copy Markdown
Collaborator

@Datta0 Datta0 commented May 4, 2026

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the grpo_trainer__generate_and_score_completions function in unsloth/models/rl_replacements.py to ensure that the for_training method call includes the use_gradient_checkpointing parameter, which is retrieved from the trainer's arguments. This modification ensures that gradient checkpointing settings are correctly applied during the generation and scoring phase of the GRPO trainer. I have no feedback to provide as there were no review comments to evaluate.

@Datta0 Datta0 marked this pull request as ready for review May 6, 2026 05:46
@danielhanchen danielhanchen added the auto-addresses-issue Pre-flight: appears to address an open issue label May 6, 2026
@danielhanchen
Copy link
Copy Markdown
Member

This PR appears to address open issue(s). The duplicate detector matched the following open issues with HIGH confidence:

  • unslothai/unsloth#4886@jonahsamost — GRPO gradient checkpointing bug; PR makes for_training respect args.gradient_checkpointing, avoiding inconsistent GC enablement during GRPO transitions.
  • unslothai/unsloth#3828@nafee-ahmed — GRPO evaluation crash is explicitly tied to Unsloth gradient checkpointing; PR changes GRPO training-mode transition to honor gradient_checkpointing.

If this PR fixes any of them, consider adding closes #N / resolves #N to the description so the issue auto-closes on merge. If the match is wrong, ignore this comment.

@danielhanchen danielhanchen added the auto-has-duplicate Pre-flight: similar to a trusted maintainer's PR label May 6, 2026
@danielhanchen
Copy link
Copy Markdown
Member

Possible duplicate of a trusted maintainer's PR. This PR looks like it solves the same underlying problem as unslothai/unsloth#4934 by @danielhanchen (trusted maintainer).

Same GRPO training-mode restoration fix in rl_replacements.py: generation preserves configured gradient_checkpointing instead of defaulting behavior.

Canonical PR summary: This PR fixes Gemma-4 GRPO instability with newer TRL by preserving Unsloth gradient checkpointing during generation and correctly reading final logit softcapping from nested text configs. It patches TRL’s checkpoint-disabling context manager to a no-op and adds a softcap lookup helper for GRPO log-prob computation.

The auto-review is still running against this PR — reviewers will factor in the canonical above. If this PR is genuinely different, call out the delta in the review discussion so the maintainer can decide which to merge.

@danielhanchen danielhanchen added the auto-reviewing Auto-review in progress label May 6, 2026
…stores

Two sibling generation paths put the model into inference mode and then
unconditionally restored training with the for_training default, which
re-enabled gradient checkpointing even when the caller had it disabled:

- unsloth/models/rl.py: unsloth_unwrap_model_for_generation, installed
  onto every TRL *_trainer module that exposes unwrap_model_for_generation.
- unsloth/models/llama.py: unsloth_fast_generate, bound onto model.generate.

Snapshot the active gradient_checkpointing state from the model modules
before for_inference clears it, then thread the snapshot through the
matching for_training call. Same one-line restore semantics already used
by prepare_for_training_mode and the GRPO replacement at rl_replacements.py.

The for_training(...) call on each line is preserved; only the kwarg is
added. The pre-existing post-generate guards (the conditional restore in
unsloth_fast_generate and the finally restore in
unsloth_unwrap_model_for_generation) continue to run unchanged.
…n restores

Two follow-ups to the post-generate gradient_checkpointing restore:

1. unsloth/models/rl.py: TRL's _unwrap_model_for_generation calls
   unwrapped_model.gradient_checkpointing_disable() before yielding
   (trl/models/utils.py:124-127 in 0.22.2, 0.27.1, and 1.3.0). The
   previous snapshot was taken inside the with-block and therefore read
   the post-disable state, restoring for_training with
   use_gradient_checkpointing=False even when the caller had it on. Move
   the snapshot above the with-block so it observes the caller's
   pre-disable configuration.

2. unsloth/models/{rl.py,llama.py}: any(getattr(m, "gradient_checkpointing"))
   collapses Unsloth's smart-GC mode value "unsloth" (a documented loader
   default at unsloth/models/_utils.py:212 and unsloth/models/llama.py
   2824/3314, loader.py:248/854) into a plain True. After generation, the
   restore would silently downgrade "unsloth" smart GC to standard HF GC.
   Replace any() with a value-preserving next((v for ... if v), False) so
   the actual mode value survives the round-trip.

The for_training(...) calls on each line are preserved; only the snapshot
expression and its position change. The pre-existing post-generate restore
guards continue to run unchanged.
@danielhanchen danielhanchen requested a review from mmathew23 as a code owner May 6, 2026 13:30
@danielhanchen danielhanchen added auto-approved Auto-review approved the PR and removed auto-reviewing Auto-review in progress labels May 6, 2026
@danielhanchen
Copy link
Copy Markdown
Member

Auto-review verdict: Approved

PR makes GRPO's post-generation training-mode restore respect the user's gradient_checkpointing setting instead of forcing it back to True; the review extended the same correctness fix to unsloth_unwrap_model_for_generation (rl.py) and unsloth_fast_generate (llama.py), with the snapshot taken before TRL's unwrap CM disables GC and using a value-preserving form so the documented use_gradient_checkpointing='unsloth' smart-GC mode is not silently downgraded.

Reason: GRPO GC bug correctly fixed; review-found parallel-path leaks in unsloth_unwrap_model_for_generation and unsloth_fast_generate addressed and merged back cleanly to datta0/fix_gc_grpo.

@danielhanchen danielhanchen merged commit 867fe63 into unslothai:main May 22, 2026
39 checks passed
rsd-darshan pushed a commit to rsd-darshan/unsloth that referenced this pull request Jun 3, 2026
* Respect GC for GRPO

* Preserve gradient_checkpointing across post-generate training-mode restores

Two sibling generation paths put the model into inference mode and then
unconditionally restored training with the for_training default, which
re-enabled gradient checkpointing even when the caller had it disabled:

- unsloth/models/rl.py: unsloth_unwrap_model_for_generation, installed
  onto every TRL *_trainer module that exposes unwrap_model_for_generation.
- unsloth/models/llama.py: unsloth_fast_generate, bound onto model.generate.

Snapshot the active gradient_checkpointing state from the model modules
before for_inference clears it, then thread the snapshot through the
matching for_training call. Same one-line restore semantics already used
by prepare_for_training_mode and the GRPO replacement at rl_replacements.py.

The for_training(...) call on each line is preserved; only the kwarg is
added. The pre-existing post-generate guards (the conditional restore in
unsloth_fast_generate and the finally restore in
unsloth_unwrap_model_for_generation) continue to run unchanged.

* Snapshot pre-disable, preserve unsloth smart-GC mode across generation restores

Two follow-ups to the post-generate gradient_checkpointing restore:

1. unsloth/models/rl.py: TRL's _unwrap_model_for_generation calls
   unwrapped_model.gradient_checkpointing_disable() before yielding
   (trl/models/utils.py:124-127 in 0.22.2, 0.27.1, and 1.3.0). The
   previous snapshot was taken inside the with-block and therefore read
   the post-disable state, restoring for_training with
   use_gradient_checkpointing=False even when the caller had it on. Move
   the snapshot above the with-block so it observes the caller's
   pre-disable configuration.

2. unsloth/models/{rl.py,llama.py}: any(getattr(m, "gradient_checkpointing"))
   collapses Unsloth's smart-GC mode value "unsloth" (a documented loader
   default at unsloth/models/_utils.py:212 and unsloth/models/llama.py
   2824/3314, loader.py:248/854) into a plain True. After generation, the
   restore would silently downgrade "unsloth" smart GC to standard HF GC.
   Replace any() with a value-preserving next((v for ... if v), False) so
   the actual mode value survives the round-trip.

The for_training(...) calls on each line are preserved; only the snapshot
expression and its position change. The pre-existing post-generate restore
guards continue to run unchanged.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

auto-addresses-issue Pre-flight: appears to address an open issue auto-approved Auto-review approved the PR auto-has-duplicate Pre-flight: similar to a trusted maintainer's PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants