Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ class Unsloth{RLConfig_name}({RLConfig_name}):
def __init__({RLConfig_arguments},
vllm_sampling_params = None,
unsloth_num_chunks = -1,
unsloth_logit_chunk_multiplier = None,
unsloth_grpo_mini_batch = None,
unsloth_logit_chunk_multiplier = None,
unsloth_grpo_mini_batch = None,
{max_seq_length_call}
**kwargs,
):
Expand Down Expand Up @@ -1876,10 +1876,35 @@ def patch_trl_openenv():
return


def patch_trl_vllm_generation():
# trl moved vllm stuff to trl/generation/vllm_generation.py
# We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference
# Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause
for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]:
logger.info(
f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}"
)
function()
return


def patch_trl_vllm_generation():
# trl moved vllm stuff to trl/generation/vllm_generation.py
# We need to min_p patch it to not instantiate another vLLM instance if we already have one with fast_inference
# Find the instance of self.llm = LLM(..) (multiline) and wrap it around an if clause
for function in RL_ADDITIONAL_FUNCTIONS["vllm_generation"]:
logger.info(
f"Unsloth: Patching trl VLLMGeneration with function: {function.__name__}"
)
function()
return


def PatchFastRL(algorithm = None, FastLanguageModel = None):
if FastLanguageModel is not None:
PatchRL(FastLanguageModel)
patch_trl_rl_trainers()
patch_trl_openenv()
patch_trl_vllm_generation()
if type(algorithm) is str and algorithm.islower():
PatchRLStatistics(algorithm)
197 changes: 190 additions & 7 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import re
import torch
import inspect
import linecache
from collections import defaultdict
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding
from unsloth_zoo.utils import Version
Expand Down Expand Up @@ -264,12 +265,30 @@ def grpo_trainer__generate_single_turn(function_name, function):
# Remove the reload_weights collective RPC call from the generate function's source
# function = function.replace('self.llm.collective_rpc("reload_weights")', "")
# The regex below does the same thing but is more flexible and can handle single or double quotes
# This is for older versions.
function = re.sub(
r"self\.llm\.collective_rpc\(\s*(['\"])reload_weights\1\s*\)",
"",
function,
)

# Current TRL versions call vllm_generation.sync_weights() every step.
# When Unsloth fast inference LoRA is active, weights are already shared.
sync_weights_block = re.compile(
r"(?P<indent>[ \t]*)with profiling_context\(self,\s*(['\"])sync_weights\2\s*\):\n"
r"(?P=indent)[ \t]+self\.vllm_generation\.sync_weights\(\)\n",
re.MULTILINE,
)

def remove_sync_weights_block(match):
indent = match.group("indent")
return (
f"{indent}# Unsloth fast inference LoRA shares weights with vLLM already.\n"
f"{indent}# Skipping per-step vLLM sync_weights().\n"
)

function = sync_weights_block.sub(remove_sync_weights_block, function)

# TRL 0.24.0-0.25.1 truncation regression fix
#
# TRL 0.22.2-0.23.1 used smart truncation via truncate_with_protected_tokens():
Expand Down Expand Up @@ -1325,12 +1344,21 @@ def openenv_vllm_reload_weights():
)
return

src = inspect.getsource(openenv_utils.generate_rollout_completions)
# trl 0.28 changed the function name yet again! Thanks trl :)
patch_target_name = "_generate_rollout_completions_colocate"
if hasattr(openenv_utils, patch_target_name):
patch_target = getattr(openenv_utils, patch_target_name)
else:
# Older TRL versions may keep sleep/wake logic in the public dispatcher.
patch_target_name = "generate_rollout_completions"
patch_target = getattr(openenv_utils, patch_target_name)

src = inspect.getsource(patch_target)
src = textwrap.dedent(src)
original_src = src

# Remove the reload_weights call - unsloth handles this differently
src = re.sub(r'.*\.collective_rpc\("reload_weights"\).*\n?', "", src)
src = re.sub(r'.*\.collective_rpc\(\s*([\'"])reload_weights\1\s*\).*\n?', "", src)

# Change wake_up(tags=["kv_cache"]) to wake_up() - wake everything to set is_sleeping=False
# This prevents double wake_up issues. Unsloth's allocator skips weights anyway.
Expand All @@ -1343,12 +1371,167 @@ def openenv_vllm_reload_weights():
# Execute and explicitly assign to module
local_ns = {}
exec(compile(src, "<unsloth>", "exec"), openenv_utils.__dict__, local_ns)
patched_func = local_ns["generate_rollout_completions"]
patched_func = local_ns[patch_target_name]

# Patch both the utils module and the parent openenv module
openenv_utils.generate_rollout_completions = patched_func
openenv.generate_rollout_completions = patched_func
logger.info("Unsloth: Patched trl openenv generate_rollout_completions")
# Patch the target function in utils; if dispatcher was patched also update parent module alias.
setattr(openenv_utils, patch_target_name, patched_func)
if patch_target_name == "generate_rollout_completions":
openenv.generate_rollout_completions = patched_func
logger.info(f"Unsloth: Patched trl openenv {patch_target_name}")


RL_ADDITIONAL_FUNCTIONS["openenv"].append(openenv_vllm_reload_weights)


def vllm_generation_init_patch():
# trl moved vllm stuff to trl/generation/vllm_generation.py
# We need to patch it to not instantiate another vLLM instance if we already have one with fast_inference
# Edit the TRL source directly and install the patched function in the TRL module.
# https://github.com/huggingface/trl/commit/0eb66d8f2fc63b3d00d8dbc18f99c3f48750bd16
# This exists in trl versions 0.28.0 and above

