-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[trl] Trl v0.28 (and above) rl fixes #4156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cb6fb1b
a863d0f
690fecf
35721ee
c75d8aa
785e308
deafd91
50618b3
5bc60d6
0af28ae
f7e7e82
dc55ac7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |||||
| import re | ||||||
| import torch | ||||||
| import inspect | ||||||
| import linecache | ||||||
| from collections import defaultdict | ||||||
| from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding | ||||||
| from unsloth_zoo.utils import Version | ||||||
|
|
@@ -264,12 +265,30 @@ def grpo_trainer__generate_single_turn(function_name, function): | |||||
| # Remove the reload_weights collective RPC call from the generate function's source | ||||||
| # function = function.replace('self.llm.collective_rpc("reload_weights")', "") | ||||||
| # The regex below does the same thing but is more flexible and can handle single or double quotes | ||||||
| # This is for older versions. | ||||||
| function = re.sub( | ||||||
| r"self\.llm\.collective_rpc\(\s*(['\"])reload_weights\1\s*\)", | ||||||
| "", | ||||||
| function, | ||||||
| ) | ||||||
|
|
||||||
| # Current TRL versions call vllm_generation.sync_weights() every step. | ||||||
| # When Unsloth fast inference LoRA is active, weights are already shared. | ||||||
| sync_weights_block = re.compile( | ||||||
| r"(?P<indent>[ \t]*)with profiling_context\(self,\s*(['\"])sync_weights\2\s*\):\n" | ||||||
| r"(?P=indent)[ \t]+self\.vllm_generation\.sync_weights\(\)\n", | ||||||
| re.MULTILINE, | ||||||
| ) | ||||||
|
|
||||||
| def remove_sync_weights_block(match): | ||||||
| indent = match.group("indent") | ||||||
| return ( | ||||||
| f"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\n" | ||||||
| f"{indent}# Skipping per-step vLLM sync_weights().\n" | ||||||
| ) | ||||||
|
|
||||||
| function = sync_weights_block.sub(remove_sync_weights_block, function) | ||||||
|
|
||||||
| # TRL 0.24.0-0.25.1 truncation regression fix | ||||||
| # | ||||||
| # TRL 0.22.2-0.23.1 used smart truncation via truncate_with_protected_tokens(): | ||||||
|
|
@@ -1325,12 +1344,21 @@ def openenv_vllm_reload_weights(): | |||||
| ) | ||||||
| return | ||||||
|
|
||||||
| src = inspect.getsource(openenv_utils.generate_rollout_completions) | ||||||
| # trl 0.28 changed the function name yet again! Thanks trl :) | ||||||
| patch_target_name = "_generate_rollout_completions_colocate" | ||||||
| if hasattr(openenv_utils, patch_target_name): | ||||||
| patch_target = getattr(openenv_utils, patch_target_name) | ||||||
| else: | ||||||
| # Older TRL versions may keep sleep/wake logic in the public dispatcher. | ||||||
| patch_target_name = "generate_rollout_completions" | ||||||
| patch_target = getattr(openenv_utils, patch_target_name) | ||||||
|
|
||||||
| src = inspect.getsource(patch_target) | ||||||
| src = textwrap.dedent(src) | ||||||
| original_src = src | ||||||
|
|
||||||
| # Remove the reload_weights call - unsloth handles this differently | ||||||
| src = re.sub(r'.*\.collective_rpc\("reload_weights"\).*\n?', "", src) | ||||||
| src = re.sub(r'.*\.collective_rpc\(\s*([\'"])reload_weights\1\s*\).*\n?', "", src) | ||||||
|
|
||||||
| # Change wake_up(tags=["kv_cache"]) to wake_up() - wake everything to set is_sleeping=False | ||||||
| # This prevents double wake_up issues. Unsloth's allocator skips weights anyway. | ||||||
|
|
@@ -1343,12 +1371,167 @@ def openenv_vllm_reload_weights(): | |||||
| # Execute and explicitly assign to module | ||||||
| local_ns = {} | ||||||
| exec(compile(src, "<unsloth>", "exec"), openenv_utils.__dict__, local_ns) | ||||||
| patched_func = local_ns["generate_rollout_completions"] | ||||||
| patched_func = local_ns[patch_target_name] | ||||||
|
|
||||||
| # Patch both the utils module and the parent openenv module | ||||||
| openenv_utils.generate_rollout_completions = patched_func | ||||||
| openenv.generate_rollout_completions = patched_func | ||||||
| logger.info("Unsloth: Patched trl openenv generate_rollout_completions") | ||||||
| # Patch the target function in utils; if dispatcher was patched also update parent module alias. | ||||||
| setattr(openenv_utils, patch_target_name, patched_func) | ||||||
| if patch_target_name == "generate_rollout_completions": | ||||||
| openenv.generate_rollout_completions = patched_func | ||||||
| logger.info(f"Unsloth: Patched trl openenv {patch_target_name}") | ||||||
|
|
||||||
|
|
||||||
| RL_ADDITIONAL_FUNCTIONS["openenv"].append(openenv_vllm_reload_weights) | ||||||
|
|
||||||
|
|
||||||
| def vllm_generation_init_patch(): | ||||||
| # trl moved vllm stuff to trl/generation/vllm_generation.py | ||||||
| # We need to patch it to not instantiate another vLLM instance if we already have one with fast_inference | ||||||
| # Edit the TRL source directly and install the patched function in the TRL module. | ||||||
| # https://github.com/huggingface/trl/commit/0eb66d8f2fc63b3d00d8dbc18f99c3f48750bd16 | ||||||
| # This exists in trl versions 0.28.0 and above | ||||||
|
|
||||||
| if importlib.util.find_spec("trl") is None: | ||||||
| return | ||||||
| if Version(importlib_version("trl")) < Version("0.28.0"): | ||||||
| return | ||||||
|
|
||||||
| try: | ||||||
| import trl.generation.vllm_generation as vllm_generation | ||||||
| except (ImportError, NameError, Exception) as e: | ||||||
| logger.info(f"Unsloth: Failed to import trl.generation.vllm_generation: {e}") | ||||||
| return | ||||||
|
|
||||||
| def patch_vllm_generation_method(method_name, transform, marker, filename_suffix): | ||||||
| method = getattr(vllm_generation.VLLMGeneration, method_name, None) | ||||||
| if method is None: | ||||||
| logger.info(f"Unsloth: Could not find VLLMGeneration.{method_name}") | ||||||
| return False | ||||||
|
|
||||||
| try: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching the broad
Suggested change
|
||||||
| src = inspect.getsource(method) | ||||||
| except Exception as e: | ||||||
| logger.info( | ||||||
| f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}" | ||||||
| ) | ||||||
| return False | ||||||
|
|
||||||
| src = textwrap.dedent(src) | ||||||
| if marker in src: | ||||||
| return True | ||||||
|
|
||||||
| src = transform(src) | ||||||
| filename = f"<unsloth_trl_vllm_generation_{filename_suffix}_patch>" | ||||||
| source_lines = [line + "\n" for line in src.splitlines()] | ||||||
| linecache.cache[filename] = ( | ||||||
| len(src), | ||||||
| None, | ||||||
| source_lines, | ||||||
| filename, | ||||||
| ) | ||||||
|
|
||||||
| local_ns = {} | ||||||
| exec(compile(src, filename, "exec"), vllm_generation.__dict__, local_ns) | ||||||
| setattr(vllm_generation.VLLMGeneration, method_name, local_ns[method_name]) | ||||||
| return True | ||||||
|
|
||||||
| # Patch init to remove vLLM.LLM instantiation | ||||||
| def patch_init_vllm(src): | ||||||
| pattern = re.compile( | ||||||
| r"(?P<llm_block>^(?P<indent>[ \t]*)self\.llm\s*=\s*LLM\s*\(\n(?:.*\n)*?^(?P=indent)\))", | ||||||
| re.MULTILINE, | ||||||
| ) | ||||||
|
|
||||||
| def replace_llm_block(match): | ||||||
| indent = match.group("indent") | ||||||
| llm_block = textwrap.dedent(match.group("llm_block")) | ||||||
| return ( | ||||||
| f"{indent}if hasattr(model, 'vllm_engine'):\n" | ||||||
| f"{indent} # Unsloth already inits vLLM in fast inference mode. Do not redo :)\n" | ||||||
| f"{indent} self.llm = model.vllm_engine\n" | ||||||
| f"{indent} self.unsloth_fast_inference_lora = True\n" | ||||||
| f"{indent}else:\n" + textwrap.indent(llm_block, indent + " ") | ||||||
| ) | ||||||
|
|
||||||
| patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1) | ||||||
| if num_replacements == 0: | ||||||
| raise RuntimeError( | ||||||
| "Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed" | ||||||
| ) | ||||||
| return patched_src | ||||||
|
|
||||||
| # has some sync_weights or reload rpc calls. | ||||||
| # we patched the grpo_trainer to strip them for prev versions | ||||||
| # Ref: grpo_trainer__generate_single_turn above around L270-280 | ||||||
| def patch_sync_weights(src): | ||||||
| pattern = re.compile( | ||||||
| r"^(?P<def_line>def sync_weights\(self\):\n)(?P<body>(?:.*\n)*)", | ||||||
| re.MULTILINE, | ||||||
| ) | ||||||
|
|
||||||
| def replace_sync_weights(match): | ||||||
| body = match.group("body") | ||||||
| guard = ( | ||||||
| " if getattr(self, 'unsloth_fast_inference_lora', False):\n" | ||||||
| " # Unsloth fast inference LoRA shares weights with vLLM already.\n" | ||||||
| " return\n\n" | ||||||
| ) | ||||||
| return match.group("def_line") + guard + body | ||||||
|
|
||||||
| patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1) | ||||||
| if num_replacements == 0: | ||||||
| raise RuntimeError( | ||||||
| "Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed" | ||||||
| ) | ||||||
| return patched_src | ||||||
|
|
||||||
| def patch_generate(src): | ||||||
| pattern = re.compile( | ||||||
| r"^(?P<indent>[ \t]*)self\.llm\.collective_rpc\(\s*(['\"])reload_weights\2\s*\)\s*$", | ||||||
| re.MULTILINE, | ||||||
| ) | ||||||
|
|
||||||
| def replace_reload_weights(match): | ||||||
| indent = match.group("indent") | ||||||
| return f'{indent}pass # self.llm.collective_rpc("reload_weights")' | ||||||
|
|
||||||
| patched_src, num_replacements = pattern.subn( | ||||||
| replace_reload_weights, src, count = 1 | ||||||
| ) | ||||||
| if num_replacements == 0: | ||||||
| raise RuntimeError( | ||||||
| "Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed" | ||||||
| ) | ||||||
| return patched_src | ||||||
|
|
||||||
| try: | ||||||
| init_patched = patch_vllm_generation_method( | ||||||
| "_init_vllm", | ||||||
| patch_init_vllm, | ||||||
| "self.unsloth_fast_inference_lora = True", | ||||||
| "init_vllm", | ||||||
| ) | ||||||
| sync_patched = patch_vllm_generation_method( | ||||||
| "sync_weights", | ||||||
| patch_sync_weights, | ||||||
| "if getattr(self, 'unsloth_fast_inference_lora', False):", | ||||||
| "sync_weights", | ||||||
| ) | ||||||
| generate_patched = patch_vllm_generation_method( | ||||||
| "generate", | ||||||
| patch_generate, | ||||||
| 'pass # self.llm.collective_rpc("reload_weights")', | ||||||
| "generate", | ||||||
| ) | ||||||
| except RuntimeError as e: | ||||||
| logger.warning(str(e)) | ||||||
| return | ||||||
|
|
||||||
| if init_patched: | ||||||
| logger.info("Unsloth: Patched trl VLLMGeneration._init_vllm") | ||||||
| if sync_patched: | ||||||
| logger.info("Unsloth: Patched trl VLLMGeneration.sync_weights") | ||||||
| if generate_patched: | ||||||
| logger.info("Unsloth: Patched trl VLLMGeneration.generate") | ||||||
|
|
||||||
|
|
||||||
| RL_ADDITIONAL_FUNCTIONS["vllm_generation"].append(vllm_generation_init_patch) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exception handling
(ImportError, NameError, Exception)is redundant becauseExceptionis a base class for bothImportErrorandNameError. It's better to be more specific about the exceptions you expect to catch. In this case,ImportErroris the most likely exception if the module path is incorrect or thetrlversion is not as expected. Catching the broadExceptioncan mask other unexpected issues during the import process.