diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 30546a048d..5c61a6eae2 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1232,7 +1232,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Unsloth gradient checkpointing requires use_reentrant=True, so we remove # the setting after super().__init__() when it gets auto-applied. RLConfig_post = "" - if trl_version >= Version("0.27.0") and RLConfig_name == "GRPOConfig": + if trl_version >= Version("0.27.0"): RLConfig_post = ( " # Unsloth: Remove use_reentrant=False forced by TRL 0.27.0+\n" " if getattr(self, 'gradient_checkpointing_kwargs', None) is not None:\n"