From 355700a58c47c2a47d34f6aeee52b66d42c3f260 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 15 Apr 2026 16:02:14 +0000 Subject: [PATCH 1/2] Support VLM processors in `is_chat_template_prefix_preserving` --- tests/test_chat_template_utils.py | 65 ++++++++++++++++++++++++++++++- trl/chat_template_utils.py | 22 +++++++---- 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index cd3a653a7b2..8a0e2d0c09d 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -18,7 +18,7 @@ import pytest import transformers from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer from trl import clone_chat_template from trl.chat_template_utils import ( @@ -311,6 +311,69 @@ def test_non_prefix_preserving_template(self): {%- endif %}""") assert is_chat_template_prefix_preserving(tokenizer) is False + def test_prefix_preserving_template_processor(self): + processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration") + # Simple prefix-preserving template that mirrors how Qwen-VL templates emit image tokens: a list-of-blocks + # content is iterated, and `{"type": "image"}` blocks are rendered as `<|vision_start|><|image_pad|><|vision_end|>`. + # docstyle-ignore + processor.chat_template = textwrap.dedent(r""" + {%- for message in messages %} + + {%- if message.role == 'user' %} + {{- '<|im_start|>user\n' }} + {%- if message.content is string %} + {{- message.content }} + {%- else %} + {%- for content in message.content %} + {%- if content.type == 'image' or 'image' in content %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == 'assistant' %} + {{- '<|im_start|>assistant\n' }} + {%- if message.content is string %} + {{- message.content }} + {%- else %} + {%- for content in message.content %} + {%- if 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' + tool_call.name + '' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == 'tool' %} + {{- '<|im_start|>tool\n' }} + {%- if message.content is string %} + {{- message.content }} + {%- else %} + {%- for content in message.content %} + {%- if 'text' in content %} + {{- content.text }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- endif %} + + {%- endfor %} + + {%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- endif %}""") + assert is_chat_template_prefix_preserving(processor) is True + @pytest.mark.parametrize( "tokenizer_name", diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index c7cbbb1e02e..1e961945f23 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -391,7 +391,7 @@ def supports_tool_calling(processing_class) -> bool: return all(s in rendered for s in (_name_sentinel, _arg_key_sentinel, _arg_val_sentinel, _content_sentinel)) -def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: +def is_chat_template_prefix_preserving(processing_class: PreTrainedTokenizer | ProcessorMixin) -> bool: """ Check whether the chat template preserves prefixes when applied. @@ -400,8 +400,8 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: tokenizations with and without tool messages appended. Args: - tokenizer (`PreTrainedTokenizer`): - Tokenizer instance to check. + processing_class (`PreTrainedTokenizer` or `ProcessorMixin`): + Tokenizer or processor instance to check. Returns: `bool`: @@ -418,18 +418,26 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, {"role": "tool", "name": "dummy", "content": "dummy"}, ] + # VLM processors expect structured list-of-blocks content, and image-token expansion only kicks in when an image + # is actually present, so include a dummy image to exercise the real code path. + if isinstance(processing_class, ProcessorMixin): + from PIL import Image + + dummy_image = Image.new("RGB", (8, 8)) + messages1 = prepare_multimodal_messages(messages1, images=[dummy_image]) + messages2 = prepare_multimodal_messages(messages2, images=[dummy_image]) try: - text1 = tokenizer.apply_chat_template(messages1, tokenize=False) - text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + text1 = processing_class.apply_chat_template(messages1, tokenize=False) + text2 = processing_class.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) except TypeError: # Best-effort fallback for templates that reject dict args (e.g. DeepSeek-V3). This is a chat template # bug (see transformers#45419), and the training chat template fixes it to avoid blocking users. dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": "{}"}}] messages1[1]["tool_calls"] = dummy_tool_calls messages2[1]["tool_calls"] = dummy_tool_calls - text1 = tokenizer.apply_chat_template(messages1, tokenize=False) - text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + text1 = processing_class.apply_chat_template(messages1, tokenize=False) + text2 = processing_class.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) return text2.startswith(text1) From b08cc7ce68b01b41ece80bf521980fbe358d8547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 15 Apr 2026 17:49:36 +0000 Subject: [PATCH 2/2] requires vision --- tests/test_chat_template_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 8a0e2d0c09d..f08992a1697 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -29,7 +29,7 @@ supports_tool_calling, ) -from .testing_utils import TrlTestCase, require_jmespath +from .testing_utils import TrlTestCase, require_jmespath, require_vision class TestCloneChatTemplate(TrlTestCase): @@ -311,6 +311,7 @@ def test_non_prefix_preserving_template(self): {%- endif %}""") assert is_chat_template_prefix_preserving(tokenizer) is False + @require_vision def test_prefix_preserving_template_processor(self): processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration") # Simple prefix-preserving template that mirrors how Qwen-VL templates emit image tokens: a list-of-blocks