Skip to content
242 changes: 167 additions & 75 deletions tests/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import pytest
import transformers
from packaging.version import Version
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
)

from trl import clone_chat_template
from trl.chat_template_utils import (
Expand All @@ -28,6 +33,7 @@
parse_response,
supports_tool_calling,
)
from trl.data_utils import prepare_multimodal_messages

from .testing_utils import TrlTestCase, require_jmespath

Expand Down Expand Up @@ -112,31 +118,28 @@ def test_clone_with_sequence_classification_model(self):
assert modified_tokenizer.eos_token == "<|im_end|>"


@pytest.mark.parametrize(
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"),
pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"),
pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
],
)
@pytest.mark.xfail(
condition=Version(transformers.__version__) < Version("5.0.0"),
reason="Response parsing is not supported in transformers versions below 5.0.0",
strict=True,
)
@require_jmespath
class TestAddResponseSchema:
@pytest.mark.parametrize(
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"),
],
)
def test_add_response_schema(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = add_response_schema(tokenizer)
messages = [
{"role": "user", "content": "What is 3*4?"},
{
"role": "assistant",
"content": "",
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}],
},
]
Expand All @@ -147,6 +150,36 @@ def test_add_response_schema(self, tokenizer_name):
# The correctness of the parsing is tested in TestParseResponse
tokenizer.parse_response(response)

@pytest.mark.parametrize(
"processor_name",
[
pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"),
pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
],
)
def test_add_response_schema_vlm(self, processor_name):
# For VLM processors, `add_response_schema` must set the schema on the inner tokenizer, since
# `parse_response` is a tokenizer method that reads `self.response_schema` from the tokenizer instance.
processor = AutoProcessor.from_pretrained(processor_name)
processor = add_response_schema(processor)
assert processor.tokenizer.response_schema is not None
messages = [
{"role": "user", "content": [{"type": "text", "text": "What is 3*4?"}]},
{
"role": "assistant",
# "content" is required here because VLM processors crash on tokenize=True without it
# (KeyError in processing_utils.py). See huggingface/transformers#45290.
"content": [{"type": "text", "text": ""}],
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}],
},
]
prefix = processor.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True)
text = processor.apply_chat_template(messages, tokenize=False)
response = text[len(prefix) :]
# Here, we just test that the parsing doesn't raise an error.
# The correctness of the parsing is tested in TestParseResponse
processor.tokenizer.parse_response(response)


class TestSupportsToolCalling:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -509,7 +542,7 @@ def test_assistant_masks_multi_turn(self, tokenizer_name):


@pytest.mark.parametrize(
"tokenizer_name",
"model_name",
[
pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
Expand All @@ -533,119 +566,178 @@ def test_assistant_masks_multi_turn(self, tokenizer_name):
)
@require_jmespath
class TestParseResponse:
def test_parse_response(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
def _load(self, model_name):
if "ForCausalLM" in model_name:
self.is_vlm = False
processing_class = AutoTokenizer.from_pretrained(model_name)
response_schema = getattr(processing_class, "response_schema", None)
elif "ForConditionalGeneration" in model_name:
self.is_vlm = True
processing_class = AutoProcessor.from_pretrained(model_name)
response_schema = getattr(processing_class.tokenizer, "response_schema", None)

if response_schema is None:
processing_class = add_response_schema(processing_class)

return processing_class

def test_parse_response(self, model_name):
processing_class = self._load(model_name)
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "12"},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_with_reasoning_content(self, tokenizer_name):
if tokenizer_name in (
def test_parse_response_with_reasoning_content(self, model_name):
if model_name in (
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
"trl-internal-testing/tiny-GptOssForCausalLM",
"trl-internal-testing/tiny-Qwen3VLForConditionalGeneration",
):
pytest.skip("This tokenizer doesn't support inline reasoning_content.")

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
processing_class = self._load(model_name)
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "reasoning_content": "Hmmm.", "content": "12"},
]
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
# enable_thinking=True is required here because for Qwen3.5, the thinking is disabled by default for the
# generation prompt.
prefix = tokenizer.apply_chat_template(
messages[:1], add_generation_prompt=True, enable_thinking=True
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, enable_thinking=True, tokenize=True, return_dict=True
).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_tool_call(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
def test_parse_response_tool_call(self, model_name):
processing_class = self._load(model_name)
tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}]
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "", "tool_calls": tool_calls},
{
"role": "assistant",
# "content" is required here because VLM processors crash on tokenize=True without it
# (KeyError in processing_utils.py). See huggingface/transformers#45290.
"content": "",
"tool_calls": tool_calls,
},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_tool_call_with_content(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
def test_parse_response_tool_call_with_content(self, model_name):
processing_class = self._load(model_name)
tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}]
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "Let's call the tool.", "tool_calls": tool_calls},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_tool_call_without_arguments(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
def test_parse_response_tool_call_without_arguments(self, model_name):
processing_class = self._load(model_name)
tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}]
messages = [
{"role": "user", "content": "Ping the service."},
{"role": "assistant", "tool_calls": tool_calls},
{
"role": "assistant",
# "content" is required here because VLM processors crash on tokenize=True without it
# (KeyError in processing_utils.py). See huggingface/transformers#45290.
"content": "",
"tool_calls": tool_calls,
},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == {"role": "assistant", "content": "", "tool_calls": tool_calls}
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_multiple_tool_calls(self, tokenizer_name):
if tokenizer_name == "trl-internal-testing/tiny-GptOssForCausalLM":
pytest.skip("GPT-OSS template only renders one tool call per assistant message.")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
def test_parse_response_multiple_tool_calls(self, model_name):
if model_name == "trl-internal-testing/tiny-GptOssForCausalLM":
pytest.skip("This template only renders one tool call per assistant message.")
processing_class = self._load(model_name)
tool_calls = [
{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}},
{"type": "function", "function": {"name": "addition", "arguments": {"a": 4, "b": 3}}},
]
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "", "tool_calls": tool_calls},
{
"role": "assistant",
# "content" is required here because VLM processors crash on tokenize=True without it
# (KeyError in processing_utils.py). See huggingface/transformers#45290.
"content": "",
"tool_calls": tool_calls,
},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_malformed_tool_call(self, tokenizer_name):
if tokenizer_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM":
def test_parse_response_malformed_tool_call(self, model_name):
if model_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM":
pytest.skip("For simplicity, we only test the malformed tool call case on one tokenizer.")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
processing_class = self._load(model_name)
text = '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n</tool_call><|im_end|>'
assistant_text = tokenizer(text)["input_ids"]
parsed = parse_response(tokenizer, assistant_text)
assistant_text = processing_class(text)["input_ids"]
parsed = parse_response(processing_class, assistant_text)
expected = {
"role": "assistant",
"content": '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n</tool_call>',
Expand Down
Loading
Loading