From acf90f0b3e6f592f89092f8860b080ddc1b56d8b Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 26 Jan 2026 12:35:00 -0500 Subject: [PATCH 01/13] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 42 ++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ff36da125d..5b6c7b870d 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -914,7 +914,7 @@ def compute_loss( max_left_pad = inputs.get("max_left_pad", 0) if per_token_logps is not None: - loss, completion_length, mean_kl, delta, flat_is_ratio = ( + loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( grpo_compute_loss_slow( ref_logps, per_token_logps, @@ -944,7 +944,7 @@ def compute_loss( ) else: if hasattr(self.args, "loss_type"): - loss, completion_length, mean_kl, delta, flat_is_ratio = ( + loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( grpo_accumulated_loss( trainer = self, input_ids = _input_ids, @@ -976,7 +976,7 @@ def compute_loss( ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 - loss, completion_length, mean_kl = grpo_accumulated_loss( + loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, logits_to_keep = logits_to_keep, @@ -991,7 +991,6 @@ def compute_loss( logit_scale_divide = logit_scale_divide, attention_mask = attention_mask, ) - if "train" in self._metrics: mode = "eval" if self.control.should_evaluate else "train" self._metrics[mode]["completion_length"].append(completion_length.item()) @@ -1053,8 +1052,43 @@ def compute_loss( .item() ) + completion_token_count = completion_mask.sum().clamp(min=1.0) + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * completion_mask).sum() / completion_token_count + #breakpoint() + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + + if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + elif self.loss_type == "cispo": + is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) + cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) + gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) + self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) + return loss + function = inspect.getsource(compute_loss) return function From 707fa3805741cdedb30bd223b8fe2249bce29a44 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 26 Jan 2026 15:03:34 -0500 Subject: [PATCH 02/13] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 54 +++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5b6c7b870d..d41d00ae71 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -27,6 +27,7 @@ from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding from unsloth_zoo.utils import Version +from trl import __version__ as trl_version_raw from importlib.metadata import version as importlib_version from unsloth_zoo.log import logger import importlib.util @@ -56,6 +57,13 @@ "triton.cudagraphs": False, } +try: + trl_version = Version(trl_version_raw) +except Exception: + try: + trl_version = Version(importlib_version("trl")) + except Exception: + trl_version = Version("0.0.0") # Check untrained tokens def sft_trainer_fix_untrained_tokens(call_args, extra_args): @@ -220,6 +228,41 @@ def grpo_trainer__prepare_inputs(function_name, function): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) +if trl_version >= Version("0.27.0"): + def grpo_trainer__calculate_rewards(function_name, function): + if function_name != "_calculate_rewards": + return function + # For TRL 0.27.0 where the function has weird chat template logic + target_line = ' reward_kwargs["trainer_state"] = self.state' + + replacement_block = """ reward_kwargs["trainer_state"] = self.state + + # --- Monkey Patch: Batch Decode for Rewards --- + try: + reward_kwargs["completions_text"] = [ + self.processing_class.apply_chat_template( + completion, + tokenize=False, + ) + for completion in completions + ] + + reward_kwargs["prompts_text"] = [ + self.processing_class.apply_chat_template( + prompt, + tokenize=False, + ) + for prompt in prompts + ] + except Exception as e: + # Fallback or logging if chat template application fails (e.g. non-conversational format) + pass + """ + function = function.replace(target_line, replacement_block) + return function + + + RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__calculate_rewards) # Remove collective RPC of reload weights from generate # trl added reload weights (potentially for quantized models), we don't need it for our use case (LoRA primarily) @@ -390,6 +433,13 @@ def grpo_trainer__generate_and_score_completions(function_name, function): function = function.replace(string_to_find, replacement_string) + if trl_version >= Version("0.27.0"): + # We replace the call using 'completions' with one using 'completions_text' + string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)" + replacement_string = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions_text, completion_ids_list)" + + function = function.replace(string_to_find, replacement_string) + if "wake_up()" not in function: # Sleep functionality has been added to trl in v0.23.0. We do not want to redo this. # https://github.com/huggingface/trl/commit/edbe8234bc7e528f72ac76607de9d3e4753e2709 @@ -911,7 +961,7 @@ def compute_loss( logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite if logit_scale_divide is None: logit_scale_divide = 0 - + max_left_pad = inputs.get("max_left_pad", 0) if per_token_logps is not None: loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( @@ -1058,7 +1108,7 @@ def masked_batch_mean(x): return x.mean() else: return (x * completion_mask).sum() / completion_token_count - #breakpoint() + if advantages.dim() == 1: advantages = advantages.unsqueeze(1) From 44beb3039fdbbbdc2e9a3f89fbd38cf7c826162b Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 2 Feb 2026 15:41:49 -0500 Subject: [PATCH 03/13] Update rl.py --- unsloth/models/rl.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 803153e608..8f1b4c86a2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1109,7 +1109,22 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source = re.sub( pattern, new_options, RLTrainer_source, flags = re.DOTALL ) + + if trl_version >= Version("0.26.0"): + peft_block_pattern = ( + r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:" + r".*?" + r"param\.data = param\.data\.to\(torch\.bfloat16\)" + ) + RLTrainer_source = re.sub( + peft_block_pattern, + "\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n", + RLTrainer_source, + flags=re.DOTALL + ) + + if RLTrainer_name == "SFTTrainer": original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]' new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]' From 48c3c1925d85a56384a2b367c1f25fdf620f45a9 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 2 Feb 2026 15:44:32 -0500 Subject: [PATCH 04/13] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9fe62b4032..07419fa7b6 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -443,10 +443,10 @@ def grpo_trainer__generate_and_score_completions(function_name, function): function = function.replace(string_to_find, replacement_string) - if trl_version >= Version("0.27.0"): + if trl_version >= Version("0.25.0"): # We replace the call using 'completions' with one using 'completions_text' string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)" - replacement_string = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions_text, completion_ids_list)" + replacement_string = " rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)" function = function.replace(string_to_find, replacement_string) From c2a6fc1b4fb8f4a67c0f5178c6ad9e0cd0eced78 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 2 Feb 2026 16:32:23 -0500 Subject: [PATCH 05/13] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 36 ------------------------------- 1 file changed, 36 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 07419fa7b6..349eea8933 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -238,42 +238,6 @@ def grpo_trainer__prepare_inputs(function_name, function): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) -if trl_version >= Version("0.27.0"): - def grpo_trainer__calculate_rewards(function_name, function): - if function_name != "_calculate_rewards": - return function - # For TRL 0.27.0 where the function has weird chat template logic - target_line = ' reward_kwargs["trainer_state"] = self.state' - - replacement_block = """ reward_kwargs["trainer_state"] = self.state - - # --- Monkey Patch: Batch Decode for Rewards --- - try: - reward_kwargs["completions_text"] = [ - self.processing_class.apply_chat_template( - completion, - tokenize=False, - ) - for completion in completions - ] - - reward_kwargs["prompts_text"] = [ - self.processing_class.apply_chat_template( - prompt, - tokenize=False, - ) - for prompt in prompts - ] - except Exception as e: - # Fallback or logging if chat template application fails (e.g. non-conversational format) - pass - """ - function = function.replace(target_line, replacement_block) - return function - - - RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__calculate_rewards) - # Remove collective RPC of reload weights from generate # trl added reload weights (potentially for quantized models), we don't need it for our use case (LoRA primarily) # https://github.com/huggingface/trl/commit/7856d3b1f6518601732f489883b341bb6dd36434#diff-964e6fd373aa93037604064cb2b822d7f8e2735e33f791065acf2c4c3552d393R1168-R1169 From b470807d465743c0a84bfb1d8b08a4b6035e4cfb Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 2 Feb 2026 19:48:37 -0500 Subject: [PATCH 06/13] Update rl.py --- unsloth/models/rl.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 8f1b4c86a2..038df86dd1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -476,6 +476,20 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): old_RLTrainer_source = inspect.getsource(RLTrainer) old_RLConfig_source = inspect.getsource(RLConfig) + if trl_version >= Version("0.27.0"): + checkpoint_pattern = ( + r"\s*if self\.gradient_checkpointing and Version\(transformers\.__version__\) < Version\(\"5\.0\.0\"\):" + r".*?" + r"self\.gradient_checkpointing_kwargs\.setdefault\(\"use_reentrant\", False\)" + ) + + old_RLConfig_source = re.sub( + checkpoint_pattern, + "\n # Gradient checkpointing version check removed via script\n", + old_RLConfig_source, + flags=re.DOTALL + ) + all_imports = dir(trainer) # Fix _deprecate_arguments not getting imported so stop __ but not _ imports = [x for x in all_imports if not x.startswith("__")] From e09f7165d761b30088019bec9444fee648080206 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 2 Feb 2026 19:49:50 -0500 Subject: [PATCH 07/13] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 038df86dd1..5a01c7090e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -476,7 +476,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): old_RLTrainer_source = inspect.getsource(RLTrainer) old_RLConfig_source = inspect.getsource(RLConfig) - if trl_version >= Version("0.27.0"): + if trl_version >= Version("0.27.0") and RLConfig_name == "GRPOConfig": checkpoint_pattern = ( r"\s*if self\.gradient_checkpointing and Version\(transformers\.__version__\) < Version\(\"5\.0\.0\"\):" r".*?" From 58a1fd455a76b0dbc55b350833d60e876576d8b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 03:44:29 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl.py | 23 ++++++++++----------- unsloth/models/rl_replacements.py | 34 ++++++++++++++++++++++--------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5a01c7090e..dacfdb5c56 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -487,9 +487,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): checkpoint_pattern, "\n # Gradient checkpointing version check removed via script\n", old_RLConfig_source, - flags=re.DOTALL + flags = re.DOTALL, ) - + all_imports = dir(trainer) # Fix _deprecate_arguments not getting imported so stop __ but not _ imports = [x for x in all_imports if not x.startswith("__")] @@ -1123,22 +1123,21 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source = re.sub( pattern, new_options, RLTrainer_source, flags = re.DOTALL ) - + if trl_version >= Version("0.26.0"): peft_block_pattern = ( - r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:" - r".*?" - r"param\.data = param\.data\.to\(torch\.bfloat16\)" + r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:" + r".*?" + r"param\.data = param\.data\.to\(torch\.bfloat16\)" ) RLTrainer_source = re.sub( - peft_block_pattern, - "\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n", - RLTrainer_source, - flags=re.DOTALL + peft_block_pattern, + "\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n", + RLTrainer_source, + flags = re.DOTALL, ) - - + if RLTrainer_name == "SFTTrainer": original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]' new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]' diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 349eea8933..dfff8b8373 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -65,6 +65,7 @@ except Exception: trl_version = Version("0.0.0") + # Check untrained tokens def sft_trainer_fix_untrained_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: @@ -238,6 +239,7 @@ def grpo_trainer__prepare_inputs(function_name, function): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) + # Remove collective RPC of reload weights from generate # trl added reload weights (potentially for quantized models), we don't need it for our use case (LoRA primarily) # https://github.com/huggingface/trl/commit/7856d3b1f6518601732f489883b341bb6dd36434#diff-964e6fd373aa93037604064cb2b822d7f8e2735e33f791065acf2c4c3552d393R1168-R1169 @@ -411,7 +413,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function): # We replace the call using 'completions' with one using 'completions_text' string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)" replacement_string = " rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)" - + function = function.replace(string_to_find, replacement_string) if "wake_up()" not in function: @@ -935,7 +937,7 @@ def compute_loss( logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite if logit_scale_divide is None: logit_scale_divide = 0 - + max_left_pad = inputs.get("max_left_pad", 0) if per_token_logps is not None: loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( @@ -1076,7 +1078,8 @@ def compute_loss( .item() ) - completion_token_count = completion_mask.sum().clamp(min=1.0) + completion_token_count = completion_mask.sum().clamp(min = 1.0) + def masked_batch_mean(x): if x.shape[1] == 1: # when importance_sampling_level == "sequence" return x.mean() @@ -1097,22 +1100,33 @@ def masked_batch_mean(x): clip_ratio = masked_batch_mean(is_region_clipped.float()) gathered_low_clip = self.accelerator.gather(low_clip) - self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + self._metrics[mode]["clip_ratio/low_mean"].append( + gathered_low_clip.nanmean().item() + ) + self._metrics[mode]["clip_ratio/low_min"].append( + nanmin(gathered_low_clip).item() + ) gathered_high_clip = self.accelerator.gather(high_clip) - self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + self._metrics[mode]["clip_ratio/high_mean"].append( + gathered_high_clip.nanmean().item() + ) + self._metrics[mode]["clip_ratio/high_max"].append( + nanmax(gathered_high_clip).item() + ) gathered_clip_ratio = self.accelerator.gather(clip_ratio) - self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + self._metrics[mode]["clip_ratio/region_mean"].append( + gathered_clip_ratio.nanmean().item() + ) elif self.loss_type == "cispo": is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) - self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) + self._metrics[mode]["cispo_clip_ratio"].append( + gathered_cispo_clip_ratio.nanmean().item() + ) return loss - function = inspect.getsource(compute_loss) return function From d8a7c7de9d4a488c8ced67627a30e721e9872905 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 3 Feb 2026 14:00:19 -0500 Subject: [PATCH 09/13] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index dfff8b8373..bcd691a445 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -412,8 +412,12 @@ def grpo_trainer__generate_and_score_completions(function_name, function): if trl_version >= Version("0.25.0"): # We replace the call using 'completions' with one using 'completions_text' string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)" - replacement_string = " rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)" - + replacement_string = ( + " if images is not None:\n" + " rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)\n" + " else:\n" + " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)" + ) function = function.replace(string_to_find, replacement_string) if "wake_up()" not in function: From e7db06e3f1ba6a80d87a52706b1b09847f6575cc Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:43:37 -0500 Subject: [PATCH 10/13] Update rl.py --- unsloth/models/rl.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 02ddff7a61..61eab6f202 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1164,18 +1164,36 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pattern, new_options, RLTrainer_source, flags = re.DOTALL ) - if trl_version >= Version("0.26.0"): + if trl_version >= Version("0.27.0"): + peft_pattern = ( + r"\s*if is_peft_available\(\) and is_peft_model\(model\) and args\.beta != 0\.0:" + r".*?" + r"param\.data = param\.data\.to\(torch\.bfloat16\)" + ) + + replacement_comment = ( + "\n # PEFT initialization logic removed via script for trl >= 0.27.0\n" + ) + + RLTrainer_source = re.sub( + peft_pattern, + replacement_comment, + RLTrainer_source, + flags=re.DOTALL + ) + + elif trl_version >= Version("0.26.0"): peft_block_pattern = ( - r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:" - r".*?" - r"param\.data = param\.data\.to\(torch\.bfloat16\)" + r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:" + r".*?" + r"param\.data = param\.data\.to\(torch\.bfloat16\)" ) RLTrainer_source = re.sub( - peft_block_pattern, - "\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n", - RLTrainer_source, - flags = re.DOTALL, + peft_block_pattern, + "\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n", + RLTrainer_source, + flags=re.DOTALL ) if RLTrainer_name == "SFTTrainer": From 832c03b7eecfccd9aca50745c7c105952aab7cb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 03:44:40 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 61eab6f202..b4dbaa3271 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1166,34 +1166,29 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if trl_version >= Version("0.27.0"): peft_pattern = ( - r"\s*if is_peft_available\(\) and is_peft_model\(model\) and args\.beta != 0\.0:" - r".*?" - r"param\.data = param\.data\.to\(torch\.bfloat16\)" + r"\s*if is_peft_available\(\) and is_peft_model\(model\) and args\.beta != 0\.0:" + r".*?" + r"param\.data = param\.data\.to\(torch\.bfloat16\)" ) - replacement_comment = ( - "\n # PEFT initialization logic removed via script for trl >= 0.27.0\n" - ) + replacement_comment = "\n # PEFT initialization logic removed via script for trl >= 0.27.0\n" RLTrainer_source = re.sub( - peft_pattern, - replacement_comment, - RLTrainer_source, - flags=re.DOTALL + peft_pattern, replacement_comment, RLTrainer_source, flags = re.DOTALL ) - - elif trl_version >= Version("0.26.0"): + + elif trl_version >= Version("0.26.0"): peft_block_pattern = ( - r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:" - r".*?" - r"param\.data = param\.data\.to\(torch\.bfloat16\)" + r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:" + r".*?" + r"param\.data = param\.data\.to\(torch\.bfloat16\)" ) RLTrainer_source = re.sub( - peft_block_pattern, - "\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n", - RLTrainer_source, - flags=re.DOTALL + peft_block_pattern, + "\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n", + RLTrainer_source, + flags = re.DOTALL, ) if RLTrainer_name == "SFTTrainer": From fa354fc1987f83f06dd46b12d2468e6799cef22a Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Wed, 4 Feb 2026 20:31:30 -0500 Subject: [PATCH 12/13] Update rl_replacements.py, remove chat template from codexes commits --- unsloth/models/rl_replacements.py | 93 ------------------------------- 1 file changed, 93 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 36164b5e42..8208dc922a 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -443,99 +443,6 @@ def grpo_trainer__generate_and_score_completions(function_name, function): _target_line + _metadata_extraction, ) - # Unsloth: Skip prepare_multimodal_messages when prompts are pre-templated strings. - # When notebooks pre-apply apply_chat_template(), prompts become strings with image tokens - # already embedded. Calling prepare_multimodal_messages on strings crashes with TypeError. - # Skipping it keeps prompts as strings so TRL uses the non-conversational path, which - # ensures completions are strings and reward functions work correctly. - string_to_find_vision = """ if images is not None: - prompts = [ - prepare_multimodal_messages(prompt, image_list) - for prompt, image_list in zip(prompts, images, strict=True) - ]""" - - replacement_string_vision = """ if images is not None: - # Unsloth: skip prepare_multimodal_messages for pre-templated string prompts - if not prompts or not isinstance(prompts[0], str): - prompts = [ - prepare_multimodal_messages(prompt, image_list) - for prompt, image_list in zip(prompts, images, strict=True) - ]""" - - function = function.replace(string_to_find_vision, replacement_string_vision) - - # Unsloth: Skip apply_chat_template in the forward_kwargs block for pre-templated - # string prompts. When prompts are already strings (from notebooks that pre-applied - # apply_chat_template), calling it again crashes because strings aren't dicts. - # We use prompts directly as prompts_text instead. - - # TRL 0.26.2+ variant (has tools=self.tools) - string_to_find_fwd = """ if images is not None: - prompts_text = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs - )["prompt"] - for prompt in prompts - ]""" - - replacement_string_fwd = """ if images is not None: - # Unsloth: skip apply_chat_template for pre-templated string prompts - if prompts and isinstance(prompts[0], str): - prompts_text = prompts - else: - prompts_text = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs - )["prompt"] - for prompt in prompts - ]""" - - function = function.replace(string_to_find_fwd, replacement_string_fwd) - - # TRL 0.25.x variant (no tools parameter) - string_to_find_fwd_old = """ if images is not None: - prompts_text = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, **self.chat_template_kwargs - )["prompt"] - for prompt in prompts - ]""" - - replacement_string_fwd_old = """ if images is not None: - # Unsloth: skip apply_chat_template for pre-templated string prompts - if prompts and isinstance(prompts[0], str): - prompts_text = prompts - else: - prompts_text = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, **self.chat_template_kwargs - )["prompt"] - for prompt in prompts - ]""" - - function = function.replace(string_to_find_fwd_old, replacement_string_fwd_old) - - # TRL 0.25.1 single-line variant (no tools, single-line apply_chat_template call) - string_to_find_fwd_single = """ if images is not None: - prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] - for prompt in prompts - ]""" - - replacement_string_fwd_single = """ if images is not None: - # Unsloth: skip apply_chat_template for pre-templated string prompts - if prompts and isinstance(prompts[0], str): - prompts_text = prompts - else: - prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] - for prompt in prompts - ]""" - - function = function.replace( - string_to_find_fwd_single, replacement_string_fwd_single - ) - # This path is for TRL 0.24.0 images is a variable exclusive to this version string_to_find = """ if images is not None: output["num_images"] = num_images""" From f6002a6bda9d1b5fc07e3eb712b2f51595fef99e Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Wed, 4 Feb 2026 21:15:03 -0500 Subject: [PATCH 13/13] Update rl.py, got rid of gradient checkpointing code that did not work --- unsloth/models/rl.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b4dbaa3271..3776752f4b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -479,20 +479,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): old_RLTrainer_source = inspect.getsource(RLTrainer) old_RLConfig_source = inspect.getsource(RLConfig) - if trl_version >= Version("0.27.0") and RLConfig_name == "GRPOConfig": - checkpoint_pattern = ( - r"\s*if self\.gradient_checkpointing and Version\(transformers\.__version__\) < Version\(\"5\.0\.0\"\):" - r".*?" - r"self\.gradient_checkpointing_kwargs\.setdefault\(\"use_reentrant\", False\)" - ) - - old_RLConfig_source = re.sub( - checkpoint_pattern, - "\n # Gradient checkpointing version check removed via script\n", - old_RLConfig_source, - flags = re.DOTALL, - ) - all_imports = dir(trainer) # Fix _deprecate_arguments not getting imported so stop __ but not _ imports = [x for x in all_imports if not x.startswith("__")]