Skip to content
236 changes: 163 additions & 73 deletions tests/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import pytest
import transformers
from packaging.version import Version
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
)

from trl import clone_chat_template
from trl.chat_template_utils import (
Expand All @@ -27,6 +32,7 @@
parse_response,
supports_tool_calling,
)
from trl.data_utils import prepare_multimodal_messages

from .testing_utils import TrlTestCase, require_jmespath

Expand Down Expand Up @@ -111,23 +117,23 @@ def test_clone_with_sequence_classification_model(self):
assert modified_tokenizer.eos_token == "<|im_end|>"


@pytest.mark.parametrize(
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"),
pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"),
pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
],
)
@pytest.mark.xfail(
condition=Version(transformers.__version__) < Version("5.0.0"),
reason="Response parsing is not supported in transformers versions below 5.0.0",
strict=True,
)
@require_jmespath
class TestAddResponseSchema:
@pytest.mark.parametrize(
"tokenizer_name",
[
pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.1", id="llama3.1"),
pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.2", id="llama3.2"),
pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"),
],
)
def test_add_response_schema(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = add_response_schema(tokenizer)
Expand All @@ -146,6 +152,34 @@ def test_add_response_schema(self, tokenizer_name):
# The correctness of the parsing is tested in TestParseResponse
tokenizer.parse_response(response)

@pytest.mark.parametrize(
"processor_name",
[
pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"),
pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
],
)
def test_add_response_schema_vlm(self, processor_name):
# For VLM processors, `add_response_schema` must set the schema on the inner tokenizer, since
# `parse_response` is a tokenizer method that reads `self.response_schema` from the tokenizer instance.
processor = AutoProcessor.from_pretrained(processor_name)
processor = add_response_schema(processor)
assert processor.tokenizer.response_schema is not None
messages = [
{"role": "user", "content": "What is 3*4?"},
{
"role": "assistant",
"content": "",
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}],
},
]
prefix = processor.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True)
text = processor.apply_chat_template(messages, tokenize=False)
response = text[len(prefix) :]
# Here, we just test that the parsing doesn't raise an error.
# The correctness of the parsing is tested in TestParseResponse
processor.tokenizer.parse_response(response)


class TestSupportsToolCalling:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -500,7 +534,7 @@ def test_assistant_masks_multi_turn(self, tokenizer_name):


