From 5ee2ea5f3473e75d8d4b4c9f4b2f70e322baea2a Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 27 Apr 2026 05:25:36 +0000 Subject: [PATCH 1/2] MROPE for VLM GRPO --- unsloth/models/rl_replacements.py | 82 +++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4d36af62cc..da9d00b1c2 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -30,6 +30,8 @@ RL_REPLACEMENTS, left_pack_padding, chunked_selective_log_softmax, + _unsloth_get_mm_token_id, + _unsloth_fix_mm_token_type_ids, ) from unsloth_zoo.utils import Version from trl import __version__ as trl_version_raw @@ -323,6 +325,18 @@ def remove_sync_weights_block(match): ]: function = re.sub(pattern, "", function) + string_to_find = " generate_inputs = super()._prepare_inputs(generate_inputs)" + replacement_string = string_to_find + """ + if "mm_token_type_ids" in generate_inputs or "image_grid_thw" in generate_inputs: + mm_token_type_ids = _unsloth_fix_mm_token_type_ids( + self.processing_class, + generate_inputs["input_ids"], + generate_inputs.get("mm_token_type_ids", None), + ) + if mm_token_type_ids is not None: + generate_inputs["mm_token_type_ids"] = mm_token_type_ids""" + function = function.replace(string_to_find, replacement_string) + return function @@ -542,36 +556,41 @@ def grpo_trainer__generate_and_score_completions(function_name, function): function = patched - # Transformers 5.x: Extend mm_token_type_ids for completion tokens (Qwen3VL M-RoPE). - # TRL handles token_type_ids but not mm_token_type_ids. - _tt_search = ( - 'if "token_type_ids" in forward_kwargs:\n' - ' token_type_ids = forward_kwargs["token_type_ids"]\n' - ' forward_kwargs["token_type_ids"] = torch.cat(\n' - " [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1\n" - " )" - ) - _tt_replace = ( - _tt_search + "\n" - ' if "mm_token_type_ids" in forward_kwargs:\n' - ' mm_tti = forward_kwargs["mm_token_type_ids"]\n' - ' forward_kwargs["mm_token_type_ids"] = torch.cat(\n' - " [mm_tti, mm_tti.new_zeros(completion_ids.shape)], dim=1\n" - " )" - ) - function = function.replace(_tt_search, _tt_replace) + _mm_alignment = """ + if "mm_token_type_ids" in forward_kwargs or "image_grid_thw" in forward_kwargs: + _mm_token_type_ids = _unsloth_fix_mm_token_type_ids( + self.processing_class, + prompt_completion_ids, + forward_kwargs.get("mm_token_type_ids", None), + completion_ids = completion_ids, + ) + if _mm_token_type_ids is not None: + forward_kwargs["mm_token_type_ids"] = _mm_token_type_ids +""" + _tool_image_marker = " # For VLM tool images: build token type IDs from the full prompt_completion_ids." + if _tool_image_marker in function: + function = function.replace(_tool_image_marker, _mm_alignment + "\n" + _tool_image_marker) + else: + _tt_search = ( + 'if "token_type_ids" in forward_kwargs:\n' + ' token_type_ids = forward_kwargs["token_type_ids"]\n' + ' forward_kwargs["token_type_ids"] = torch.cat(\n' + " [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1\n" + " )" + ) + function = function.replace(_tt_search, _tt_search + "\n" + _mm_alignment.rstrip()) - # Save mm_token_type_ids to output dict alongside token_type_ids _save_search = ( 'if "token_type_ids" in forward_kwargs:\n' ' output["token_type_ids"] = forward_kwargs["token_type_ids"]' ) - _save_replace = ( - _save_search + "\n" - ' if "mm_token_type_ids" in forward_kwargs:\n' - ' output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]' - ) - function = function.replace(_save_search, _save_replace) + if 'output["mm_token_type_ids"]' not in function: + _save_replace = ( + _save_search + "\n" + ' if "mm_token_type_ids" in forward_kwargs:\n' + ' output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]' + ) + function = function.replace(_save_search, _save_replace) return function @@ -748,6 +767,10 @@ def _get_per_token_logps_and_entropies( # Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models token_type_ids = kwargs.get("token_type_ids", None) mm_token_type_ids = kwargs.get("mm_token_type_ids", None) + if mm_token_type_ids is not None or image_grid_thw is not None: + mm_token_type_ids = _unsloth_fix_mm_token_type_ids( + self.processing_class, input_ids, mm_token_type_ids + ) unwrapped_model = self.accelerator.unwrap_model( model, keep_fp32_wrapper = False @@ -1034,6 +1057,8 @@ def _unsloth_get_final_logit_softcapping(config): RL_PRE_ITEMS["grpo_trainer"].append( inspect.getsource(_unsloth_get_final_logit_softcapping) ) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(_unsloth_get_mm_token_id)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(_unsloth_fix_mm_token_type_ids)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) @@ -1080,6 +1105,13 @@ def compute_loss( input_ids = torch.cat([prompt_ids, completion_ids], dim = 1) bsz, qlen = input_ids.shape attention_mask = torch.cat([prompt_mask, completion_mask], dim = 1) + if mm_token_type_ids is not None or image_grid_thw is not None: + mm_token_type_ids = _unsloth_fix_mm_token_type_ids( + self.processing_class, + input_ids, + mm_token_type_ids, + completion_ids = completion_ids, + ) # attention_mask = None logits_to_keep = completion_ids.size( 1 From a5eb7525dfe5475fedffd6948c678028ad62b148 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Apr 2026 05:28:01 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl_replacements.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index da9d00b1c2..bf0da0f730 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -325,8 +325,12 @@ def remove_sync_weights_block(match): ]: function = re.sub(pattern, "", function) - string_to_find = " generate_inputs = super()._prepare_inputs(generate_inputs)" - replacement_string = string_to_find + """ + string_to_find = ( + " generate_inputs = super()._prepare_inputs(generate_inputs)" + ) + replacement_string = ( + string_to_find + + """ if "mm_token_type_ids" in generate_inputs or "image_grid_thw" in generate_inputs: mm_token_type_ids = _unsloth_fix_mm_token_type_ids( self.processing_class, @@ -335,6 +339,7 @@ def remove_sync_weights_block(match): ) if mm_token_type_ids is not None: generate_inputs["mm_token_type_ids"] = mm_token_type_ids""" + ) function = function.replace(string_to_find, replacement_string) return function @@ -569,7 +574,9 @@ def grpo_trainer__generate_and_score_completions(function_name, function): """ _tool_image_marker = " # For VLM tool images: build token type IDs from the full prompt_completion_ids." if _tool_image_marker in function: - function = function.replace(_tool_image_marker, _mm_alignment + "\n" + _tool_image_marker) + function = function.replace( + _tool_image_marker, _mm_alignment + "\n" + _tool_image_marker + ) else: _tt_search = ( 'if "token_type_ids" in forward_kwargs:\n' @@ -578,7 +585,9 @@ def grpo_trainer__generate_and_score_completions(function_name, function): " [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1\n" " )" ) - function = function.replace(_tt_search, _tt_search + "\n" + _mm_alignment.rstrip()) + function = function.replace( + _tt_search, _tt_search + "\n" + _mm_alignment.rstrip() + ) _save_search = ( 'if "token_type_ids" in forward_kwargs:\n'