Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,33 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
pattern, new_options, RLTrainer_source, flags = re.DOTALL
)

if trl_version >= Version("0.27.0"):
peft_pattern = (
r"\s*if is_peft_available\(\) and is_peft_model\(model\) and args\.beta != 0\.0:"
r".*?"
r"param\.data = param\.data\.to\(torch\.bfloat16\)"
)

replacement_comment = "\n # PEFT initialization logic removed via script for trl >= 0.27.0\n"

RLTrainer_source = re.sub(
peft_pattern, replacement_comment, RLTrainer_source, flags = re.DOTALL
)

elif trl_version >= Version("0.26.0"):
peft_block_pattern = (
r"\s*if is_peft_available\(\) and isinstance\(model, PeftModel\) and peft_config is not None:"
r".*?"
r"param\.data = param\.data\.to\(torch\.bfloat16\)"
)

RLTrainer_source = re.sub(
peft_block_pattern,
"\n # TRL PEFT 0.26.0 initialization logic removed on unsloth side.\n",
RLTrainer_source,
flags = re.DOTALL,
)

if RLTrainer_name == "SFTTrainer":
original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]'
new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]'
Expand Down
167 changes: 70 additions & 97 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from collections import defaultdict
from unsloth_zoo.rl_replacements import RL_REPLACEMENTS, left_pack_padding
from unsloth_zoo.utils import Version
from trl import __version__ as trl_version_raw
from importlib.metadata import version as importlib_version
from unsloth_zoo.log import logger
from unsloth_zoo.device_type import device_synchronize
Expand Down Expand Up @@ -57,6 +58,14 @@
"triton.cudagraphs": False,
}

try:
trl_version = Version(trl_version_raw)
except Exception:
try:
trl_version = Version(importlib_version("trl"))
except Exception:
trl_version = Version("0.0.0")
Comment on lines +61 to +67
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can hide unexpected errors and make debugging harder. It's better to catch more specific exceptions. For example, Version() might raise a ValueError, and importlib_version() can raise PackageNotFoundError (which you'd need to import). Using more specific exceptions would make this code more robust and maintainable.



# Check untrained tokens
def sft_trainer_fix_untrained_tokens(call_args, extra_args):
Expand Down Expand Up @@ -434,99 +443,6 @@ def grpo_trainer__generate_and_score_completions(function_name, function):
_target_line + _metadata_extraction,
)

# Unsloth: Skip prepare_multimodal_messages when prompts are pre-templated strings.
# When notebooks pre-apply apply_chat_template(), prompts become strings with image tokens
# already embedded. Calling prepare_multimodal_messages on strings crashes with TypeError.
# Skipping it keeps prompts as strings so TRL uses the non-conversational path, which
# ensures completions are strings and reward functions work correctly.
string_to_find_vision = """ if images is not None:
prompts = [
prepare_multimodal_messages(prompt, image_list)
for prompt, image_list in zip(prompts, images, strict=True)
]"""

replacement_string_vision = """ if images is not None:
# Unsloth: skip prepare_multimodal_messages for pre-templated string prompts
if not prompts or not isinstance(prompts[0], str):
prompts = [
prepare_multimodal_messages(prompt, image_list)
for prompt, image_list in zip(prompts, images, strict=True)
]"""

function = function.replace(string_to_find_vision, replacement_string_vision)

# Unsloth: Skip apply_chat_template in the forward_kwargs block for pre-templated
# string prompts. When prompts are already strings (from notebooks that pre-applied
# apply_chat_template), calling it again crashes because strings aren't dicts.
# We use prompts directly as prompts_text instead.

# TRL 0.26.2+ variant (has tools=self.tools)
string_to_find_fwd = """ if images is not None:
prompts_text = [
apply_chat_template(
{"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs
)["prompt"]
for prompt in prompts
]"""

replacement_string_fwd = """ if images is not None:
# Unsloth: skip apply_chat_template for pre-templated string prompts
if prompts and isinstance(prompts[0], str):
prompts_text = prompts
else:
prompts_text = [
apply_chat_template(
{"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs
)["prompt"]
for prompt in prompts
]"""

function = function.replace(string_to_find_fwd, replacement_string_fwd)

# TRL 0.25.x variant (no tools parameter)
string_to_find_fwd_old = """ if images is not None:
prompts_text = [
apply_chat_template(
{"prompt": prompt}, self.processing_class, **self.chat_template_kwargs
)["prompt"]
for prompt in prompts
]"""

replacement_string_fwd_old = """ if images is not None:
# Unsloth: skip apply_chat_template for pre-templated string prompts
if prompts and isinstance(prompts[0], str):
prompts_text = prompts
else:
prompts_text = [
apply_chat_template(
{"prompt": prompt}, self.processing_class, **self.chat_template_kwargs
)["prompt"]
for prompt in prompts
]"""

function = function.replace(string_to_find_fwd_old, replacement_string_fwd_old)

# TRL 0.25.1 single-line variant (no tools, single-line apply_chat_template call)
string_to_find_fwd_single = """ if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
for prompt in prompts
]"""