@pytest.mark.parametrize(
"tokenizer_name",
"model_name",
[
pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
Expand All @@ -524,100 +558,151 @@ def test_assistant_masks_multi_turn(self, tokenizer_name):
)
@require_jmespath
class TestParseResponse:
def test_parse_response(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
def _load(self, model_name):
if "ForCausalLM" in model_name:
self.is_vlm = False
processing_class = AutoTokenizer.from_pretrained(model_name)
response_schema = getattr(processing_class, "response_schema", None)
elif "ForConditionalGeneration" in model_name:
self.is_vlm = True
processing_class = AutoProcessor.from_pretrained(model_name)
response_schema = getattr(processing_class.tokenizer, "response_schema", None)

if response_schema is None:
processing_class = add_response_schema(processing_class)

return processing_class

def test_parse_response(self, model_name):
processing_class = self._load(model_name)
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "12"},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_with_reasoning_content(self, tokenizer_name):
if tokenizer_name in (
def test_parse_response_with_reasoning_content(self, model_name):
if model_name in (
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
"trl-internal-testing/tiny-GptOssForCausalLM",
"trl-internal-testing/tiny-Qwen3VLForConditionalGeneration",
):
pytest.skip("This tokenizer doesn't support inline reasoning_content.")

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
processing_class = self._load(model_name)
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "reasoning_content": "Hmmm.", "content": "12"},
]
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
# enable_thinking=True is required here because for Qwen3.5, the thinking is disabled by default for the
# generation prompt.
prefix = tokenizer.apply_chat_template(
messages[:1], add_generation_prompt=True, enable_thinking=True
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, enable_thinking=True, tokenize=True, return_dict=True
).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_tool_call(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
def test_parse_response_tool_call(self, model_name):
processing_class = self._load(model_name)
tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}]
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "", "tool_calls": tool_calls},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_tool_call_with_content(self, tokenizer_name):
if tokenizer_name == "trl-internal-testing/tiny-Gemma4ForConditionalGeneration":
def test_parse_response_tool_call_with_content(self, model_name):
if model_name == "trl-internal-testing/tiny-Gemma4ForConditionalGeneration":
# Gemma4 response_schema regex doesn't capture content after tool calls.
# Remove once https://huggingface.co/google/gemma-4-31B-it/discussions/19 is merged.
pytest.xfail("Gemma4 response_schema regex bug.")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
if model_name in (
"trl-internal-testing/tiny-LlamaForCausalLM-3.1",
"trl-internal-testing/tiny-LlamaForCausalLM-3.2",
):
pytest.skip("Llama 3.1 / 3.2 templates only allow a single tool call per assistant turn, with no content.")
processing_class = self._load(model_name)
tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}]
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "Let's call the tool.", "tool_calls": tool_calls},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_tool_call_without_arguments(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
def test_parse_response_tool_call_without_arguments(self, model_name):
processing_class = self._load(model_name)
tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}]
messages = [
{"role": "user", "content": "Ping the service."},
{"role": "assistant", "tool_calls": tool_calls},
{
"role": "assistant",
# "content" is required here because VLM processors crash on tokenize=True without it
# (KeyError in processing_utils.py). See huggingface/transformers#45290.
"content": "",
"tool_calls": tool_calls,
},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == {"role": "assistant", "content": "", "tool_calls": tool_calls}
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_multiple_tool_calls(self, tokenizer_name):
if tokenizer_name == "trl-internal-testing/tiny-GptOssForCausalLM":
pytest.skip("GPT-OSS template only renders one tool call per assistant message.")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
def test_parse_response_multiple_tool_calls(self, model_name):
if model_name in (
"trl-internal-testing/tiny-GptOssForCausalLM",
"trl-internal-testing/tiny-LlamaForCausalLM-3.1",
"trl-internal-testing/tiny-LlamaForCausalLM-3.2",
):
pytest.skip("This template only renders one tool call per assistant message.")
processing_class = self._load(model_name)
tool_calls = [
{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}},
{"type": "function", "function": {"name": "addition", "arguments": {"a": 4, "b": 3}}},
Expand All @@ -626,21 +711,26 @@ def test_parse_response_multiple_tool_calls(self, tokenizer_name):
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "", "tool_calls": tool_calls},
]
prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
text = tokenizer.apply_chat_template(messages).input_ids
expected = messages[-1]
messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
prefix = processing_class.apply_chat_template(
messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
).input_ids
text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
if self.is_vlm:
prefix = prefix[0]
text = text[0]
response = text[len(prefix) :]
parsed = parse_response(tokenizer, response)
assert parsed == messages[-1]
parsed = parse_response(processing_class, response)
assert parsed == expected

def test_parse_response_malformed_tool_call(self, tokenizer_name):
if tokenizer_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM":
def test_parse_response_malformed_tool_call(self, model_name):
if model_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM":
pytest.skip("For simplicity, we only test the malformed tool call case on one tokenizer.")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if getattr(tokenizer, "response_schema", None) is None:
tokenizer = add_response_schema(tokenizer)
processing_class = self._load(model_name)
text = '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n</tool_call><|im_end|>'
assistant_text = tokenizer(text)["input_ids"]
parsed = parse_response(tokenizer, assistant_text)
assistant_text = processing_class(text)["input_ids"]
parsed = parse_response(processing_class, assistant_text)
expected = {
"role": "assistant",
"content": '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n</tool_call>',
Expand Down
Loading
Loading