diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5d2c4151cf..0f10847282 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -30,6 +30,8 @@ RL_REPLACEMENTS, left_pack_padding, chunked_selective_log_softmax, + _unsloth_get_mm_token_id, + _unsloth_fix_mm_token_type_ids, ) from unsloth_zoo.utils import Version from trl import __version__ as trl_version_raw @@ -595,6 +597,23 @@ def remove_sync_weights_block(match): ]: function = re.sub(pattern, "", function) + string_to_find = ( + " generate_inputs = super()._prepare_inputs(generate_inputs)" + ) + replacement_string = ( + string_to_find + + """ + if "mm_token_type_ids" in generate_inputs or "image_grid_thw" in generate_inputs: + mm_token_type_ids = _unsloth_fix_mm_token_type_ids( + self.processing_class, + generate_inputs["input_ids"], + generate_inputs.get("mm_token_type_ids", None), + ) + if mm_token_type_ids is not None: + generate_inputs["mm_token_type_ids"] = mm_token_type_ids""" + ) + function = function.replace(string_to_find, replacement_string) + return function @@ -814,36 +833,45 @@ def grpo_trainer__generate_and_score_completions(function_name, function): function = patched - # Transformers 5.x: Extend mm_token_type_ids for completion tokens (Qwen3VL M-RoPE). - # TRL handles token_type_ids but not mm_token_type_ids. - _tt_search = ( - 'if "token_type_ids" in forward_kwargs:\n' - ' token_type_ids = forward_kwargs["token_type_ids"]\n' - ' forward_kwargs["token_type_ids"] = torch.cat(\n' - " [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1\n" - " )" - ) - _tt_replace = ( - _tt_search + "\n" - ' if "mm_token_type_ids" in forward_kwargs:\n' - ' mm_tti = forward_kwargs["mm_token_type_ids"]\n' - ' forward_kwargs["mm_token_type_ids"] = torch.cat(\n' - " [mm_tti, mm_tti.new_zeros(completion_ids.shape)], dim=1\n" - " )" - ) - function = function.replace(_tt_search, _tt_replace) + _mm_alignment = """ + if "mm_token_type_ids" in forward_kwargs or "image_grid_thw" in forward_kwargs: + _mm_token_type_ids = _unsloth_fix_mm_token_type_ids( + self.processing_class, + prompt_completion_ids, + forward_kwargs.get("mm_token_type_ids", None), + completion_ids = completion_ids, + ) + if _mm_token_type_ids is not None: + forward_kwargs["mm_token_type_ids"] = _mm_token_type_ids +""" + _tool_image_marker = " # For VLM tool images: build token type IDs from the full prompt_completion_ids." + if _tool_image_marker in function: + function = function.replace( + _tool_image_marker, _mm_alignment + "\n" + _tool_image_marker + ) + else: + _tt_search = ( + 'if "token_type_ids" in forward_kwargs:\n' + ' token_type_ids = forward_kwargs["token_type_ids"]\n' + ' forward_kwargs["token_type_ids"] = torch.cat(\n' + " [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1\n" + " )" + ) + function = function.replace( + _tt_search, _tt_search + "\n" + _mm_alignment.rstrip() + ) - # Save mm_token_type_ids to output dict alongside token_type_ids _save_search = ( 'if "token_type_ids" in forward_kwargs:\n' ' output["token_type_ids"] = forward_kwargs["token_type_ids"]' ) - _save_replace = ( - _save_search + "\n" - ' if "mm_token_type_ids" in forward_kwargs:\n' - ' output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]' - ) - function = function.replace(_save_search, _save_replace) + if 'output["mm_token_type_ids"]' not in function: + _save_replace = ( + _save_search + "\n" + ' if "mm_token_type_ids" in forward_kwargs:\n' + ' output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]' + ) + function = function.replace(_save_search, _save_replace) return function @@ -1020,6 +1048,10 @@ def _get_per_token_logps_and_entropies( # Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models token_type_ids = kwargs.get("token_type_ids", None) mm_token_type_ids = kwargs.get("mm_token_type_ids", None) + if mm_token_type_ids is not None or image_grid_thw is not None: + mm_token_type_ids = _unsloth_fix_mm_token_type_ids( + self.processing_class, input_ids, mm_token_type_ids + ) unwrapped_model = self.accelerator.unwrap_model( model, keep_fp32_wrapper = False @@ -1306,6 +1338,8 @@ def _unsloth_get_final_logit_softcapping(config): RL_PRE_ITEMS["grpo_trainer"].append( inspect.getsource(_unsloth_get_final_logit_softcapping) ) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(_unsloth_get_mm_token_id)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(_unsloth_fix_mm_token_type_ids)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) @@ -1352,6 +1386,13 @@ def compute_loss( input_ids = torch.cat([prompt_ids, completion_ids], dim = 1) bsz, qlen = input_ids.shape attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1) + if mm_token_type_ids is not None or image_grid_thw is not None: + mm_token_type_ids = _unsloth_fix_mm_token_type_ids( + self.processing_class, + input_ids, + mm_token_type_ids, + completion_ids = completion_ids, + ) # attention_mask = None logits_to_keep = completion_ids.size( 1