Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 66 additions & 25 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
RL_REPLACEMENTS,
left_pack_padding,
chunked_selective_log_softmax,
_unsloth_get_mm_token_id,
_unsloth_fix_mm_token_type_ids,
)
from unsloth_zoo.utils import Version
from trl import __version__ as trl_version_raw
Expand Down Expand Up @@ -595,6 +597,23 @@ def remove_sync_weights_block(match):
]:
function = re.sub(pattern, "", function)

string_to_find = (
" generate_inputs = super()._prepare_inputs(generate_inputs)"
)
replacement_string = (
string_to_find
+ """
if "mm_token_type_ids" in generate_inputs or "image_grid_thw" in generate_inputs:
mm_token_type_ids = _unsloth_fix_mm_token_type_ids(
self.processing_class,
generate_inputs["input_ids"],
generate_inputs.get("mm_token_type_ids", None),
)
if mm_token_type_ids is not None:
generate_inputs["mm_token_type_ids"] = mm_token_type_ids"""
)
function = function.replace(string_to_find, replacement_string)

return function


Expand Down Expand Up @@ -814,36 +833,45 @@ def grpo_trainer__generate_and_score_completions(function_name, function):

function = patched

# Transformers 5.x: Extend mm_token_type_ids for completion tokens (Qwen3VL M-RoPE).
# TRL handles token_type_ids but not mm_token_type_ids.
_tt_search = (
'if "token_type_ids" in forward_kwargs:\n'
' token_type_ids = forward_kwargs["token_type_ids"]\n'
' forward_kwargs["token_type_ids"] = torch.cat(\n'
" [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1\n"
" )"
)
_tt_replace = (
_tt_search + "\n"
' if "mm_token_type_ids" in forward_kwargs:\n'
' mm_tti = forward_kwargs["mm_token_type_ids"]\n'
' forward_kwargs["mm_token_type_ids"] = torch.cat(\n'
" [mm_tti, mm_tti.new_zeros(completion_ids.shape)], dim=1\n"
" )"
)
function = function.replace(_tt_search, _tt_replace)
_mm_alignment = """
if "mm_token_type_ids" in forward_kwargs or "image_grid_thw" in forward_kwargs:
_mm_token_type_ids = _unsloth_fix_mm_token_type_ids(
self.processing_class,
prompt_completion_ids,
forward_kwargs.get("mm_token_type_ids", None),
completion_ids = completion_ids,
)
if _mm_token_type_ids is not None:
forward_kwargs["mm_token_type_ids"] = _mm_token_type_ids
"""
_tool_image_marker = " # For VLM tool images: build token type IDs from the full prompt_completion_ids."
if _tool_image_marker in function:
function = function.replace(
_tool_image_marker, _mm_alignment + "\n" + _tool_image_marker
)
else:
_tt_search = (
'if "token_type_ids" in forward_kwargs:\n'
' token_type_ids = forward_kwargs["token_type_ids"]\n'
' forward_kwargs["token_type_ids"] = torch.cat(\n'
" [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1\n"
" )"
)
function = function.replace(
_tt_search, _tt_search + "\n" + _mm_alignment.rstrip()
)

# Save mm_token_type_ids to output dict alongside token_type_ids
_save_search = (
'if "token_type_ids" in forward_kwargs:\n'
' output["token_type_ids"] = forward_kwargs["token_type_ids"]'
)
_save_replace = (
_save_search + "\n"
' if "mm_token_type_ids" in forward_kwargs:\n'
' output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]'
)
function = function.replace(_save_search, _save_replace)
if 'output["mm_token_type_ids"]' not in function:
_save_replace = (
_save_search + "\n"
' if "mm_token_type_ids" in forward_kwargs:\n'
' output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]'
)
function = function.replace(_save_search, _save_replace)

return function

Expand Down Expand Up @@ -1020,6 +1048,10 @@ def _get_per_token_logps_and_entropies(
# Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models
token_type_ids = kwargs.get("token_type_ids", None)
mm_token_type_ids = kwargs.get("mm_token_type_ids", None)
if mm_token_type_ids is not None or image_grid_thw is not None:
mm_token_type_ids = _unsloth_fix_mm_token_type_ids(
self.processing_class, input_ids, mm_token_type_ids
)

unwrapped_model = self.accelerator.unwrap_model(
model, keep_fp32_wrapper = False
Expand Down Expand Up @@ -1306,6 +1338,8 @@ def _unsloth_get_final_logit_softcapping(config):
RL_PRE_ITEMS["grpo_trainer"].append(
inspect.getsource(_unsloth_get_final_logit_softcapping)
)
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(_unsloth_get_mm_token_id))
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(_unsloth_fix_mm_token_type_ids))
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss))
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO))
RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss))
Expand Down Expand Up @@ -1352,6 +1386,13 @@ def compute_loss(
input_ids = torch.cat([prompt_ids, completion_ids], dim = 1)
bsz, qlen = input_ids.shape
attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1)
if mm_token_type_ids is not None or image_grid_thw is not None:
mm_token_type_ids = _unsloth_fix_mm_token_type_ids(
self.processing_class,
input_ids,
mm_token_type_ids,
completion_ids = completion_ids,
)
# attention_mask = None
logits_to_keep = completion_ids.size(
1
Expand Down