replacement_string_fwd_single = """ if images is not None:
# Unsloth: skip apply_chat_template for pre-templated string prompts
if prompts and isinstance(prompts[0], str):
prompts_text = prompts
else:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
for prompt in prompts
]"""

function = function.replace(
string_to_find_fwd_single, replacement_string_fwd_single
)

# This path is for TRL 0.24.0 images is a variable exclusive to this version
string_to_find = """ if images is not None:
output["num_images"] = num_images"""
Expand All @@ -543,6 +459,17 @@ def grpo_trainer__generate_and_score_completions(function_name, function):

function = function.replace(string_to_find, replacement_string)

if trl_version >= Version("0.25.0"):
# We replace the call using 'completions' with one using 'completions_text'
string_to_find = " rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)"
replacement_string = (
" if images is not None:\n"
" rewards_per_func = self._calculate_rewards(inputs, prompts_text, completions_text, completion_ids_list)\n"
" else:\n"
" rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)"
)
function = function.replace(string_to_find, replacement_string)

if "wake_up()" not in function:
# Sleep functionality has been added to trl in v0.23.0. We do not want to redo this.
# https://github.com/huggingface/trl/commit/edbe8234bc7e528f72ac76607de9d3e4753e2709
Expand Down Expand Up @@ -1072,7 +999,7 @@ def compute_loss(

max_left_pad = inputs.get("max_left_pad", 0)
if per_token_logps is not None:
loss, completion_length, mean_kl, delta, flat_is_ratio = (
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = (
grpo_compute_loss_slow(
ref_logps,
per_token_logps,
Expand Down Expand Up @@ -1102,7 +1029,7 @@ def compute_loss(
)
else:
if hasattr(self.args, "loss_type"):
loss, completion_length, mean_kl, delta, flat_is_ratio = (
loss, completion_length, mean_kl, delta, flat_is_ratio, coef_1 = (
grpo_accumulated_loss(
trainer = self,
input_ids = _input_ids,
Expand Down Expand Up @@ -1134,7 +1061,7 @@ def compute_loss(
)
else:
# to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17
loss, completion_length, mean_kl = grpo_accumulated_loss(
loss, completion_length, mean_kl, coef_1 = grpo_accumulated_loss(
trainer = self,
Comment on lines 1063 to 1065
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve backward-compatibility return arity

The backward-compatibility branch now unpacks four values from grpo_accumulated_loss, but that path exists specifically for older TRL versions (“0.15.2 and maybe even 0.17”) where the function historically returned only three values. In those environments, this will raise ValueError: not enough values to unpack at runtime, defeating the compatibility fallback. Consider guarding on the returned tuple length or only unpacking coef_1 when the underlying implementation provides it.

Useful? React with 👍 / 👎.

input_ids = _input_ids,
logits_to_keep = logits_to_keep,
Expand All @@ -1149,7 +1076,6 @@ def compute_loss(
logit_scale_divide = logit_scale_divide,
attention_mask = attention_mask,
)

if "train" in self._metrics:
mode = "eval" if self.control.should_evaluate else "train"
self._metrics[mode]["completion_length"].append(completion_length.item())
Expand Down Expand Up @@ -1211,6 +1137,53 @@ def compute_loss(
.item()
)

completion_token_count = completion_mask.sum().clamp(min = 1.0)

def masked_batch_mean(x):
if x.shape[1] == 1: # when importance_sampling_level == "sequence"
return x.mean()
else:
return (x * completion_mask).sum() / completion_token_count

if advantages.dim() == 1:
advantages = advantages.unsqueeze(1)

if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
is_region_clipped = is_low_clipped | is_high_clipped

low_clip = masked_batch_mean(is_low_clipped.float())
high_clip = masked_batch_mean(is_high_clipped.float())
clip_ratio = masked_batch_mean(is_region_clipped.float())

gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(
gathered_low_clip.nanmean().item()
Comment on lines +1160 to +1163
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The functions nanmin and nanmax are used here without being defined or imported. This will lead to a NameError at runtime. You should use torch.nanmin and torch.nanmax instead, as torch is available in the execution context of this code.

Suggested change
gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(
gathered_low_clip.nanmean().item()
self._metrics[mode]["clip_ratio/low_min"].append(torch.nanmin(gathered_low_clip).item())
gathered_high_clip = self.accelerator.gather(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
self._metrics[mode]["clip_ratio/high_max"].append(torch.nanmax(gathered_high_clip).item())

)
self._metrics[mode]["clip_ratio/low_min"].append(
nanmin(gathered_low_clip).item()
)
gathered_high_clip = self.accelerator.gather(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(
gathered_high_clip.nanmean().item()
)
self._metrics[mode]["clip_ratio/high_max"].append(
nanmax(gathered_high_clip).item()
)
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(
gathered_clip_ratio.nanmean().item()
)
elif self.loss_type == "cispo":
is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0)
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
self._metrics[mode]["cispo_clip_ratio"].append(
gathered_cispo_clip_ratio.nanmean().item()
)

return loss

function = inspect.getsource(compute_loss)
Expand Down