diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e4f34c908e..30546a048d 100755 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -389,8 +389,8 @@ class Unsloth{RLConfig_name}({RLConfig_name}): def __init__({RLConfig_arguments}, vllm_sampling_params = None, unsloth_num_chunks = -1, - unsloth_logit_chunk_multiplier = None, - unsloth_grpo_mini_batch = None, + unsloth_logit_chunk_multiplier = None, + unsloth_grpo_mini_batch = None, {max_seq_length_call} **kwargs, ): @@ -1876,10 +1876,35 @@ def patch_trl_openenv(): return +def patch_trl_vllm_generation(): + # trl moved vllm stuff to trl/generation/vllm_generation.py + # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference + # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause + for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]: + logger.info( + f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}" + ) + function() + return + + +def patch_trl_vllm_generation(): + # trl moved vllm stuff to trl/generation/vllm_generation.py + # We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference + # Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause + for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]: + logger.info( + f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}" + ) + function() + return + + def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() patch_trl_openenv() + patch_trl_vllm_generation() if type(algorithm) is str and algorithm.islower(): PatchRLStatistics(algorithm) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 69e9ce3455..314feb5d2a 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -24,6 +24,7 @@ import re import torch import inspect +import linecache from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding from unsloth_zoo.utils import Version @@ -264,12 +265,30 @@ def grpo_trainer__generate_single_turn(function_name, function): # Remove the reload_weights collective RPC call from the generate function's source # function = function.replace('self.llm.collective_rpc("reload_weights")', "") # The regex below does the same thing but is more flexible and can handle single or double quotes + # This is for older versions. function = re.sub( r"self\.llm\.collective_rpc\(\s*(['\"])reload_weights\1\s*\)", "", function, ) + # Current TRL versions call vllm_generation.sync_weights() every step. + # When Unsloth fast inference LoRA is active, weights are already shared. + sync_weights_block = re.compile( + r"(?P[ \t]*)with profiling_context\(self,\s*(['\"])sync_weights\2\s*\):\n" + r"(?P=indent)[ \t]+self\.vllm_generation\.sync_weights\(\)\n", + re.MULTILINE, + ) + + def remove_sync_weights_block(match): + indent = match.group("indent") + return ( + f"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\n" + f"{indent}# Skipping per-step vLLM sync_weights().\n" + ) + + function = sync_weights_block.sub(remove_sync_weights_block, function) + # TRL 0.24.0-0.25.1 truncation regression fix # # TRL 0.22.2-0.23.1 used smart truncation via truncate_with_protected_tokens(): @@ -1325,12 +1344,21 @@ def openenv_vllm_reload_weights(): ) return - src = inspect.getsource(openenv_utils.generate_rollout_completions) + # trl 0.28 changed the function name yet again! Thanks trl :) + patch_target_name = "_generate_rollout_completions_colocate" + if hasattr(openenv_utils, patch_target_name): + patch_target = getattr(openenv_utils, patch_target_name) + else: + # Older TRL versions may keep sleep/wake logic in the public dispatcher. + patch_target_name = "generate_rollout_completions" + patch_target = getattr(openenv_utils, patch_target_name) + + src = inspect.getsource(patch_target) src = textwrap.dedent(src) original_src = src # Remove the reload_weights call - unsloth handles this differently - src = re.sub(r'.*\.collective_rpc\("reload_weights"\).*\n?', "", src) + src = re.sub(r'.*\.collective_rpc\(\s*([\'"])reload_weights\1\s*\).*\n?', "", src) # Change wake_up(tags=["kv_cache"]) to wake_up() - wake everything to set is_sleeping=False # This prevents double wake_up issues. Unsloth's allocator skips weights anyway. @@ -1343,12 +1371,167 @@ def openenv_vllm_reload_weights(): # Execute and explicitly assign to module local_ns = {} exec(compile(src, "", "exec"), openenv_utils.__dict__, local_ns) - patched_func = local_ns["generate_rollout_completions"] + patched_func = local_ns[patch_target_name] - # Patch both the utils module and the parent openenv module - openenv_utils.generate_rollout_completions = patched_func - openenv.generate_rollout_completions = patched_func - logger.info("Unsloth: Patched trl openenv generate_rollout_completions") + # Patch the target function in utils; if dispatcher was patched also update parent module alias. + setattr(openenv_utils, patch_target_name, patched_func) + if patch_target_name == "generate_rollout_completions": + openenv.generate_rollout_completions = patched_func + logger.info(f"Unsloth: Patched trl openenv {patch_target_name}") RL_ADDITIONAL_FUNCTIONS["openenv"].append(openenv_vllm_reload_weights) + + +def vllm_generation_init_patch(): + # trl moved vllm stuff to trl/generation/vllm_generation.py + # We need to patch it to not instantiate another vLLM instance if we already have one with fast_inference + # Edit the TRL source directly and install the patched function in the TRL module. + # https://github.com/huggingface/trl/commit/0eb66d8f2fc63b3d00d8dbc18f99c3f48750bd16 + # This exists in trl versions 0.28.0 and above + + if importlib.util.find_spec("trl") is None: + return + if Version(importlib_version("trl")) < Version("0.28.0"): + return + + try: + import trl.generation.vllm_generation as vllm_generation + except (ImportError, NameError, Exception) as e: + logger.info(f"Unsloth: Failed to import trl.generation.vllm_generation: {e}") + return + + def patch_vllm_generation_method(method_name, transform, marker, filename_suffix): + method = getattr(vllm_generation.VLLMGeneration, method_name, None) + if method is None: + logger.info(f"Unsloth: Could not find VLLMGeneration.{method_name}") + return False + + try: + src = inspect.getsource(method) + except Exception as e: + logger.info( + f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}" + ) + return False + + src = textwrap.dedent(src) + if marker in src: + return True + + src = transform(src) + filename = f"" + source_lines = [line + "\n" for line in src.splitlines()] + linecache.cache[filename] = ( + len(src), + None, + source_lines, + filename, + ) + + local_ns = {} + exec(compile(src, filename, "exec"), vllm_generation.__dict__, local_ns) + setattr(vllm_generation.VLLMGeneration, method_name, local_ns[method_name]) + return True + + # Patch init to remove vLLM.LLM instantiation + def patch_init_vllm(src): + pattern = re.compile( + r"(?P^(?P[ \t]*)self\.llm\s*=\s*LLM\s*\(\n(?:.*\n)*?^(?P=indent)\))", + re.MULTILINE, + ) + + def replace_llm_block(match): + indent = match.group("indent") + llm_block = textwrap.dedent(match.group("llm_block")) + return ( + f"{indent}if hasattr(model, 'vllm_engine'):\n" + f"{indent} # Unsloth already inits vLLM in fast inference mode. Do not redo :)\n" + f"{indent} self.llm = model.vllm_engine\n" + f"{indent} self.unsloth_fast_inference_lora = True\n" + f"{indent}else:\n" + textwrap.indent(llm_block, indent + " ") + ) + + patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1) + if num_replacements == 0: + raise RuntimeError( + "Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed" + ) + return patched_src + + # has some sync_weights or reload rpc calls. + # we patched the grpo_trainer to strip them for prev versions + # Ref: grpo_trainer__generate_single_turn above around L270-280 + def patch_sync_weights(src): + pattern = re.compile( + r"^(?Pdef sync_weights\(self\):\n)(?P(?:.*\n)*)", + re.MULTILINE, + ) + + def replace_sync_weights(match): + body = match.group("body") + guard = ( + " if getattr(self, 'unsloth_fast_inference_lora', False):\n" + " # Unsloth fast inference LoRA shares weights with vLLM already.\n" + " return\n\n" + ) + return match.group("def_line") + guard + body + + patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1) + if num_replacements == 0: + raise RuntimeError( + "Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed" + ) + return patched_src + + def patch_generate(src): + pattern = re.compile( + r"^(?P[ \t]*)self\.llm\.collective_rpc\(\s*(['\"])reload_weights\2\s*\)\s*$", + re.MULTILINE, + ) + + def replace_reload_weights(match): + indent = match.group("indent") + return f'{indent}pass # self.llm.collective_rpc("reload_weights")' + + patched_src, num_replacements = pattern.subn( + replace_reload_weights, src, count = 1 + ) + if num_replacements == 0: + raise RuntimeError( + "Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed" + ) + return patched_src + + try: + init_patched = patch_vllm_generation_method( + "_init_vllm", + patch_init_vllm, + "self.unsloth_fast_inference_lora = True", + "init_vllm", + ) + sync_patched = patch_vllm_generation_method( + "sync_weights", + patch_sync_weights, + "if getattr(self, 'unsloth_fast_inference_lora', False):", + "sync_weights", + ) + generate_patched = patch_vllm_generation_method( + "generate", + patch_generate, + 'pass # self.llm.collective_rpc("reload_weights")', + "generate", + ) + except RuntimeError as e: + logger.warning(str(e)) + return + + if init_patched: + logger.info("Unsloth: Patched trl VLLMGeneration._init_vllm") + if sync_patched: + logger.info("Unsloth: Patched trl VLLMGeneration.sync_weights") + if generate_patched: + logger.info("Unsloth: Patched trl VLLMGeneration.generate") + + +RL_ADDITIONAL_FUNCTIONS["vllm_generation"].append(vllm_generation_init_patch)