From 861887155d691bc9144c5abc61fba431d4c04054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 15 Apr 2026 18:21:34 +0000 Subject: [PATCH 1/4] Revert VLM support in `parse_response` --- trl/chat_template_utils.py | 10 +++------- trl/experimental/dppo/dppo_trainer.py | 20 +++++++++++++------- trl/trainer/grpo_trainer.py | 13 +++++-------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index c7cbbb1e02e..d33363d98fb 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -560,7 +560,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None: tool_call["arguments"] = {} -def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: +def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: r""" Parse a token sequence into structured response dictionaries with fallback handling. @@ -570,11 +570,9 @@ def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: Also removes incorrectly appended EOS tokens from tool call content when present, and validates tool_calls to ensure all required fields exist. - For VLM processors, automatically uses the inner tokenizer for parsing. - Args: - tokenizer_or_processor (`PreTrainedTokenizer` or VLM processor): - Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer). + tokenizer (`PreTrainedTokenizer`): + Tokenizer with a `parse_response()` method. ids (`list[int]`): List of token sequences. @@ -595,8 +593,6 @@ def parse_response(tokenizer_or_processor, ids: list[int]) -> dict: {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ - # VLM processors don't have parse_response directly; use the inner tokenizer - tokenizer = getattr(tokenizer_or_processor, "tokenizer", tokenizer_or_processor) try: parsed = tokenizer.parse_response(ids) # Hotfix: remove incorrectly appended EOS token from tool calls diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py index bb5e5933228..0ab5b47a72e 100644 --- a/trl/experimental/dppo/dppo_trainer.py +++ b/trl/experimental/dppo/dppo_trainer.py @@ -592,9 +592,8 @@ async def _run_async_tools(async_coros): completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] # Decode post-tool completions - post_tool_completions = [ - parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids - ] + tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class + post_tool_completions = [parse_response(tokenizer, ids) if ids else {} for ids in post_tool_ids] for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] @@ -668,13 +667,20 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): + # For VLM processors, propagate response_schema to the inner tokenizer if needed + if self._is_vlm: + if getattr(self.processing_class, "response_schema", None) and not getattr( + self.processing_class.tokenizer, "response_schema", None + ): + self.processing_class.tokenizer.response_schema = self.processing_class.response_schema + tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors - and hasattr(self.processing_class, "response_schema") # attribute not set by default for now - and self.processing_class.response_schema is not None # only works if the tokenizer has a schema + and isinstance(tokenizer, PreTrainedTokenizerBase) + and hasattr(tokenizer, "response_schema") # attribute not set by default for now + and tokenizer.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + completions = [[parse_response(tokenizer, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bcabd5e17be..b75722baa4f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1669,10 +1669,9 @@ async def _run_async_tools(async_coros): pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] - # Decode post-tool completions. - post_tool_completions = [ - parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids - ] + # Decode post-tool completions + tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class + post_tool_completions = [parse_response(tokenizer, ids) if ids else {} for ids in post_tool_ids] # Add post-tool completions to the existing completions for idx in range(len(idxs_with_tool)): @@ -1722,22 +1721,20 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): - parsing_class = self.processing_class # For VLM processors, propagate response_schema to the inner tokenizer if needed if self._is_vlm: if getattr(self.processing_class, "response_schema", None) and not getattr( self.processing_class.tokenizer, "response_schema", None ): self.processing_class.tokenizer.response_schema = self.processing_class.response_schema - # parse_response handles VLM processors internally (uses inner tokenizer) - tokenizer = getattr(parsing_class, "tokenizer", parsing_class) + tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 and isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, "response_schema") # attribute not set by default for now and tokenizer.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(parsing_class, ids)] for ids in completion_ids] + completions = [[parse_response(tokenizer, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] From cc3905c0fa30cc5e9a176a7dfd015e41b110553a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 17 Apr 2026 18:32:54 +0000 Subject: [PATCH 2/4] fix --- tests/test_chat_template_utils.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 4ce74c0f760..c448e2e49fa 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -660,7 +660,8 @@ def test_parse_response(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_with_reasoning_content(self, model_name): @@ -690,7 +691,8 @@ def test_parse_response_with_reasoning_content(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_tool_call(self, model_name): @@ -716,7 +718,8 @@ def test_parse_response_tool_call(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_tool_call_with_content(self, model_name): @@ -741,7 +744,8 @@ def test_parse_response_tool_call_with_content(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_tool_call_without_arguments(self, model_name): @@ -767,7 +771,8 @@ def test_parse_response_tool_call_without_arguments(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_multiple_tool_calls(self, model_name): @@ -802,7 +807,8 @@ def test_parse_response_multiple_tool_calls(self, model_name): prefix = prefix[0] text = text[0] response = text[len(prefix) :] - parsed = parse_response(processing_class, response) + tokenizer = processing_class.tokenizer if self.is_vlm else processing_class + parsed = parse_response(tokenizer, response) assert parsed == expected def test_parse_response_malformed_tool_call(self, model_name): From 8a5ffa4e1f1209b2f06240700460478cef86f3cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 29 Apr 2026 15:58:13 +0000 Subject: [PATCH 3/4] empty for ci trigger From 58859f5b6031f7bba075bfd2977bc07c419df6d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 29 Apr 2026 16:02:38 +0000 Subject: [PATCH 4/4] fix --- trl/chat_template_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 19d152598b5..5b63e43e173 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -685,7 +685,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None: tool_call["arguments"] = {} -def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: +def parse_response(tokenizer: PreTrainedTokenizerBase, ids: list[int]) -> dict: r""" Parse a token sequence into structured response dictionaries with fallback handling. @@ -696,7 +696,7 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: ensure all required fields exist. Args: - tokenizer (`PreTrainedTokenizer`): + tokenizer (`PreTrainedTokenizerBase`): Tokenizer with a `parse_response()` method. ids (`list[int]`): List of token sequences.