From cb6fb1bb4cba5d74f641284d21aae8572a4c000f Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:51:32 -0500 Subject: [PATCH 1/9] Refactor loss computation to include completion_mask --- unsloth/models/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b3a55440f9..44a8d3a68a 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1013,7 +1013,7 @@ def compute_loss( max_left_pad = inputs.get("max_left_pad", 0) if per_token_logps is not None: - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( + loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask = ( grpo_compute_loss_slow( ref_logps, per_token_logps, @@ -1043,7 +1043,7 @@ def compute_loss( ) else: if hasattr(self.args, "loss_type"): - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( + loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask = ( grpo_accumulated_loss( trainer = self, input_ids = _input_ids, @@ -1075,7 +1075,7 @@ def compute_loss( ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 - loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss( + loss, completion_length, mean_kl, coef_1, completion_mask = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, logits_to_keep = logits_to_keep, From a863d0f84f6fbd9cc4ab925e7003245efadd6200 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 17:57:43 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl_replacements.py | 98 ++++++++++++++++++------------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 44a8d3a68a..69e9ce3455 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1013,17 +1013,61 @@ def compute_loss( max_left_pad = inputs.get("max_left_pad", 0) if per_token_logps is not None: - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask = ( - grpo_compute_loss_slow( - ref_logps, - per_token_logps, - old_logps, - input_ids, + ( + loss, + completion_length, + mean_kl, + delta, + flat_is_ratio, + coef_1, + completion_mask, + ) = grpo_compute_loss_slow( + ref_logps, + per_token_logps, + old_logps, + input_ids, + completion_mask, + self.beta, + advantages, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + loss_type = self.args.loss_type, + importance_sampling_level = self.importance_sampling_level, + epsilon_low = self.epsilon_low, + epsilon_high = self.epsilon_high, + max_completion_length = self.args.max_completion_length, + delta = self.args.delta, + temperature = self.args.temperature, + max_left_pad = max_left_pad, + logit_softcapping = logit_softcapping, + logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, + num_items_in_batch = num_items_in_batch, + current_gradient_accumulation_steps = current_gradient_accumulation_steps, + num_processes = num_processes, + sampling_per_token_logps = sampling_per_token_logps, + ) + else: + if hasattr(self.args, "loss_type"): + ( + loss, + completion_length, + mean_kl, + delta, + flat_is_ratio, + coef_1, completion_mask, - self.beta, - advantages, + ) = grpo_accumulated_loss( + trainer = self, + input_ids = _input_ids, pixel_values = pixel_values, image_grid_thw = image_grid_thw, + logits_to_keep = logits_to_keep, + completion_mask = completion_mask, + advantages = advantages, + old_logps = old_logps, + ref_logps = ref_logps, + n_chunks = self.args.unsloth_num_chunks, loss_type = self.args.loss_type, importance_sampling_level = self.importance_sampling_level, epsilon_low = self.epsilon_low, @@ -1035,61 +1079,31 @@ def compute_loss( logit_softcapping = logit_softcapping, logit_scale_multiply = logit_scale_multiply, logit_scale_divide = logit_scale_divide, + attention_mask = attention_mask, num_items_in_batch = num_items_in_batch, current_gradient_accumulation_steps = current_gradient_accumulation_steps, num_processes = num_processes, sampling_per_token_logps = sampling_per_token_logps, ) - ) - else: - if hasattr(self.args, "loss_type"): - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask = ( + else: + # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 + loss, completion_length, mean_kl, coef_1, completion_mask = ( grpo_accumulated_loss( trainer = self, input_ids = _input_ids, - pixel_values = pixel_values, - image_grid_thw = image_grid_thw, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, old_logps = old_logps, ref_logps = ref_logps, n_chunks = self.args.unsloth_num_chunks, - loss_type = self.args.loss_type, - importance_sampling_level = self.importance_sampling_level, - epsilon_low = self.epsilon_low, - epsilon_high = self.epsilon_high, - max_completion_length = self.args.max_completion_length, - delta = self.args.delta, temperature = self.args.temperature, - max_left_pad = max_left_pad, logit_softcapping = logit_softcapping, logit_scale_multiply = logit_scale_multiply, logit_scale_divide = logit_scale_divide, attention_mask = attention_mask, - num_items_in_batch = num_items_in_batch, - current_gradient_accumulation_steps = current_gradient_accumulation_steps, - num_processes = num_processes, - sampling_per_token_logps = sampling_per_token_logps, ) ) - else: - # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 - loss, completion_length, mean_kl, coef_1, completion_mask = grpo_accumulated_loss( - trainer = self, - input_ids = _input_ids, - logits_to_keep = logits_to_keep, - completion_mask = completion_mask, - advantages = advantages, - old_logps = old_logps, - ref_logps = ref_logps, - n_chunks = self.args.unsloth_num_chunks, - temperature = self.args.temperature, - logit_softcapping = logit_softcapping, - logit_scale_multiply = logit_scale_multiply, - logit_scale_divide = logit_scale_divide, - attention_mask = attention_mask, - ) if "train" in self._metrics: mode = "eval" if self.control.should_evaluate else "train" self._metrics[mode]["completion_length"].append(completion_length.item()) From 35721ee0ffff74501c085f7b4120e67cccd78e4b Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 4 Mar 2026 10:57:21 +0000 Subject: [PATCH 3/9] Fixes for trl 0.28 and above Remove sync/reload weights calls , remove vllm.LLM instantiation --- unsloth/models/rl.py | 14 ++- unsloth/models/rl_replacements.py | 155 ++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e4f34c908e..4d9e747f8d 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -389,8 +389,8 @@ class Unsloth{RLConfig_name}({RLConfig_name}): def __init__({RLConfig_arguments}, vllm_sampling_params = None, unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, + unsloth_logit_chunk_multiplier = None, + unsloth_grpo_mini_batch = None, {max_seq_length_call} **kwargs, ): @@ -1875,11 +1875,21 @@ def patch_trl_openenv(): function() # Call the function to apply the patch return +def patch_trl_vllm_generation(): + # trl moved vllm stuff to trl/generation/vllm_generation.py + # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference + # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause + for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]: + logger.info(f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}") + function() + return + def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() patch_trl_openenv() + patch_trl_vllm_generation() if type(algorithm) is str and algorithm.islower(): PatchRLStatistics(algorithm) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 69e9ce3455..f2e4b050f2 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -24,6 +24,7 @@ import re import torch import inspect +import linecache from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding from unsloth_zoo.utils import Version @@ -264,12 +265,28 @@ def grpo_trainer__generate_single_turn(function_name, function): # Remove the reload_weights collective RPC call from the generate function's source # function = function.replace('self.llm.collective_rpc("reload_weights")', "") # The regex below does the same thing but is more flexible and can handle single or double quotes + # This is for older versions. function = re.sub( r"self\.llm\.collective_rpc\(\s*(['\"])reload_weights\1\s*\)", "", function, ) + # Current TRL versions call vllm_generation.sync_weights() every step. + # When Unsloth fast inference LoRA is active, weights are already shared. + sync_weights_block = re.compile( + r"(?P[ \t]*)with profiling_context\(self,\s*(['\"])sync_weights\2\s*\):\n" + r"(?P=indent)[ \t]+self\.vllm_generation\.sync_weights\(\)\n", + re.MULTILINE, + ) + def remove_sync_weights_block(match): + indent = match.group("indent") + return ( + f"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\n" + f"{indent}# Skipping per-step vLLM sync_weights().\n" + ) + function = sync_weights_block.sub(remove_sync_weights_block, function) + # TRL 0.24.0-0.25.1 truncation regression fix # # TRL 0.22.2-0.23.1 used smart truncation via truncate_with_protected_tokens(): @@ -1352,3 +1369,141 @@ def openenv_vllm_reload_weights(): RL_ADDITIONAL_FUNCTIONS["openenv"].append(openenv_vllm_reload_weights) + + +def vllm_generation_init_patch(): + # trl moved vllm stuff to trl/generation/vllm_generation.py + # We need to patch it to not instantiate another vLLM instance if we already have one with fast_inference + # Edit the TRL source directly and install the patched function in the TRL module. + # https://github.com/huggingface/trl/commit/0eb66d8f2fc63b3d00d8dbc18f99c3f48750bd16 + # This exists in trl versions 0.28.0 and above + + if importlib.util.find_spec("trl") is None: + return + if Version(importlib_version("trl")) < Version("0.28.0"): + return + + try: + import trl.generation.vllm_generation as vllm_generation + except (ImportError, NameError, Exception) as e: + logger.info(f"Unsloth: Failed to import trl.generation.vllm_generation: {e}") + return + + def patch_vllm_generation_method(method_name, transform, marker, filename_suffix): + method = getattr(vllm_generation.VLLMGeneration, method_name, None) + if method is None: + logger.info(f"Unsloth: Could not find VLLMGeneration.{method_name}") + return False + + try: + src = inspect.getsource(method) + except Exception as e: + logger.info(f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}") + return False + + src = textwrap.dedent(src) + if marker in src: + return True + + src = transform(src) + filename = f"" + source_lines = [line + "\n" for line in src.splitlines()] + linecache.cache[filename] = ( + len(src), + None, + source_lines, + filename, + ) + + local_ns = {} + exec(compile(src, filename, "exec"), vllm_generation.__dict__, local_ns) + setattr(vllm_generation.VLLMGeneration, method_name, local_ns[method_name]) + return True + + # Patch init to remove vLLM.LLM instantiation + def patch_init_vllm(src): + pattern = re.compile( + r"(?P^(?P[ \t]*)self\.llm\s*=\s*LLM\s*\(\n(?:.*\n)*?^(?P=indent)\))", + re.MULTILINE, + ) + def replace_llm_block(match): + indent = match.group("indent") + llm_block = textwrap.dedent(match.group("llm_block")) + return ( + f"{indent}if hasattr(model, 'vllm_engine'):\n" + f"{indent} # Unsloth already inits vLLM in fast inference mode. Do not redo :)\n" + f"{indent} self.llm = model.vllm_engine\n" + f"{indent} self.unsloth_fast_inference_lora = True\n" + f"{indent}else:\n" + + textwrap.indent(llm_block, indent + " ") + ) + patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1) + if num_replacements == 0: + raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed") + return patched_src + + # has some sync_weights or reload rpc calls. + # we patched the grpo_trainer to strip them for prev versions + # Ref: grpo_trainer__generate_single_turn above around L270-280 + def patch_sync_weights(src): + pattern = re.compile( + r"^(?Pdef sync_weights\(self\):\n)(?P(?:.*\n)*)", + re.MULTILINE, + ) + def replace_sync_weights(match): + body = match.group("body") + guard = ( + " if getattr(self, 'unsloth_fast_inference_lora', False):\n" + " # Unsloth fast inference LoRA shares weights with vLLM already.\n" + " return\n\n" + ) + return match.group("def_line") + guard + body + patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1) + if num_replacements == 0: + raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed") + return patched_src + + def patch_generate(src): + pattern = re.compile( + r"^(?P[ \t]*)self\.llm\.collective_rpc\(\s*(['\"])reload_weights\2\s*\)\s*$", + re.MULTILINE, + ) + def replace_reload_weights(match): + indent = match.group("indent") + return f'{indent}pass # self.llm.collective_rpc("reload_weights")' + patched_src, num_replacements = pattern.subn(replace_reload_weights, src, count = 1) + if num_replacements == 0: + raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed") + return patched_src + + try: + init_patched = patch_vllm_generation_method( + "_init_vllm", + patch_init_vllm, + "self.unsloth_fast_inference_lora = True", + "init_vllm", + ) + sync_patched = patch_vllm_generation_method( + "sync_weights", + patch_sync_weights, + "if getattr(self, 'unsloth_fast_inference_lora', False):", + "sync_weights", + ) + generate_patched = patch_vllm_generation_method( + "generate", + patch_generate, + 'pass # self.llm.collective_rpc("reload_weights")', + "generate", + ) + except RuntimeError as e: + logger.warning(str(e)) + return + + if init_patched: + logger.info("Unsloth: Patched trl VLLMGeneration._init_vllm") + if sync_patched: + logger.info("Unsloth: Patched trl VLLMGeneration.sync_weights") + if generate_patched: + logger.info("Unsloth: Patched trl VLLMGeneration.generate") + +RL_ADDITIONAL_FUNCTIONS["vllm_generation"].append(vllm_generation_init_patch) From c75d8aa4afc4a46edde18d1fe402dddfd4eb5375 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 10:59:01 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl.py | 5 ++++- unsloth/models/rl_replacements.py | 32 ++++++++++++++++++++++++------- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4d9e747f8d..285227f174 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1875,12 +1875,15 @@ def patch_trl_openenv(): function() # Call the function to apply the patch return + def patch_trl_vllm_generation(): # trl moved vllm stuff to trl/generation/vllm_generation.py # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]: - logger.info(f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}") + logger.info( + f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}" + ) function() return diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f2e4b050f2..453fe61930 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -279,12 +279,14 @@ def grpo_trainer__generate_single_turn(function_name, function): r"(?P=indent)[ \t]+self\.vllm_generation\.sync_weights\(\)\n", re.MULTILINE, ) + def remove_sync_weights_block(match): indent = match.group("indent") return ( f"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\n" f"{indent}# Skipping per-step vLLM sync_weights().\n" ) + function = sync_weights_block.sub(remove_sync_weights_block, function) # TRL 0.24.0-0.25.1 truncation regression fix @@ -1398,7 +1400,9 @@ def patch_vllm_generation_method(method_name, transform, marker, filename_suffix try: src = inspect.getsource(method) except Exception as e: - logger.info(f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}") + logger.info( + f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}" + ) return False src = textwrap.dedent(src) @@ -1426,6 +1430,7 @@ def patch_init_vllm(src): r"(?P^(?P[ \t]*)self\.llm\s*=\s*LLM\s*\(\n(?:.*\n)*?^(?P=indent)\))", re.MULTILINE, ) + def replace_llm_block(match): indent = match.group("indent") llm_block = textwrap.dedent(match.group("llm_block")) @@ -1434,12 +1439,14 @@ def replace_llm_block(match): f"{indent} # Unsloth already inits vLLM in fast inference mode. Do not redo :)\n" f"{indent} self.llm = model.vllm_engine\n" f"{indent} self.unsloth_fast_inference_lora = True\n" - f"{indent}else:\n" - + textwrap.indent(llm_block, indent + " ") + f"{indent}else:\n" + textwrap.indent(llm_block, indent + " ") ) + patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1) if num_replacements == 0: - raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed") + raise RuntimeError( + "Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed" + ) return patched_src # has some sync_weights or reload rpc calls. @@ -1450,6 +1457,7 @@ def patch_sync_weights(src): r"^(?Pdef sync_weights\(self\):\n)(?P(?:.*\n)*)", re.MULTILINE, ) + def replace_sync_weights(match): body = match.group("body") guard = ( @@ -1458,9 +1466,12 @@ def replace_sync_weights(match): " return\n\n" ) return match.group("def_line") + guard + body + patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1) if num_replacements == 0: - raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed") + raise RuntimeError( + "Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed" + ) return patched_src def patch_generate(src): @@ -1468,12 +1479,18 @@ def patch_generate(src): r"^(?P[ \t]*)self\.llm\.collective_rpc\(\s*(['\"])reload_weights\2\s*\)\s*$", re.MULTILINE, ) + def replace_reload_weights(match): indent = match.group("indent") return f'{indent}pass # self.llm.collective_rpc("reload_weights")' - patched_src, num_replacements = pattern.subn(replace_reload_weights, src, count = 1) + + patched_src, num_replacements = pattern.subn( + replace_reload_weights, src, count = 1 + ) if num_replacements == 0: - raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed") + raise RuntimeError( + "Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed" + ) return patched_src try: @@ -1506,4 +1523,5 @@ def replace_reload_weights(match): if generate_patched: logger.info("Unsloth: Patched trl VLLMGeneration.generate") + RL_ADDITIONAL_FUNCTIONS["vllm_generation"].append(vllm_generation_init_patch) From 785e30889a403055a182749e3e34947b95a70c57 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:51:32 -0500 Subject: [PATCH 5/9] Refactor loss computation to include completion_mask --- unsloth/models/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b3a55440f9..44a8d3a68a 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1013,7 +1013,7 @@ def compute_loss( max_left_pad = inputs.get("max_left_pad", 0) if per_token_logps is not None: - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( + loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask = ( grpo_compute_loss_slow( ref_logps, per_token_logps, @@ -1043,7 +1043,7 @@ def compute_loss( ) else: if hasattr(self.args, "loss_type"): - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = ( + loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask = ( grpo_accumulated_loss( trainer = self, input_ids = _input_ids, @@ -1075,7 +1075,7 @@ def compute_loss( ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 - loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss( + loss, completion_length, mean_kl, coef_1, completion_mask = grpo_accumulated_loss( trainer = self, input_ids = _input_ids, logits_to_keep = logits_to_keep, From deafd914154adb8fec63a4bee3a7a143ee275c83 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 17:57:43 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl_replacements.py | 98 ++++++++++++++++++------------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 44a8d3a68a..69e9ce3455 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1013,17 +1013,61 @@ def compute_loss( max_left_pad = inputs.get("max_left_pad", 0) if per_token_logps is not None: - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask = ( - grpo_compute_loss_slow( - ref_logps, - per_token_logps, - old_logps, - input_ids, + ( + loss, + completion_length, + mean_kl, + delta, + flat_is_ratio, + coef_1, + completion_mask, + ) = grpo_compute_loss_slow( + ref_logps, + per_token_logps, + old_logps, + input_ids, + completion_mask, + self.beta, + advantages, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + loss_type = self.args.loss_type, + importance_sampling_level = self.importance_sampling_level, + epsilon_low = self.epsilon_low, + epsilon_high = self.epsilon_high, + max_completion_length = self.args.max_completion_length, + delta = self.args.delta, + temperature = self.args.temperature, + max_left_pad = max_left_pad, + logit_softcapping = logit_softcapping, + logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, + num_items_in_batch = num_items_in_batch, + current_gradient_accumulation_steps = current_gradient_accumulation_steps, + num_processes = num_processes, + sampling_per_token_logps = sampling_per_token_logps, + ) + else: + if hasattr(self.args, "loss_type"): + ( + loss, + completion_length, + mean_kl, + delta, + flat_is_ratio, + coef_1, completion_mask, - self.beta, - advantages, + ) = grpo_accumulated_loss( + trainer = self, + input_ids = _input_ids, pixel_values = pixel_values, image_grid_thw = image_grid_thw, + logits_to_keep = logits_to_keep, + completion_mask = completion_mask, + advantages = advantages, + old_logps = old_logps, + ref_logps = ref_logps, + n_chunks = self.args.unsloth_num_chunks, loss_type = self.args.loss_type, importance_sampling_level = self.importance_sampling_level, epsilon_low = self.epsilon_low, @@ -1035,61 +1079,31 @@ def compute_loss( logit_softcapping = logit_softcapping, logit_scale_multiply = logit_scale_multiply, logit_scale_divide = logit_scale_divide, + attention_mask = attention_mask, num_items_in_batch = num_items_in_batch, current_gradient_accumulation_steps = current_gradient_accumulation_steps, num_processes = num_processes, sampling_per_token_logps = sampling_per_token_logps, ) - ) - else: - if hasattr(self.args, "loss_type"): - loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1, completion_mask = ( + else: + # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 + loss, completion_length, mean_kl, coef_1, completion_mask = ( grpo_accumulated_loss( trainer = self, input_ids = _input_ids, - pixel_values = pixel_values, - image_grid_thw = image_grid_thw, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, old_logps = old_logps, ref_logps = ref_logps, n_chunks = self.args.unsloth_num_chunks, - loss_type = self.args.loss_type, - importance_sampling_level = self.importance_sampling_level, - epsilon_low = self.epsilon_low, - epsilon_high = self.epsilon_high, - max_completion_length = self.args.max_completion_length, - delta = self.args.delta, temperature = self.args.temperature, - max_left_pad = max_left_pad, logit_softcapping = logit_softcapping, logit_scale_multiply = logit_scale_multiply, logit_scale_divide = logit_scale_divide, attention_mask = attention_mask, - num_items_in_batch = num_items_in_batch, - current_gradient_accumulation_steps = current_gradient_accumulation_steps, - num_processes = num_processes, - sampling_per_token_logps = sampling_per_token_logps, ) ) - else: - # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 - loss, completion_length, mean_kl, coef_1, completion_mask = grpo_accumulated_loss( - trainer = self, - input_ids = _input_ids, - logits_to_keep = logits_to_keep, - completion_mask = completion_mask, - advantages = advantages, - old_logps = old_logps, - ref_logps = ref_logps, - n_chunks = self.args.unsloth_num_chunks, - temperature = self.args.temperature, - logit_softcapping = logit_softcapping, - logit_scale_multiply = logit_scale_multiply, - logit_scale_divide = logit_scale_divide, - attention_mask = attention_mask, - ) if "train" in self._metrics: mode = "eval" if self.control.should_evaluate else "train" self._metrics[mode]["completion_length"].append(completion_length.item()) From 50618b3e9a0d5503121f495d6c314962afe03f9d Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 4 Mar 2026 10:57:21 +0000 Subject: [PATCH 7/9] Fixes for trl 0.28 and above Remove sync/reload weights calls , remove vllm.LLM instantiation --- unsloth/models/rl.py | 14 ++- unsloth/models/rl_replacements.py | 155 ++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e4f34c908e..4d9e747f8d 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -389,8 +389,8 @@ class Unsloth{RLConfig_name}({RLConfig_name}): def __init__({RLConfig_arguments}, vllm_sampling_params = None, unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, + unsloth_logit_chunk_multiplier = None, + unsloth_grpo_mini_batch = None, {max_seq_length_call} **kwargs, ): @@ -1875,11 +1875,21 @@ def patch_trl_openenv(): function() # Call the function to apply the patch return +def patch_trl_vllm_generation(): + # trl moved vllm stuff to trl/generation/vllm_generation.py + # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference + # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause + for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]: + logger.info(f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}") + function() + return + def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() patch_trl_openenv() + patch_trl_vllm_generation() if type(algorithm) is str and algorithm.islower(): PatchRLStatistics(algorithm) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 69e9ce3455..f2e4b050f2 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -24,6 +24,7 @@ import re import torch import inspect +import linecache from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding from unsloth_zoo.utils import Version @@ -264,12 +265,28 @@ def grpo_trainer__generate_single_turn(function_name, function): # Remove the reload_weights collective RPC call from the generate function's source # function = function.replace('self.llm.collective_rpc("reload_weights")', "") # The regex below does the same thing but is more flexible and can handle single or double quotes + # This is for older versions. function = re.sub( r"self\.llm\.collective_rpc\(\s*(['\"])reload_weights\1\s*\)", "", function, ) + # Current TRL versions call vllm_generation.sync_weights() every step. + # When Unsloth fast inference LoRA is active, weights are already shared. + sync_weights_block = re.compile( + r"(?P[ \t]*)with profiling_context\(self,\s*(['\"])sync_weights\2\s*\):\n" + r"(?P=indent)[ \t]+self\.vllm_generation\.sync_weights\(\)\n", + re.MULTILINE, + ) + def remove_sync_weights_block(match): + indent = match.group("indent") + return ( + f"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\n" + f"{indent}# Skipping per-step vLLM sync_weights().\n" + ) + function = sync_weights_block.sub(remove_sync_weights_block, function) + # TRL 0.24.0-0.25.1 truncation regression fix # # TRL 0.22.2-0.23.1 used smart truncation via truncate_with_protected_tokens(): @@ -1352,3 +1369,141 @@ def openenv_vllm_reload_weights(): RL_ADDITIONAL_FUNCTIONS["openenv"].append(openenv_vllm_reload_weights) + + +def vllm_generation_init_patch(): + # trl moved vllm stuff to trl/generation/vllm_generation.py + # We need to patch it to not instantiate another vLLM instance if we already have one with fast_inference + # Edit the TRL source directly and install the patched function in the TRL module. + # https://github.com/huggingface/trl/commit/0eb66d8f2fc63b3d00d8dbc18f99c3f48750bd16 + # This exists in trl versions 0.28.0 and above + + if importlib.util.find_spec("trl") is None: + return + if Version(importlib_version("trl")) < Version("0.28.0"): + return + + try: + import trl.generation.vllm_generation as vllm_generation + except (ImportError, NameError, Exception) as e: + logger.info(f"Unsloth: Failed to import trl.generation.vllm_generation: {e}") + return + + def patch_vllm_generation_method(method_name, transform, marker, filename_suffix): + method = getattr(vllm_generation.VLLMGeneration, method_name, None) + if method is None: + logger.info(f"Unsloth: Could not find VLLMGeneration.{method_name}") + return False + + try: + src = inspect.getsource(method) + except Exception as e: + logger.info(f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}") + return False + + src = textwrap.dedent(src) + if marker in src: + return True + + src = transform(src) + filename = f"" + source_lines = [line + "\n" for line in src.splitlines()] + linecache.cache[filename] = ( + len(src), + None, + source_lines, + filename, + ) + + local_ns = {} + exec(compile(src, filename, "exec"), vllm_generation.__dict__, local_ns) + setattr(vllm_generation.VLLMGeneration, method_name, local_ns[method_name]) + return True + + # Patch init to remove vLLM.LLM instantiation + def patch_init_vllm(src): + pattern = re.compile( + r"(?P^(?P[ \t]*)self\.llm\s*=\s*LLM\s*\(\n(?:.*\n)*?^(?P=indent)\))", + re.MULTILINE, + ) + def replace_llm_block(match): + indent = match.group("indent") + llm_block = textwrap.dedent(match.group("llm_block")) + return ( + f"{indent}if hasattr(model, 'vllm_engine'):\n" + f"{indent} # Unsloth already inits vLLM in fast inference mode. Do not redo :)\n" + f"{indent} self.llm = model.vllm_engine\n" + f"{indent} self.unsloth_fast_inference_lora = True\n" + f"{indent}else:\n" + + textwrap.indent(llm_block, indent + " ") + ) + patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1) + if num_replacements == 0: + raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed") + return patched_src + + # has some sync_weights or reload rpc calls. + # we patched the grpo_trainer to strip them for prev versions + # Ref: grpo_trainer__generate_single_turn above around L270-280 + def patch_sync_weights(src): + pattern = re.compile( + r"^(?Pdef sync_weights\(self\):\n)(?P(?:.*\n)*)", + re.MULTILINE, + ) + def replace_sync_weights(match): + body = match.group("body") + guard = ( + " if getattr(self, 'unsloth_fast_inference_lora', False):\n" + " # Unsloth fast inference LoRA shares weights with vLLM already.\n" + " return\n\n" + ) + return match.group("def_line") + guard + body + patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1) + if num_replacements == 0: + raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed") + return patched_src + + def patch_generate(src): + pattern = re.compile( + r"^(?P[ \t]*)self\.llm\.collective_rpc\(\s*(['\"])reload_weights\2\s*\)\s*$", + re.MULTILINE, + ) + def replace_reload_weights(match): + indent = match.group("indent") + return f'{indent}pass # self.llm.collective_rpc("reload_weights")' + patched_src, num_replacements = pattern.subn(replace_reload_weights, src, count = 1) + if num_replacements == 0: + raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed") + return patched_src + + try: + init_patched = patch_vllm_generation_method( + "_init_vllm", + patch_init_vllm, + "self.unsloth_fast_inference_lora = True", + "init_vllm", + ) + sync_patched = patch_vllm_generation_method( + "sync_weights", + patch_sync_weights, + "if getattr(self, 'unsloth_fast_inference_lora', False):", + "sync_weights", + ) + generate_patched = patch_vllm_generation_method( + "generate", + patch_generate, + 'pass # self.llm.collective_rpc("reload_weights")', + "generate", + ) + except RuntimeError as e: + logger.warning(str(e)) + return + + if init_patched: + logger.info("Unsloth: Patched trl VLLMGeneration._init_vllm") + if sync_patched: + logger.info("Unsloth: Patched trl VLLMGeneration.sync_weights") + if generate_patched: + logger.info("Unsloth: Patched trl VLLMGeneration.generate") + +RL_ADDITIONAL_FUNCTIONS["vllm_generation"].append(vllm_generation_init_patch) From 5bc60d6e068732d07232b94c7d12423ac79bbc0b Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 5 Mar 2026 14:27:42 +0000 Subject: [PATCH 8/9] patch rpc in openenv for newer trl --- unsloth/models/rl_replacements.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f2e4b050f2..7af3104a77 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1342,12 +1342,21 @@ def openenv_vllm_reload_weights(): ) return - src = inspect.getsource(openenv_utils.generate_rollout_completions) + # trl 0.28 changed the function name yet again! Thanks trl :) + patch_target_name = "_generate_rollout_completions_colocate" + if hasattr(openenv_utils, patch_target_name): + patch_target = getattr(openenv_utils, patch_target_name) + else: + # Older TRL versions may keep sleep/wake logic in the public dispatcher. + patch_target_name = "generate_rollout_completions" + patch_target = getattr(openenv_utils, patch_target_name) + + src = inspect.getsource(patch_target) src = textwrap.dedent(src) original_src = src # Remove the reload_weights call - unsloth handles this differently - src = re.sub(r'.*\.collective_rpc\("reload_weights"\).*\n?', "", src) + src = re.sub(r'.*\.collective_rpc\(\s*([\'"])reload_weights\1\s*\).*\n?', "", src) # Change wake_up(tags=["kv_cache"]) to wake_up() - wake everything to set is_sleeping=False # This prevents double wake_up issues. Unsloth's allocator skips weights anyway. @@ -1360,12 +1369,13 @@ def openenv_vllm_reload_weights(): # Execute and explicitly assign to module local_ns = {} exec(compile(src, "", "exec"), openenv_utils.__dict__, local_ns) - patched_func = local_ns["generate_rollout_completions"] + patched_func = local_ns[patch_target_name] - # Patch both the utils module and the parent openenv module - openenv_utils.generate_rollout_completions = patched_func - openenv.generate_rollout_completions = patched_func - logger.info("Unsloth: Patched trl openenv generate_rollout_completions") + # Patch the target function in utils; if dispatcher was patched also update parent module alias. + setattr(openenv_utils, patch_target_name, patched_func) + if patch_target_name == "generate_rollout_completions": + openenv.generate_rollout_completions = patched_func + logger.info(f"Unsloth: Patched trl openenv {patch_target_name}") RL_ADDITIONAL_FUNCTIONS["openenv"].append(openenv_vllm_reload_weights) From f7e7e821f7fb9b71a35f713cf4cf761b9391d525 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:39:57 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl.py | 5 ++++- unsloth/models/rl_replacements.py | 32 ++++++++++++++++++++++++------- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 98e8ea2053..30546a048d 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -1875,12 +1875,15 @@ def patch_trl_openenv(): function() # Call the function to apply the patch return + def patch_trl_vllm_generation(): # trl moved vllm stuff to trl/generation/vllm_generation.py # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]: - logger.info(f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}") + logger.info( + f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}" + ) function() return diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 7af3104a77..314feb5d2a 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -279,12 +279,14 @@ def grpo_trainer__generate_single_turn(function_name, function): r"(?P=indent)[ \t]+self\.vllm_generation\.sync_weights\(\)\n", re.MULTILINE, ) + def remove_sync_weights_block(match): indent = match.group("indent") return ( f"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\n" f"{indent}# Skipping per-step vLLM sync_weights().\n" ) + function = sync_weights_block.sub(remove_sync_weights_block, function) # TRL 0.24.0-0.25.1 truncation regression fix @@ -1408,7 +1410,9 @@ def patch_vllm_generation_method(method_name, transform, marker, filename_suffix try: src = inspect.getsource(method) except Exception as e: - logger.info(f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}") + logger.info( + f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}" + ) return False src = textwrap.dedent(src) @@ -1436,6 +1440,7 @@ def patch_init_vllm(src): r"(?P^(?P[ \t]*)self\.llm\s*=\s*LLM\s*\(\n(?:.*\n)*?^(?P=indent)\))", re.MULTILINE, ) + def replace_llm_block(match): indent = match.group("indent") llm_block = textwrap.dedent(match.group("llm_block")) @@ -1444,12 +1449,14 @@ def replace_llm_block(match): f"{indent} # Unsloth already inits vLLM in fast inference mode. Do not redo :)\n" f"{indent} self.llm = model.vllm_engine\n" f"{indent} self.unsloth_fast_inference_lora = True\n" - f"{indent}else:\n" - + textwrap.indent(llm_block, indent + " ") + f"{indent}else:\n" + textwrap.indent(llm_block, indent + " ") ) + patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1) if num_replacements == 0: - raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed") + raise RuntimeError( + "Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed" + ) return patched_src # has some sync_weights or reload rpc calls. @@ -1460,6 +1467,7 @@ def patch_sync_weights(src): r"^(?Pdef sync_weights\(self\):\n)(?P(?:.*\n)*)", re.MULTILINE, ) + def replace_sync_weights(match): body = match.group("body") guard = ( @@ -1468,9 +1476,12 @@ def replace_sync_weights(match): " return\n\n" ) return match.group("def_line") + guard + body + patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1) if num_replacements == 0: - raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed") + raise RuntimeError( + "Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed" + ) return patched_src def patch_generate(src): @@ -1478,12 +1489,18 @@ def patch_generate(src): r"^(?P[ \t]*)self\.llm\.collective_rpc\(\s*(['\"])reload_weights\2\s*\)\s*$", re.MULTILINE, ) + def replace_reload_weights(match): indent = match.group("indent") return f'{indent}pass # self.llm.collective_rpc("reload_weights")' - patched_src, num_replacements = pattern.subn(replace_reload_weights, src, count = 1) + + patched_src, num_replacements = pattern.subn( + replace_reload_weights, src, count = 1 + ) if num_replacements == 0: - raise RuntimeError("Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed") + raise RuntimeError( + "Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed" + ) return patched_src try: @@ -1516,4 +1533,5 @@ def replace_reload_weights(match): if generate_patched: logger.info("Unsloth: Patched trl VLLMGeneration.generate") + RL_ADDITIONAL_FUNCTIONS["vllm_generation"].append(vllm_generation_init_patch)