From 7561aecd98b48c765d408818b8b0bb55179ff313 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 11 Apr 2026 01:37:57 +0000 Subject: [PATCH 1/7] Fix `add_response_schema` for VLM processors --- tests/test_chat_template_utils.py | 236 ++++++++++++++++++-------- trl/chat_template_utils.py | 57 ++++--- trl/experimental/dppo/dppo_trainer.py | 11 +- trl/trainer/grpo_trainer.py | 24 +-- 4 files changed, 212 insertions(+), 116 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 081f07478a8..5ce74928f72 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -17,7 +17,12 @@ import pytest import transformers from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoProcessor, + AutoTokenizer, +) from trl import clone_chat_template from trl.chat_template_utils import ( @@ -27,6 +32,7 @@ parse_response, supports_tool_calling, ) +from trl.data_utils import prepare_multimodal_messages from .testing_utils import TrlTestCase, require_jmespath @@ -111,16 +117,6 @@ def test_clone_with_sequence_classification_model(self): assert modified_tokenizer.eos_token == "<|im_end|>" -@pytest.mark.parametrize( - "tokenizer_name", - [ - pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), - pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), - pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), - pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), - pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), - ], -) @pytest.mark.xfail( condition=Version(transformers.__version__) < Version("5.0.0"), reason="Response parsing is not supported in transformers versions below 5.0.0", @@ -128,6 +124,16 @@ def test_clone_with_sequence_classification_model(self): ) @require_jmespath class TestAddResponseSchema: + @pytest.mark.parametrize( + "tokenizer_name", + [ + pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), + pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), + pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.1", id="llama3.1"), + pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.2", id="llama3.2"), + pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), + ], + ) def test_add_response_schema(self, tokenizer_name): tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) tokenizer = add_response_schema(tokenizer) @@ -146,6 +152,34 @@ def test_add_response_schema(self, tokenizer_name): # The correctness of the parsing is tested in TestParseResponse tokenizer.parse_response(response) + @pytest.mark.parametrize( + "processor_name", + [ + pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), + pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), + ], + ) + def test_add_response_schema_vlm(self, processor_name): + # For VLM processors, `add_response_schema` must set the schema on the inner tokenizer, since + # `parse_response` is a tokenizer method that reads `self.response_schema` from the tokenizer instance. + processor = AutoProcessor.from_pretrained(processor_name) + processor = add_response_schema(processor) + assert processor.tokenizer.response_schema is not None + messages = [ + {"role": "user", "content": "What is 3*4?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], + }, + ] + prefix = processor.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True) + text = processor.apply_chat_template(messages, tokenize=False) + response = text[len(prefix) :] + # Here, we just test that the parsing doesn't raise an error. + # The correctness of the parsing is tested in TestParseResponse + processor.tokenizer.parse_response(response) + class TestSupportsToolCalling: @pytest.mark.parametrize( @@ -500,7 +534,7 @@ def test_assistant_masks_multi_turn(self, tokenizer_name): @pytest.mark.parametrize( - "tokenizer_name", + "model_name", [ pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), @@ -524,100 +558,151 @@ def test_assistant_masks_multi_turn(self, tokenizer_name): ) @require_jmespath class TestParseResponse: - def test_parse_response(self, tokenizer_name): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + def _load(self, model_name): + if "ForCausalLM" in model_name: + self.is_vlm = False + processing_class = AutoTokenizer.from_pretrained(model_name) + response_schema = getattr(processing_class, "response_schema", None) + elif "ForConditionalGeneration" in model_name: + self.is_vlm = True + processing_class = AutoProcessor.from_pretrained(model_name) + response_schema = getattr(processing_class.tokenizer, "response_schema", None) + + if response_schema is None: + processing_class = add_response_schema(processing_class) + + return processing_class + + def test_parse_response(self, model_name): + processing_class = self._load(model_name) messages = [ {"role": "user", "content": "What is 3*4?"}, {"role": "assistant", "content": "12"}, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_with_reasoning_content(self, tokenizer_name): - if tokenizer_name in ( + def test_parse_response_with_reasoning_content(self, model_name): + if model_name in ( "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", "trl-internal-testing/tiny-GptOssForCausalLM", "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", ): pytest.skip("This tokenizer doesn't support inline reasoning_content.") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + processing_class = self._load(model_name) messages = [ {"role": "user", "content": "What is 3*4?"}, {"role": "assistant", "reasoning_content": "Hmmm.", "content": "12"}, ] + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages # enable_thinking=True is required here because for Qwen3.5, the thinking is disabled by default for the # generation prompt. - prefix = tokenizer.apply_chat_template( - messages[:1], add_generation_prompt=True, enable_thinking=True + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, enable_thinking=True, tokenize=True, return_dict=True ).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_tool_call(self, tokenizer_name): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + def test_parse_response_tool_call(self, model_name): + processing_class = self._load(model_name) tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}] messages = [ {"role": "user", "content": "What is 3*4?"}, {"role": "assistant", "content": "", "tool_calls": tool_calls}, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_tool_call_with_content(self, tokenizer_name): - if tokenizer_name == "trl-internal-testing/tiny-Gemma4ForConditionalGeneration": + def test_parse_response_tool_call_with_content(self, model_name): + if model_name == "trl-internal-testing/tiny-Gemma4ForConditionalGeneration": # Gemma4 response_schema regex doesn't capture content after tool calls. # Remove once https://huggingface.co/google/gemma-4-31B-it/discussions/19 is merged. pytest.xfail("Gemma4 response_schema regex bug.") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + if model_name in ( + "trl-internal-testing/tiny-LlamaForCausalLM-3.1", + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + ): + pytest.skip("Llama 3.1 / 3.2 templates only allow a single tool call per assistant turn, with no content.") + processing_class = self._load(model_name) tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}] messages = [ {"role": "user", "content": "What is 3*4?"}, {"role": "assistant", "content": "Let's call the tool.", "tool_calls": tool_calls}, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_tool_call_without_arguments(self, tokenizer_name): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + def test_parse_response_tool_call_without_arguments(self, model_name): + processing_class = self._load(model_name) tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "tool_calls": tool_calls}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": tool_calls, + }, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == {"role": "assistant", "content": "", "tool_calls": tool_calls} + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_multiple_tool_calls(self, tokenizer_name): - if tokenizer_name == "trl-internal-testing/tiny-GptOssForCausalLM": - pytest.skip("GPT-OSS template only renders one tool call per assistant message.") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + def test_parse_response_multiple_tool_calls(self, model_name): + if model_name in ( + "trl-internal-testing/tiny-GptOssForCausalLM", + "trl-internal-testing/tiny-LlamaForCausalLM-3.1", + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + ): + pytest.skip("This template only renders one tool call per assistant message.") + processing_class = self._load(model_name) tool_calls = [ {"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}, {"type": "function", "function": {"name": "addition", "arguments": {"a": 4, "b": 3}}}, @@ -626,21 +711,26 @@ def test_parse_response_multiple_tool_calls(self, tokenizer_name): {"role": "user", "content": "What is 3*4?"}, {"role": "assistant", "content": "", "tool_calls": tool_calls}, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_malformed_tool_call(self, tokenizer_name): - if tokenizer_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM": + def test_parse_response_malformed_tool_call(self, model_name): + if model_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM": pytest.skip("For simplicity, we only test the malformed tool call case on one tokenizer.") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + processing_class = self._load(model_name) text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n<|im_end|>' - assistant_text = tokenizer(text)["input_ids"] - parsed = parse_response(tokenizer, assistant_text) + assistant_text = processing_class(text)["input_ids"] + parsed = parse_response(processing_class, assistant_text) expected = { "role": "assistant", "content": '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n', diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 8925ae05e9b..29a3270512e 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -280,7 +280,9 @@ def clone_chat_template( qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() -def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: +def add_response_schema( + processing_class: PreTrainedTokenizer | ProcessorMixin, +) -> PreTrainedTokenizer | ProcessorMixin: r""" Adds the appropriate response schema to the given tokenizer based on its chat template. @@ -288,13 +290,16 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: waiting for broader adoption, we provide this utility function to manually set the response schema for known chat templates. + When given a VLM processor, the schema is set on the inner tokenizer, since `parse_response` is a tokenizer method + and reads `self.response_schema` from the tokenizer instance. + Args: - tokenizer (`PreTrainedTokenizer`): - Tokenizer to which the response schema will be added. + processing_class (`PreTrainedTokenizer` or `ProcessorMixin`): + Tokenizer or VLM processor to which the response schema will be added. Returns: - `PreTrainedTokenizer`: - Tokenizer with the added response schema. + `PreTrainedTokenizer` or `ProcessorMixin`: + The same object that was passed in, with the response schema set on the underlying tokenizer. Examples: @@ -309,24 +314,32 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ - if tokenizer.chat_template == glm4moe_chat_template: + # For VLM processors, set the schema on the inner tokenizer (where `parse_response` reads it from). + # Match against the top-level chat_template, since that's what was used historically and processors + # may carry their own VLM-specific template separate from the inner tokenizer's. + chat_template = processing_class.chat_template + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + else: + tokenizer = processing_class + if chat_template == glm4moe_chat_template: tokenizer.response_schema = glm4moe_schema - return tokenizer - if tokenizer.chat_template == gptoss_chat_template: + elif chat_template == gptoss_chat_template: tokenizer.response_schema = gptoss_schema - return tokenizer - if tokenizer.chat_template in [qwen3_chat_template, qwen3_vl_chat_template]: + elif chat_template in [llama3_1_chat_template, llama3_2_chat_template]: + tokenizer.response_schema = llama3_schema + elif chat_template in [qwen3_chat_template, qwen3_vl_chat_template]: tokenizer.response_schema = qwen3_schema - return tokenizer - if tokenizer.chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]: + elif chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]: tokenizer.response_schema = qwen3_5_schema - return tokenizer - raise ValueError( - "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the " - "tokenizer or processor. See the Transformers " - "[docs](https://huggingface.co/docs/transformers/main/en/chat_response_parsing#response-parsing) for more " - "details on response parsing." - ) + else: + raise ValueError( + "Unrecognized chat template, failed to add response schema. Please manually set the response schema on " + "the tokenizer or processor. See the Transformers " + "[docs](https://huggingface.co/docs/transformers/main/en/chat_response_parsing#response-parsing) for more " + "details on response parsing." + ) + return processing_class def supports_tool_calling(processing_class) -> bool: @@ -525,7 +538,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None: tool_call["arguments"] = {} -def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: +def parse_response(processing_class, ids: list[int]) -> dict: r""" Parse a token sequence into structured response dictionaries with fallback handling. @@ -538,7 +551,7 @@ def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: For VLM processors, automatically uses the inner tokenizer for parsing. Args: - tokenizer_or_processor (`PreTrainedTokenizer` or VLM processor): + processing_class (`PreTrainedTokenizer` or VLM processor): Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer). ids (`list[int]`): List of token sequences. @@ -561,7 +574,7 @@ def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: ``` """ # VLM processors don't have parse_response directly; use the inner tokenizer - tokenizer = getattr(tokenizer_or_processor, "tokenizer", tokenizer_or_processor) + tokenizer = getattr(processing_class, "tokenizer", processing_class) try: parsed = tokenizer.parse_response(ids) # Hotfix: remove incorrectly appended EOS token from tool calls diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py index bb5e5933228..400b51c7aad 100644 --- a/trl/experimental/dppo/dppo_trainer.py +++ b/trl/experimental/dppo/dppo_trainer.py @@ -667,12 +667,17 @@ def _generate(self, prompts: list): extra_fields = {} # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. + # `parse_response` handles VLM processors internally by unwrapping the inner tokenizer. if is_conversational({"prompt": prompts[0]}): + tokenizer = ( + self.processing_class.tokenizer + if isinstance(self.processing_class, ProcessorMixin) + else self.processing_class + ) if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors - and hasattr(self.processing_class, "response_schema") # attribute not set by default for now - and self.processing_class.response_schema is not None # only works if the tokenizer has a schema + and hasattr(tokenizer, "response_schema") # attribute not set by default for now + and tokenizer.response_schema is not None # only works if the tokenizer has a schema ): completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] else: diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 948841dc9a2..09fa558a813 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -514,13 +514,10 @@ def __init__( # At the time of initial implementation, most tokenizers do not have built-in support for response schemas. # While waiting for broader adoption, we provide this utility function to manually set the response schema for - # known chat templates. - # We need `getattr`` until the base class sets a default None value for response_schema - # For VLM processors, check the inner tokenizer too (response_schema lives on the tokenizer) - has_response_schema = getattr(processing_class, "response_schema", None) or ( - self._is_vlm and getattr(processing_class.tokenizer, "response_schema", None) - ) - if self.tools and not has_response_schema: + # known chat templates. `response_schema` lives on the (inner) tokenizer, since `parse_response` is a tokenizer + # method that reads `self.response_schema`. + tokenizer = processing_class.tokenizer if self._is_vlm else processing_class + if self.tools and getattr(tokenizer, "response_schema", None) is None: processing_class = add_response_schema(processing_class) # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template # isn't, we replace it at initialization with a training-safe, prefix-preserving template. @@ -1802,22 +1799,13 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): - parsing_class = self.processing_class - # For VLM processors, propagate response_schema to the inner tokenizer if needed - if self._is_vlm: - if getattr(self.processing_class, "response_schema", None) and not getattr( - self.processing_class.tokenizer, "response_schema", None - ): - self.processing_class.tokenizer.response_schema = self.processing_class.response_schema - # parse_response handles VLM processors internally (uses inner tokenizer) - tokenizer = getattr(parsing_class, "tokenizer", parsing_class) + tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, "response_schema") # attribute not set by default for now and tokenizer.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(parsing_class, ids)] for ids in completion_ids] + completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] From be9e232a290688e24d631cd3edfb2c11333e5248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 11 Apr 2026 01:40:38 +0000 Subject: [PATCH 2/7] rm llama --- tests/test_chat_template_utils.py | 13 +------------ trl/chat_template_utils.py | 2 -- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 5ce74928f72..60ac1012903 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -129,8 +129,6 @@ class TestAddResponseSchema: [ pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), - pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.1", id="llama3.1"), - pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.2", id="llama3.2"), pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), ], ) @@ -645,11 +643,6 @@ def test_parse_response_tool_call_with_content(self, model_name): # Gemma4 response_schema regex doesn't capture content after tool calls. # Remove once https://huggingface.co/google/gemma-4-31B-it/discussions/19 is merged. pytest.xfail("Gemma4 response_schema regex bug.") - if model_name in ( - "trl-internal-testing/tiny-LlamaForCausalLM-3.1", - "trl-internal-testing/tiny-LlamaForCausalLM-3.2", - ): - pytest.skip("Llama 3.1 / 3.2 templates only allow a single tool call per assistant turn, with no content.") processing_class = self._load(model_name) tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}] messages = [ @@ -696,11 +689,7 @@ def test_parse_response_tool_call_without_arguments(self, model_name): assert parsed == expected def test_parse_response_multiple_tool_calls(self, model_name): - if model_name in ( - "trl-internal-testing/tiny-GptOssForCausalLM", - "trl-internal-testing/tiny-LlamaForCausalLM-3.1", - "trl-internal-testing/tiny-LlamaForCausalLM-3.2", - ): + if model_name == "trl-internal-testing/tiny-GptOssForCausalLM": pytest.skip("This template only renders one tool call per assistant message.") processing_class = self._load(model_name) tool_calls = [ diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 29a3270512e..8defe20ae56 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -326,8 +326,6 @@ def add_response_schema( tokenizer.response_schema = glm4moe_schema elif chat_template == gptoss_chat_template: tokenizer.response_schema = gptoss_schema - elif chat_template in [llama3_1_chat_template, llama3_2_chat_template]: - tokenizer.response_schema = llama3_schema elif chat_template in [qwen3_chat_template, qwen3_vl_chat_template]: tokenizer.response_schema = qwen3_schema elif chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]: From 778432efea4e9f9ad2e671ffd6106545b6653ea6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 11 Apr 2026 01:45:00 +0000 Subject: [PATCH 3/7] handle empty content for vlm --- tests/test_chat_template_utils.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 60ac1012903..48153979c9d 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -139,7 +139,6 @@ def test_add_response_schema(self, tokenizer_name): {"role": "user", "content": "What is 3*4?"}, { "role": "assistant", - "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], }, ] @@ -164,10 +163,12 @@ def test_add_response_schema_vlm(self, processor_name): processor = add_response_schema(processor) assert processor.tokenizer.response_schema is not None messages = [ - {"role": "user", "content": "What is 3*4?"}, + {"role": "user", "content": [{"type": "text", "text": "What is 3*4?"}]}, { "role": "assistant", - "content": "", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": [{"type": "text", "text": ""}], "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], }, ] @@ -623,7 +624,13 @@ def test_parse_response_tool_call(self, model_name): tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}] messages = [ {"role": "user", "content": "What is 3*4?"}, - {"role": "assistant", "content": "", "tool_calls": tool_calls}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": tool_calls, + }, ] expected = messages[-1] messages = prepare_multimodal_messages(messages) if self.is_vlm else messages @@ -698,7 +705,13 @@ def test_parse_response_multiple_tool_calls(self, model_name): ] messages = [ {"role": "user", "content": "What is 3*4?"}, - {"role": "assistant", "content": "", "tool_calls": tool_calls}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": tool_calls, + }, ] expected = messages[-1] messages = prepare_multimodal_messages(messages) if self.is_vlm else messages From 4c62f60a3bb1a0f15b8cd4a5ffa1d61a14af662e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 11 Apr 2026 01:57:19 +0000 Subject: [PATCH 4/7] bikeshedding --- trl/chat_template_utils.py | 8 +++++--- trl/experimental/dppo/dppo_trainer.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 8defe20ae56..6693368a3b9 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -13,6 +13,7 @@ # limitations under the License. from pathlib import Path +from typing import TypeVar from jinja2 import TemplateError from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin @@ -280,9 +281,10 @@ def clone_chat_template( qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() -def add_response_schema( - processing_class: PreTrainedTokenizer | ProcessorMixin, -) -> PreTrainedTokenizer | ProcessorMixin: +ProcessingClassT = TypeVar("ProcessingClassT", PreTrainedTokenizer, ProcessorMixin) + + +def add_response_schema(processing_class: ProcessingClassT) -> ProcessingClassT: r""" Adds the appropriate response schema to the given tokenizer based on its chat template. diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py index 400b51c7aad..88a634a0ed2 100644 --- a/trl/experimental/dppo/dppo_trainer.py +++ b/trl/experimental/dppo/dppo_trainer.py @@ -667,7 +667,6 @@ def _generate(self, prompts: list): extra_fields = {} # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. - # `parse_response` handles VLM processors internally by unwrapping the inner tokenizer. if is_conversational({"prompt": prompts[0]}): tokenizer = ( self.processing_class.tokenizer From ecc41e7a5cd9432db7a0659f7d7d1e5681616285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 11 Apr 2026 01:59:02 +0000 Subject: [PATCH 5/7] type hint --- trl/chat_template_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 6693368a3b9..5e19bdcfd41 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -538,7 +538,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None: tool_call["arguments"] = {} -def parse_response(processing_class, ids: list[int]) -> dict: +def parse_response(processing_class: PreTrainedTokenizer | ProcessorMixin, ids: list[int]) -> dict: r""" Parse a token sequence into structured response dictionaries with fallback handling. From 8cb5f583a6613c0dde8ef9f51b526d44c4c72cf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 11 Apr 2026 02:05:50 +0000 Subject: [PATCH 6/7] alignement --- trl/experimental/dppo/dppo_trainer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py index 88a634a0ed2..260417a1aa6 100644 --- a/trl/experimental/dppo/dppo_trainer.py +++ b/trl/experimental/dppo/dppo_trainer.py @@ -668,11 +668,7 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): - tokenizer = ( - self.processing_class.tokenizer - if isinstance(self.processing_class, ProcessorMixin) - else self.processing_class - ) + tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 and hasattr(tokenizer, "response_schema") # attribute not set by default for now From 831d662670e9aef8eca70e1d04f932bcd5625bd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 14 Apr 2026 20:31:36 +0000 Subject: [PATCH 7/7] fix merge main --- tests/test_chat_template_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 3571bdd101b..b1244cc40d0 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -655,9 +655,7 @@ def test_parse_response_tool_call(self, model_name): assert parsed == expected def test_parse_response_tool_call_with_content(self, model_name): - tokenizer = AutoTokenizer.from_pretrained(model_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + processing_class = self._load(model_name) tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}] messages = [ {"role": "user", "content": "What is 3*4?"},