diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index cd3a653a7b2..b1244cc40d0 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -18,7 +18,12 @@ import pytest import transformers from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoProcessor, + AutoTokenizer, +) from trl import clone_chat_template from trl.chat_template_utils import ( @@ -28,6 +33,7 @@ parse_response, supports_tool_calling, ) +from trl.data_utils import prepare_multimodal_messages from .testing_utils import TrlTestCase, require_jmespath @@ -112,16 +118,6 @@ def test_clone_with_sequence_classification_model(self): assert modified_tokenizer.eos_token == "<|im_end|>" -@pytest.mark.parametrize( - "tokenizer_name", - [ - pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), - pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), - pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), - pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), - pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), - ], -) @pytest.mark.xfail( condition=Version(transformers.__version__) < Version("5.0.0"), reason="Response parsing is not supported in transformers versions below 5.0.0", @@ -129,6 +125,14 @@ def test_clone_with_sequence_classification_model(self): ) @require_jmespath class TestAddResponseSchema: + @pytest.mark.parametrize( + "tokenizer_name", + [ + pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), + pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), + pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), + ], + ) def test_add_response_schema(self, tokenizer_name): tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) tokenizer = add_response_schema(tokenizer) @@ -136,7 +140,6 @@ def test_add_response_schema(self, tokenizer_name): {"role": "user", "content": "What is 3*4?"}, { "role": "assistant", - "content": "", "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], }, ] @@ -147,6 +150,36 @@ def test_add_response_schema(self, tokenizer_name): # The correctness of the parsing is tested in TestParseResponse tokenizer.parse_response(response) + @pytest.mark.parametrize( + "processor_name", + [ + pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), + pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), + ], + ) + def test_add_response_schema_vlm(self, processor_name): + # For VLM processors, `add_response_schema` must set the schema on the inner tokenizer, since + # `parse_response` is a tokenizer method that reads `self.response_schema` from the tokenizer instance. + processor = AutoProcessor.from_pretrained(processor_name) + processor = add_response_schema(processor) + assert processor.tokenizer.response_schema is not None + messages = [ + {"role": "user", "content": [{"type": "text", "text": "What is 3*4?"}]}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": [{"type": "text", "text": ""}], + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], + }, + ] + prefix = processor.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True) + text = processor.apply_chat_template(messages, tokenize=False) + response = text[len(prefix) :] + # Here, we just test that the parsing doesn't raise an error. + # The correctness of the parsing is tested in TestParseResponse + processor.tokenizer.parse_response(response) + class TestSupportsToolCalling: @pytest.mark.parametrize( @@ -509,7 +542,7 @@ def test_assistant_masks_multi_turn(self, tokenizer_name): @pytest.mark.parametrize( - "tokenizer_name", + "model_name", [ pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), @@ -533,119 +566,178 @@ def test_assistant_masks_multi_turn(self, tokenizer_name): ) @require_jmespath class TestParseResponse: - def test_parse_response(self, tokenizer_name): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + def _load(self, model_name): + if "ForCausalLM" in model_name: + self.is_vlm = False + processing_class = AutoTokenizer.from_pretrained(model_name) + response_schema = getattr(processing_class, "response_schema", None) + elif "ForConditionalGeneration" in model_name: + self.is_vlm = True + processing_class = AutoProcessor.from_pretrained(model_name) + response_schema = getattr(processing_class.tokenizer, "response_schema", None) + + if response_schema is None: + processing_class = add_response_schema(processing_class) + + return processing_class + + def test_parse_response(self, model_name): + processing_class = self._load(model_name) messages = [ {"role": "user", "content": "What is 3*4?"}, {"role": "assistant", "content": "12"}, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_with_reasoning_content(self, tokenizer_name): - if tokenizer_name in ( + def test_parse_response_with_reasoning_content(self, model_name): + if model_name in ( "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", "trl-internal-testing/tiny-GptOssForCausalLM", "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", ): pytest.skip("This tokenizer doesn't support inline reasoning_content.") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + processing_class = self._load(model_name) messages = [ {"role": "user", "content": "What is 3*4?"}, {"role": "assistant", "reasoning_content": "Hmmm.", "content": "12"}, ] + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages # enable_thinking=True is required here because for Qwen3.5, the thinking is disabled by default for the # generation prompt. - prefix = tokenizer.apply_chat_template( - messages[:1], add_generation_prompt=True, enable_thinking=True + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, enable_thinking=True, tokenize=True, return_dict=True ).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_tool_call(self, tokenizer_name): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + def test_parse_response_tool_call(self, model_name): + processing_class = self._load(model_name) tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}] messages = [ {"role": "user", "content": "What is 3*4?"}, - {"role": "assistant", "content": "", "tool_calls": tool_calls}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": tool_calls, + }, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_tool_call_with_content(self, tokenizer_name): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + def test_parse_response_tool_call_with_content(self, model_name): + processing_class = self._load(model_name) tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}] messages = [ {"role": "user", "content": "What is 3*4?"}, {"role": "assistant", "content": "Let's call the tool.", "tool_calls": tool_calls}, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_tool_call_without_arguments(self, tokenizer_name): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + def test_parse_response_tool_call_without_arguments(self, model_name): + processing_class = self._load(model_name) tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}] messages = [ {"role": "user", "content": "Ping the service."}, - {"role": "assistant", "tool_calls": tool_calls}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": tool_calls, + }, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == {"role": "assistant", "content": "", "tool_calls": tool_calls} + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_multiple_tool_calls(self, tokenizer_name): - if tokenizer_name == "trl-internal-testing/tiny-GptOssForCausalLM": - pytest.skip("GPT-OSS template only renders one tool call per assistant message.") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + def test_parse_response_multiple_tool_calls(self, model_name): + if model_name == "trl-internal-testing/tiny-GptOssForCausalLM": + pytest.skip("This template only renders one tool call per assistant message.") + processing_class = self._load(model_name) tool_calls = [ {"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}, {"type": "function", "function": {"name": "addition", "arguments": {"a": 4, "b": 3}}}, ] messages = [ {"role": "user", "content": "What is 3*4?"}, - {"role": "assistant", "content": "", "tool_calls": tool_calls}, + { + "role": "assistant", + # "content" is required here because VLM processors crash on tokenize=True without it + # (KeyError in processing_utils.py). See huggingface/transformers#45290. + "content": "", + "tool_calls": tool_calls, + }, ] - prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids - text = tokenizer.apply_chat_template(messages).input_ids + expected = messages[-1] + messages = prepare_multimodal_messages(messages) if self.is_vlm else messages + prefix = processing_class.apply_chat_template( + messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True + ).input_ids + text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids + if self.is_vlm: + prefix = prefix[0] + text = text[0] response = text[len(prefix) :] - parsed = parse_response(tokenizer, response) - assert parsed == messages[-1] + parsed = parse_response(processing_class, response) + assert parsed == expected - def test_parse_response_malformed_tool_call(self, tokenizer_name): - if tokenizer_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM": + def test_parse_response_malformed_tool_call(self, model_name): + if model_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM": pytest.skip("For simplicity, we only test the malformed tool call case on one tokenizer.") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - if getattr(tokenizer, "response_schema", None) is None: - tokenizer = add_response_schema(tokenizer) + processing_class = self._load(model_name) text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n<|im_end|>' - assistant_text = tokenizer(text)["input_ids"] - parsed = parse_response(tokenizer, assistant_text) + assistant_text = processing_class(text)["input_ids"] + parsed = parse_response(processing_class, assistant_text) expected = { "role": "assistant", "content": '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n', diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index c7cbbb1e02e..29dd031b099 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -13,6 +13,7 @@ # limitations under the License. from pathlib import Path +from typing import TypeVar from jinja2 import TemplateError from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin @@ -284,7 +285,10 @@ def clone_chat_template( qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() -def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: +ProcessingClassT = TypeVar("ProcessingClassT", PreTrainedTokenizer, ProcessorMixin) + + +def add_response_schema(processing_class: ProcessingClassT) -> ProcessingClassT: r""" Adds the appropriate response schema to the given tokenizer based on its chat template. @@ -292,13 +296,16 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: waiting for broader adoption, we provide this utility function to manually set the response schema for known chat templates. + When given a VLM processor, the schema is set on the inner tokenizer, since `parse_response` is a tokenizer method + and reads `self.response_schema` from the tokenizer instance. + Args: - tokenizer (`PreTrainedTokenizer`): - Tokenizer to which the response schema will be added. + processing_class (`PreTrainedTokenizer` or `ProcessorMixin`): + Tokenizer or VLM processor to which the response schema will be added. Returns: - `PreTrainedTokenizer`: - Tokenizer with the added response schema. + `PreTrainedTokenizer` or `ProcessorMixin`: + The same object that was passed in, with the response schema set on the underlying tokenizer. Examples: @@ -313,24 +320,30 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ - if tokenizer.chat_template == glm4moe_chat_template: + # For VLM processors, set the schema on the inner tokenizer (where `parse_response` reads it from). + # Match against the top-level chat_template, since that's what was used historically and processors + # may carry their own VLM-specific template separate from the inner tokenizer's. + chat_template = processing_class.chat_template + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + else: + tokenizer = processing_class + if chat_template == glm4moe_chat_template: tokenizer.response_schema = glm4moe_schema - return tokenizer - if tokenizer.chat_template == gptoss_chat_template: + elif chat_template == gptoss_chat_template: tokenizer.response_schema = gptoss_schema - return tokenizer - if tokenizer.chat_template in [qwen3_chat_template, qwen3_vl_chat_template]: + elif chat_template in [qwen3_chat_template, qwen3_vl_chat_template]: tokenizer.response_schema = qwen3_schema - return tokenizer - if tokenizer.chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]: + elif chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]: tokenizer.response_schema = qwen3_5_schema - return tokenizer - raise ValueError( - "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the " - "tokenizer or processor. See the Transformers " - "[docs](https://huggingface.co/docs/transformers/main/en/chat_response_parsing#response-parsing) for more " - "details on response parsing." - ) + else: + raise ValueError( + "Unrecognized chat template, failed to add response schema. Please manually set the response schema on " + "the tokenizer or processor. See the Transformers " + "[docs](https://huggingface.co/docs/transformers/main/en/chat_response_parsing#response-parsing) for more " + "details on response parsing." + ) + return processing_class def supports_tool_calling(processing_class) -> bool: @@ -560,7 +573,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None: tool_call["arguments"] = {} -def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: +def parse_response(processing_class: PreTrainedTokenizer | ProcessorMixin, ids: list[int]) -> dict: r""" Parse a token sequence into structured response dictionaries with fallback handling. @@ -573,7 +586,7 @@ def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: For VLM processors, automatically uses the inner tokenizer for parsing. Args: - tokenizer_or_processor (`PreTrainedTokenizer` or VLM processor): + processing_class (`PreTrainedTokenizer` or VLM processor): Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer). ids (`list[int]`): List of token sequences. @@ -596,7 +609,7 @@ def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: ``` """ # VLM processors don't have parse_response directly; use the inner tokenizer - tokenizer = getattr(tokenizer_or_processor, "tokenizer", tokenizer_or_processor) + tokenizer = getattr(processing_class, "tokenizer", processing_class) try: parsed = tokenizer.parse_response(ids) # Hotfix: remove incorrectly appended EOS token from tool calls diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py index bb5e5933228..260417a1aa6 100644 --- a/trl/experimental/dppo/dppo_trainer.py +++ b/trl/experimental/dppo/dppo_trainer.py @@ -668,11 +668,11 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): + tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors - and hasattr(self.processing_class, "response_schema") # attribute not set by default for now - and self.processing_class.response_schema is not None # only works if the tokenizer has a schema + and hasattr(tokenizer, "response_schema") # attribute not set by default for now + and tokenizer.response_schema is not None # only works if the tokenizer has a schema ): completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] else: diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bcabd5e17be..2f7e66ca197 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -526,13 +526,10 @@ def __init__( # At the time of initial implementation, most tokenizers do not have built-in support for response schemas. # While waiting for broader adoption, we provide this utility function to manually set the response schema for - # known chat templates. - # We need `getattr`` until the base class sets a default None value for response_schema - # For VLM processors, check the inner tokenizer too (response_schema lives on the tokenizer) - has_response_schema = getattr(processing_class, "response_schema", None) or ( - self._is_vlm and getattr(processing_class.tokenizer, "response_schema", None) - ) - if self.tools and not has_response_schema: + # known chat templates. `response_schema` lives on the (inner) tokenizer, since `parse_response` is a tokenizer + # method that reads `self.response_schema`. + tokenizer = processing_class.tokenizer if self._is_vlm else processing_class + if self.tools and getattr(tokenizer, "response_schema", None) is None: processing_class = add_response_schema(processing_class) # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template # isn't, we replace it at initialization with a training-safe, prefix-preserving template. @@ -1722,22 +1719,13 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): - parsing_class = self.processing_class - # For VLM processors, propagate response_schema to the inner tokenizer if needed - if self._is_vlm: - if getattr(self.processing_class, "response_schema", None) and not getattr( - self.processing_class.tokenizer, "response_schema", None - ): - self.processing_class.tokenizer.response_schema = self.processing_class.response_schema - # parse_response handles VLM processors internally (uses inner tokenizer) - tokenizer = getattr(parsing_class, "tokenizer", parsing_class) + tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, "response_schema") # attribute not set by default for now and tokenizer.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(parsing_class, ids)] for ids in completion_ids] + completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents]