Respect GC for GRPO#5269
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the grpo_trainer__generate_and_score_completions function in unsloth/models/rl_replacements.py to ensure that the for_training method call includes the use_gradient_checkpointing parameter, which is retrieved from the trainer's arguments. This modification ensures that gradient checkpointing settings are correctly applied during the generation and scoring phase of the GRPO trainer. I have no feedback to provide as there were no review comments to evaluate.
|
This PR appears to address open issue(s). The duplicate detector matched the following open issues with HIGH confidence:
If this PR fixes any of them, consider adding |
|
Possible duplicate of a trusted maintainer's PR. This PR looks like it solves the same underlying problem as unslothai/unsloth#4934 by @danielhanchen (trusted maintainer).
Canonical PR summary: This PR fixes Gemma-4 GRPO instability with newer TRL by preserving Unsloth gradient checkpointing during generation and correctly reading final logit softcapping from nested text configs. It patches TRL’s checkpoint-disabling context manager to a no-op and adds a softcap lookup helper for GRPO log-prob computation. The auto-review is still running against this PR — reviewers will factor in the canonical above. If this PR is genuinely different, call out the delta in the review discussion so the maintainer can decide which to merge. |
…stores Two sibling generation paths put the model into inference mode and then unconditionally restored training with the for_training default, which re-enabled gradient checkpointing even when the caller had it disabled: - unsloth/models/rl.py: unsloth_unwrap_model_for_generation, installed onto every TRL *_trainer module that exposes unwrap_model_for_generation. - unsloth/models/llama.py: unsloth_fast_generate, bound onto model.generate. Snapshot the active gradient_checkpointing state from the model modules before for_inference clears it, then thread the snapshot through the matching for_training call. Same one-line restore semantics already used by prepare_for_training_mode and the GRPO replacement at rl_replacements.py. The for_training(...) call on each line is preserved; only the kwarg is added. The pre-existing post-generate guards (the conditional restore in unsloth_fast_generate and the finally restore in unsloth_unwrap_model_for_generation) continue to run unchanged.
…n restores
Two follow-ups to the post-generate gradient_checkpointing restore:
1. unsloth/models/rl.py: TRL's _unwrap_model_for_generation calls
unwrapped_model.gradient_checkpointing_disable() before yielding
(trl/models/utils.py:124-127 in 0.22.2, 0.27.1, and 1.3.0). The
previous snapshot was taken inside the with-block and therefore read
the post-disable state, restoring for_training with
use_gradient_checkpointing=False even when the caller had it on. Move
the snapshot above the with-block so it observes the caller's
pre-disable configuration.
2. unsloth/models/{rl.py,llama.py}: any(getattr(m, "gradient_checkpointing"))
collapses Unsloth's smart-GC mode value "unsloth" (a documented loader
default at unsloth/models/_utils.py:212 and unsloth/models/llama.py
2824/3314, loader.py:248/854) into a plain True. After generation, the
restore would silently downgrade "unsloth" smart GC to standard HF GC.
Replace any() with a value-preserving next((v for ... if v), False) so
the actual mode value survives the round-trip.
The for_training(...) calls on each line are preserved; only the snapshot
expression and its position change. The pre-existing post-generate restore
guards continue to run unchanged.
for more information, see https://pre-commit.ci
|
Auto-review verdict: Approved PR makes GRPO's post-generation training-mode restore respect the user's gradient_checkpointing setting instead of forcing it back to True; the review extended the same correctness fix to unsloth_unwrap_model_for_generation (rl.py) and unsloth_fast_generate (llama.py), with the snapshot taken before TRL's unwrap CM disables GC and using a value-preserving form so the documented use_gradient_checkpointing='unsloth' smart-GC mode is not silently downgraded. Reason: GRPO GC bug correctly fixed; review-found parallel-path leaks in unsloth_unwrap_model_for_generation and unsloth_fast_generate addressed and merged back cleanly to datta0/fix_gc_grpo. |
* Respect GC for GRPO
* Preserve gradient_checkpointing across post-generate training-mode restores
Two sibling generation paths put the model into inference mode and then
unconditionally restored training with the for_training default, which
re-enabled gradient checkpointing even when the caller had it disabled:
- unsloth/models/rl.py: unsloth_unwrap_model_for_generation, installed
onto every TRL *_trainer module that exposes unwrap_model_for_generation.
- unsloth/models/llama.py: unsloth_fast_generate, bound onto model.generate.
Snapshot the active gradient_checkpointing state from the model modules
before for_inference clears it, then thread the snapshot through the
matching for_training call. Same one-line restore semantics already used
by prepare_for_training_mode and the GRPO replacement at rl_replacements.py.
The for_training(...) call on each line is preserved; only the kwarg is
added. The pre-existing post-generate guards (the conditional restore in
unsloth_fast_generate and the finally restore in
unsloth_unwrap_model_for_generation) continue to run unchanged.
* Snapshot pre-disable, preserve unsloth smart-GC mode across generation restores
Two follow-ups to the post-generate gradient_checkpointing restore:
1. unsloth/models/rl.py: TRL's _unwrap_model_for_generation calls
unwrapped_model.gradient_checkpointing_disable() before yielding
(trl/models/utils.py:124-127 in 0.22.2, 0.27.1, and 1.3.0). The
previous snapshot was taken inside the with-block and therefore read
the post-disable state, restoring for_training with
use_gradient_checkpointing=False even when the caller had it on. Move
the snapshot above the with-block so it observes the caller's
pre-disable configuration.
2. unsloth/models/{rl.py,llama.py}: any(getattr(m, "gradient_checkpointing"))
collapses Unsloth's smart-GC mode value "unsloth" (a documented loader
default at unsloth/models/_utils.py:212 and unsloth/models/llama.py
2824/3314, loader.py:248/854) into a plain True. After generation, the
restore would silently downgrade "unsloth" smart GC to standard HF GC.
Replace any() with a value-preserving next((v for ... if v), False) so
the actual mode value survives the round-trip.
The for_training(...) calls on each line are preserved; only the snapshot
expression and its position change. The pre-existing post-generate restore
guards continue to run unchanged.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
No description provided.