diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py
index cd3a653a7b2..b1244cc40d0 100644
--- a/tests/test_chat_template_utils.py
+++ b/tests/test_chat_template_utils.py
@@ -18,7 +18,12 @@
import pytest
import transformers
from packaging.version import Version
-from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
+from transformers import (
+ AutoModelForCausalLM,
+ AutoModelForSequenceClassification,
+ AutoProcessor,
+ AutoTokenizer,
+)
from trl import clone_chat_template
from trl.chat_template_utils import (
@@ -28,6 +33,7 @@
parse_response,
supports_tool_calling,
)
+from trl.data_utils import prepare_multimodal_messages
from .testing_utils import TrlTestCase, require_jmespath
@@ -112,16 +118,6 @@ def test_clone_with_sequence_classification_model(self):
assert modified_tokenizer.eos_token == "<|im_end|>"
-@pytest.mark.parametrize(
- "tokenizer_name",
- [
- pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
- pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
- pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"),
- pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"),
- pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
- ],
-)
@pytest.mark.xfail(
condition=Version(transformers.__version__) < Version("5.0.0"),
reason="Response parsing is not supported in transformers versions below 5.0.0",
@@ -129,6 +125,14 @@ def test_clone_with_sequence_classification_model(self):
)
@require_jmespath
class TestAddResponseSchema:
+ @pytest.mark.parametrize(
+ "tokenizer_name",
+ [
+ pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
+ pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
+ pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"),
+ ],
+ )
def test_add_response_schema(self, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = add_response_schema(tokenizer)
@@ -136,7 +140,6 @@ def test_add_response_schema(self, tokenizer_name):
{"role": "user", "content": "What is 3*4?"},
{
"role": "assistant",
- "content": "",
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}],
},
]
@@ -147,6 +150,36 @@ def test_add_response_schema(self, tokenizer_name):
# The correctness of the parsing is tested in TestParseResponse
tokenizer.parse_response(response)
+ @pytest.mark.parametrize(
+ "processor_name",
+ [
+ pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"),
+ pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
+ ],
+ )
+ def test_add_response_schema_vlm(self, processor_name):
+ # For VLM processors, `add_response_schema` must set the schema on the inner tokenizer, since
+ # `parse_response` is a tokenizer method that reads `self.response_schema` from the tokenizer instance.
+ processor = AutoProcessor.from_pretrained(processor_name)
+ processor = add_response_schema(processor)
+ assert processor.tokenizer.response_schema is not None
+ messages = [
+ {"role": "user", "content": [{"type": "text", "text": "What is 3*4?"}]},
+ {
+ "role": "assistant",
+ # "content" is required here because VLM processors crash on tokenize=True without it
+ # (KeyError in processing_utils.py). See huggingface/transformers#45290.
+ "content": [{"type": "text", "text": ""}],
+ "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}],
+ },
+ ]
+ prefix = processor.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True)
+ text = processor.apply_chat_template(messages, tokenize=False)
+ response = text[len(prefix) :]
+ # Here, we just test that the parsing doesn't raise an error.
+ # The correctness of the parsing is tested in TestParseResponse
+ processor.tokenizer.parse_response(response)
+
class TestSupportsToolCalling:
@pytest.mark.parametrize(
@@ -509,7 +542,7 @@ def test_assistant_masks_multi_turn(self, tokenizer_name):
@pytest.mark.parametrize(
- "tokenizer_name",
+ "model_name",
[
pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
@@ -533,119 +566,178 @@ def test_assistant_masks_multi_turn(self, tokenizer_name):
)
@require_jmespath
class TestParseResponse:
- def test_parse_response(self, tokenizer_name):
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
- if getattr(tokenizer, "response_schema", None) is None:
- tokenizer = add_response_schema(tokenizer)
+ def _load(self, model_name):
+ if "ForCausalLM" in model_name:
+ self.is_vlm = False
+ processing_class = AutoTokenizer.from_pretrained(model_name)
+ response_schema = getattr(processing_class, "response_schema", None)
+ elif "ForConditionalGeneration" in model_name:
+ self.is_vlm = True
+ processing_class = AutoProcessor.from_pretrained(model_name)
+ response_schema = getattr(processing_class.tokenizer, "response_schema", None)
+
+ if response_schema is None:
+ processing_class = add_response_schema(processing_class)
+
+ return processing_class
+
+ def test_parse_response(self, model_name):
+ processing_class = self._load(model_name)
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "12"},
]
- prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
- text = tokenizer.apply_chat_template(messages).input_ids
+ expected = messages[-1]
+ messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
+ prefix = processing_class.apply_chat_template(
+ messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
+ ).input_ids
+ text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
+ if self.is_vlm:
+ prefix = prefix[0]
+ text = text[0]
response = text[len(prefix) :]
- parsed = parse_response(tokenizer, response)
- assert parsed == messages[-1]
+ parsed = parse_response(processing_class, response)
+ assert parsed == expected
- def test_parse_response_with_reasoning_content(self, tokenizer_name):
- if tokenizer_name in (
+ def test_parse_response_with_reasoning_content(self, model_name):
+ if model_name in (
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
"trl-internal-testing/tiny-GptOssForCausalLM",
"trl-internal-testing/tiny-Qwen3VLForConditionalGeneration",
):
pytest.skip("This tokenizer doesn't support inline reasoning_content.")
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
- if getattr(tokenizer, "response_schema", None) is None:
- tokenizer = add_response_schema(tokenizer)
+ processing_class = self._load(model_name)
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "reasoning_content": "Hmmm.", "content": "12"},
]
+ expected = messages[-1]
+ messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
# enable_thinking=True is required here because for Qwen3.5, the thinking is disabled by default for the
# generation prompt.
- prefix = tokenizer.apply_chat_template(
- messages[:1], add_generation_prompt=True, enable_thinking=True
+ prefix = processing_class.apply_chat_template(
+ messages[:1], add_generation_prompt=True, enable_thinking=True, tokenize=True, return_dict=True
).input_ids
- text = tokenizer.apply_chat_template(messages).input_ids
+ text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
+ if self.is_vlm:
+ prefix = prefix[0]
+ text = text[0]
response = text[len(prefix) :]
- parsed = parse_response(tokenizer, response)
- assert parsed == messages[-1]
+ parsed = parse_response(processing_class, response)
+ assert parsed == expected
- def test_parse_response_tool_call(self, tokenizer_name):
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
- if getattr(tokenizer, "response_schema", None) is None:
- tokenizer = add_response_schema(tokenizer)
+ def test_parse_response_tool_call(self, model_name):
+ processing_class = self._load(model_name)
tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}]
messages = [
{"role": "user", "content": "What is 3*4?"},
- {"role": "assistant", "content": "", "tool_calls": tool_calls},
+ {
+ "role": "assistant",
+ # "content" is required here because VLM processors crash on tokenize=True without it
+ # (KeyError in processing_utils.py). See huggingface/transformers#45290.
+ "content": "",
+ "tool_calls": tool_calls,
+ },
]
- prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
- text = tokenizer.apply_chat_template(messages).input_ids
+ expected = messages[-1]
+ messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
+ prefix = processing_class.apply_chat_template(
+ messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
+ ).input_ids
+ text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
+ if self.is_vlm:
+ prefix = prefix[0]
+ text = text[0]
response = text[len(prefix) :]
- parsed = parse_response(tokenizer, response)
- assert parsed == messages[-1]
+ parsed = parse_response(processing_class, response)
+ assert parsed == expected
- def test_parse_response_tool_call_with_content(self, tokenizer_name):
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
- if getattr(tokenizer, "response_schema", None) is None:
- tokenizer = add_response_schema(tokenizer)
+ def test_parse_response_tool_call_with_content(self, model_name):
+ processing_class = self._load(model_name)
tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}]
messages = [
{"role": "user", "content": "What is 3*4?"},
{"role": "assistant", "content": "Let's call the tool.", "tool_calls": tool_calls},
]
- prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
- text = tokenizer.apply_chat_template(messages).input_ids
+ expected = messages[-1]
+ messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
+ prefix = processing_class.apply_chat_template(
+ messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
+ ).input_ids
+ text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
+ if self.is_vlm:
+ prefix = prefix[0]
+ text = text[0]
response = text[len(prefix) :]
- parsed = parse_response(tokenizer, response)
- assert parsed == messages[-1]
+ parsed = parse_response(processing_class, response)
+ assert parsed == expected
- def test_parse_response_tool_call_without_arguments(self, tokenizer_name):
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
- if getattr(tokenizer, "response_schema", None) is None:
- tokenizer = add_response_schema(tokenizer)
+ def test_parse_response_tool_call_without_arguments(self, model_name):
+ processing_class = self._load(model_name)
tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}]
messages = [
{"role": "user", "content": "Ping the service."},
- {"role": "assistant", "tool_calls": tool_calls},
+ {
+ "role": "assistant",
+ # "content" is required here because VLM processors crash on tokenize=True without it
+ # (KeyError in processing_utils.py). See huggingface/transformers#45290.
+ "content": "",
+ "tool_calls": tool_calls,
+ },
]
- prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
- text = tokenizer.apply_chat_template(messages).input_ids
+ expected = messages[-1]
+ messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
+ prefix = processing_class.apply_chat_template(
+ messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
+ ).input_ids
+ text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
+ if self.is_vlm:
+ prefix = prefix[0]
+ text = text[0]
response = text[len(prefix) :]
- parsed = parse_response(tokenizer, response)
- assert parsed == {"role": "assistant", "content": "", "tool_calls": tool_calls}
+ parsed = parse_response(processing_class, response)
+ assert parsed == expected
- def test_parse_response_multiple_tool_calls(self, tokenizer_name):
- if tokenizer_name == "trl-internal-testing/tiny-GptOssForCausalLM":
- pytest.skip("GPT-OSS template only renders one tool call per assistant message.")
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
- if getattr(tokenizer, "response_schema", None) is None:
- tokenizer = add_response_schema(tokenizer)
+ def test_parse_response_multiple_tool_calls(self, model_name):
+ if model_name == "trl-internal-testing/tiny-GptOssForCausalLM":
+ pytest.skip("This template only renders one tool call per assistant message.")
+ processing_class = self._load(model_name)
tool_calls = [
{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}},
{"type": "function", "function": {"name": "addition", "arguments": {"a": 4, "b": 3}}},
]
messages = [
{"role": "user", "content": "What is 3*4?"},
- {"role": "assistant", "content": "", "tool_calls": tool_calls},
+ {
+ "role": "assistant",
+ # "content" is required here because VLM processors crash on tokenize=True without it
+ # (KeyError in processing_utils.py). See huggingface/transformers#45290.
+ "content": "",
+ "tool_calls": tool_calls,
+ },
]
- prefix = tokenizer.apply_chat_template(messages[:1], add_generation_prompt=True).input_ids
- text = tokenizer.apply_chat_template(messages).input_ids
+ expected = messages[-1]
+ messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
+ prefix = processing_class.apply_chat_template(
+ messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
+ ).input_ids
+ text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
+ if self.is_vlm:
+ prefix = prefix[0]
+ text = text[0]
response = text[len(prefix) :]
- parsed = parse_response(tokenizer, response)
- assert parsed == messages[-1]
+ parsed = parse_response(processing_class, response)
+ assert parsed == expected
- def test_parse_response_malformed_tool_call(self, tokenizer_name):
- if tokenizer_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM":
+ def test_parse_response_malformed_tool_call(self, model_name):
+ if model_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM":
pytest.skip("For simplicity, we only test the malformed tool call case on one tokenizer.")
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
- if getattr(tokenizer, "response_schema", None) is None:
- tokenizer = add_response_schema(tokenizer)
+ processing_class = self._load(model_name)
text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n<|im_end|>'
- assistant_text = tokenizer(text)["input_ids"]
- parsed = parse_response(tokenizer, assistant_text)
+ assistant_text = processing_class(text)["input_ids"]
+ parsed = parse_response(processing_class, assistant_text)
expected = {
"role": "assistant",
"content": '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n',
diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py
index c7cbbb1e02e..29dd031b099 100644
--- a/trl/chat_template_utils.py
+++ b/trl/chat_template_utils.py
@@ -13,6 +13,7 @@
# limitations under the License.
from pathlib import Path
+from typing import TypeVar
from jinja2 import TemplateError
from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
@@ -284,7 +285,10 @@ def clone_chat_template(
qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text()
-def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
+ProcessingClassT = TypeVar("ProcessingClassT", PreTrainedTokenizer, ProcessorMixin)
+
+
+def add_response_schema(processing_class: ProcessingClassT) -> ProcessingClassT:
r"""
Adds the appropriate response schema to the given tokenizer based on its chat template.
@@ -292,13 +296,16 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
waiting for broader adoption, we provide this utility function to manually set the response schema for known chat
templates.
+ When given a VLM processor, the schema is set on the inner tokenizer, since `parse_response` is a tokenizer method
+ and reads `self.response_schema` from the tokenizer instance.
+
Args:
- tokenizer (`PreTrainedTokenizer`):
- Tokenizer to which the response schema will be added.
+ processing_class (`PreTrainedTokenizer` or `ProcessorMixin`):
+ Tokenizer or VLM processor to which the response schema will be added.
Returns:
- `PreTrainedTokenizer`:
- Tokenizer with the added response schema.
+ `PreTrainedTokenizer` or `ProcessorMixin`:
+ The same object that was passed in, with the response schema set on the underlying tokenizer.
Examples:
@@ -313,24 +320,30 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
{'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]}
```
"""
- if tokenizer.chat_template == glm4moe_chat_template:
+ # For VLM processors, set the schema on the inner tokenizer (where `parse_response` reads it from).
+ # Match against the top-level chat_template, since that's what was used historically and processors
+ # may carry their own VLM-specific template separate from the inner tokenizer's.
+ chat_template = processing_class.chat_template
+ if isinstance(processing_class, ProcessorMixin):
+ tokenizer = processing_class.tokenizer
+ else:
+ tokenizer = processing_class
+ if chat_template == glm4moe_chat_template:
tokenizer.response_schema = glm4moe_schema
- return tokenizer
- if tokenizer.chat_template == gptoss_chat_template:
+ elif chat_template == gptoss_chat_template:
tokenizer.response_schema = gptoss_schema
- return tokenizer
- if tokenizer.chat_template in [qwen3_chat_template, qwen3_vl_chat_template]:
+ elif chat_template in [qwen3_chat_template, qwen3_vl_chat_template]:
tokenizer.response_schema = qwen3_schema
- return tokenizer
- if tokenizer.chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]:
+ elif chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]:
tokenizer.response_schema = qwen3_5_schema
- return tokenizer
- raise ValueError(
- "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the "
- "tokenizer or processor. See the Transformers "
- "[docs](https://huggingface.co/docs/transformers/main/en/chat_response_parsing#response-parsing) for more "
- "details on response parsing."
- )
+ else:
+ raise ValueError(
+ "Unrecognized chat template, failed to add response schema. Please manually set the response schema on "
+ "the tokenizer or processor. See the Transformers "
+ "[docs](https://huggingface.co/docs/transformers/main/en/chat_response_parsing#response-parsing) for more "
+ "details on response parsing."
+ )
+ return processing_class
def supports_tool_calling(processing_class) -> bool:
@@ -560,7 +573,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None:
tool_call["arguments"] = {}
-def parse_response(tokenizer_or_processor, ids: list[int]) -> dict:
+def parse_response(processing_class: PreTrainedTokenizer | ProcessorMixin, ids: list[int]) -> dict:
r"""
Parse a token sequence into structured response dictionaries with fallback handling.
@@ -573,7 +586,7 @@ def parse_response(tokenizer_or_processor, ids: list[int]) -> dict:
For VLM processors, automatically uses the inner tokenizer for parsing.
Args:
- tokenizer_or_processor (`PreTrainedTokenizer` or VLM processor):
+ processing_class (`PreTrainedTokenizer` or VLM processor):
Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer).
ids (`list[int]`):
List of token sequences.
@@ -596,7 +609,7 @@ def parse_response(tokenizer_or_processor, ids: list[int]) -> dict:
```
"""
# VLM processors don't have parse_response directly; use the inner tokenizer
- tokenizer = getattr(tokenizer_or_processor, "tokenizer", tokenizer_or_processor)
+ tokenizer = getattr(processing_class, "tokenizer", processing_class)
try:
parsed = tokenizer.parse_response(ids)
# Hotfix: remove incorrectly appended EOS token from tool calls
diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py
index bb5e5933228..260417a1aa6 100644
--- a/trl/experimental/dppo/dppo_trainer.py
+++ b/trl/experimental/dppo/dppo_trainer.py
@@ -668,11 +668,11 @@ def _generate(self, prompts: list):
# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
if is_conversational({"prompt": prompts[0]}):
+ tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class
if (
Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5
- and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors
- and hasattr(self.processing_class, "response_schema") # attribute not set by default for now
- and self.processing_class.response_schema is not None # only works if the tokenizer has a schema
+ and hasattr(tokenizer, "response_schema") # attribute not set by default for now
+ and tokenizer.response_schema is not None # only works if the tokenizer has a schema
):
completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids]
else:
diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py
index bcabd5e17be..2f7e66ca197 100644
--- a/trl/trainer/grpo_trainer.py
+++ b/trl/trainer/grpo_trainer.py
@@ -526,13 +526,10 @@ def __init__(
# At the time of initial implementation, most tokenizers do not have built-in support for response schemas.
# While waiting for broader adoption, we provide this utility function to manually set the response schema for
- # known chat templates.
- # We need `getattr`` until the base class sets a default None value for response_schema
- # For VLM processors, check the inner tokenizer too (response_schema lives on the tokenizer)
- has_response_schema = getattr(processing_class, "response_schema", None) or (
- self._is_vlm and getattr(processing_class.tokenizer, "response_schema", None)
- )
- if self.tools and not has_response_schema:
+ # known chat templates. `response_schema` lives on the (inner) tokenizer, since `parse_response` is a tokenizer
+ # method that reads `self.response_schema`.
+ tokenizer = processing_class.tokenizer if self._is_vlm else processing_class
+ if self.tools and getattr(tokenizer, "response_schema", None) is None:
processing_class = add_response_schema(processing_class)
# In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template
# isn't, we replace it at initialization with a training-safe, prefix-preserving template.
@@ -1722,22 +1719,13 @@ def _generate(self, prompts: list):
# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
if is_conversational({"prompt": prompts[0]}):
- parsing_class = self.processing_class
- # For VLM processors, propagate response_schema to the inner tokenizer if needed
- if self._is_vlm:
- if getattr(self.processing_class, "response_schema", None) and not getattr(
- self.processing_class.tokenizer, "response_schema", None
- ):
- self.processing_class.tokenizer.response_schema = self.processing_class.response_schema
- # parse_response handles VLM processors internally (uses inner tokenizer)
- tokenizer = getattr(parsing_class, "tokenizer", parsing_class)
+ tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class
if (
Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5
- and isinstance(tokenizer, PreTrainedTokenizerBase)
and hasattr(tokenizer, "response_schema") # attribute not set by default for now
and tokenizer.response_schema is not None # only works if the tokenizer has a schema
):
- completions = [[parse_response(parsing_class, ids)] for ids in completion_ids]
+ completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids]
else:
contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completions = [[{"role": "assistant", "content": content}] for content in contents]