if importlib.util.find_spec("trl") is None:
return
if Version(importlib_version("trl")) < Version("0.28.0"):
return

try:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The exception handling (ImportError, NameError, Exception) is redundant because Exception is a base class for both ImportError and NameError. It's better to be more specific about the exceptions you expect to catch. In this case, ImportError is the most likely exception if the module path is incorrect or the trl version is not as expected. Catching the broad Exception can mask other unexpected issues during the import process.

Suggested change
try:
except ImportError as e:

import trl.generation.vllm_generation as vllm_generation
except (ImportError, NameError, Exception) as e:
logger.info(f"Unsloth: Failed to import trl.generation.vllm_generation: {e}")
return

def patch_vllm_generation_method(method_name, transform, marker, filename_suffix):
method = getattr(vllm_generation.VLLMGeneration, method_name, None)
if method is None:
logger.info(f"Unsloth: Could not find VLLMGeneration.{method_name}")
return False

try:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching the broad Exception class can hide unexpected errors. The inspect.getsource() function is documented to raise TypeError for unsupported object types and OSError if the source file cannot be retrieved. It is better practice to catch these specific exceptions to make the error handling more precise and robust.

Suggested change
try:
except (TypeError, OSError) as e:

src = inspect.getsource(method)
except Exception as e:
logger.info(
f"Unsloth: Could not get source of VLLMGeneration.{method_name}: {e}"
)
return False

src = textwrap.dedent(src)
if marker in src:
return True

src = transform(src)
filename = f"<unsloth_trl_vllm_generation_{filename_suffix}_patch>"
source_lines = [line + "\n" for line in src.splitlines()]
linecache.cache[filename] = (
len(src),
None,
source_lines,
filename,
)

local_ns = {}
exec(compile(src, filename, "exec"), vllm_generation.__dict__, local_ns)
setattr(vllm_generation.VLLMGeneration, method_name, local_ns[method_name])
return True

# Patch init to remove vLLM.LLM instantiation
def patch_init_vllm(src):
pattern = re.compile(
r"(?P<llm_block>^(?P<indent>[ \t]*)self\.llm\s*=\s*LLM\s*\(\n(?:.*\n)*?^(?P=indent)\))",
re.MULTILINE,
)

def replace_llm_block(match):
indent = match.group("indent")
llm_block = textwrap.dedent(match.group("llm_block"))
return (
f"{indent}if hasattr(model, 'vllm_engine'):\n"
f"{indent} # Unsloth already inits vLLM in fast inference mode. Do not redo :)\n"
f"{indent} self.llm = model.vllm_engine\n"
f"{indent} self.unsloth_fast_inference_lora = True\n"
f"{indent}else:\n" + textwrap.indent(llm_block, indent + " ")
)

patched_src, num_replacements = pattern.subn(replace_llm_block, src, count = 1)
if num_replacements == 0:
raise RuntimeError(
"Unsloth: Warning - regex did not match, VLLMGeneration._init_vllm patch may have failed"
)
return patched_src

# has some sync_weights or reload rpc calls.
# we patched the grpo_trainer to strip them for prev versions
# Ref: grpo_trainer__generate_single_turn above around L270-280
def patch_sync_weights(src):
pattern = re.compile(
r"^(?P<def_line>def sync_weights\(self\):\n)(?P<body>(?:.*\n)*)",
re.MULTILINE,
)

def replace_sync_weights(match):
body = match.group("body")
guard = (
" if getattr(self, 'unsloth_fast_inference_lora', False):\n"
" # Unsloth fast inference LoRA shares weights with vLLM already.\n"
" return\n\n"
)
return match.group("def_line") + guard + body

patched_src, num_replacements = pattern.subn(replace_sync_weights, src, count = 1)
if num_replacements == 0:
raise RuntimeError(
"Unsloth: Warning - regex did not match, VLLMGeneration.sync_weights patch may have failed"
)
return patched_src

def patch_generate(src):
pattern = re.compile(
r"^(?P<indent>[ \t]*)self\.llm\.collective_rpc\(\s*(['\"])reload_weights\2\s*\)\s*$",
re.MULTILINE,
)

def replace_reload_weights(match):
indent = match.group("indent")
return f'{indent}pass # self.llm.collective_rpc("reload_weights")'

patched_src, num_replacements = pattern.subn(
replace_reload_weights, src, count = 1
)
if num_replacements == 0:
raise RuntimeError(
"Unsloth: Warning - regex did not match, VLLMGeneration.generate patch may have failed"
)
return patched_src

try:
init_patched = patch_vllm_generation_method(
"_init_vllm",
patch_init_vllm,
"self.unsloth_fast_inference_lora = True",
"init_vllm",
)
sync_patched = patch_vllm_generation_method(
"sync_weights",
patch_sync_weights,
"if getattr(self, 'unsloth_fast_inference_lora', False):",
"sync_weights",
)
generate_patched = patch_vllm_generation_method(
"generate",
patch_generate,
'pass # self.llm.collective_rpc("reload_weights")',
"generate",
)
except RuntimeError as e:
logger.warning(str(e))
return

if init_patched:
logger.info("Unsloth: Patched trl VLLMGeneration._init_vllm")
if sync_patched:
logger.info("Unsloth: Patched trl VLLMGeneration.sync_weights")
if generate_patched:
logger.info("Unsloth: Patched trl VLLMGeneration.generate")


RL_ADDITIONAL_FUNCTIONS["vllm_generation"].append(vllm_generation_init_patch)