From c02b86c694459050e93bd71ce9a9e50a68f9cea8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 2 Apr 2026 18:22:38 +0000 Subject: [PATCH 1/2] Remove unnecessary `isinstance(part, dict)` checks in image extraction --- trl/trainer/grpo_trainer.py | 6 +++--- trl/trainer/rloo_trainer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a81bb9f38f2..988e2c43dee 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1292,7 +1292,7 @@ def _tokenize_prompts(self, prompts: list): for message in prompt: if isinstance(message["content"], list): for part in message["content"]: - if isinstance(part, dict) and part.get("type") == "image": + if part["type"] == "image": prompt_images.append(part["image"]) has_images = True images.append(prompt_images if prompt_images else None) @@ -1442,7 +1442,7 @@ def _get_tool_suffix_ids(self, tool_messages): for msg in tool_messages: if isinstance(msg.get("content"), list): for part in msg["content"]: - if isinstance(part, dict) and part.get("type") == "image": + if part["type"] == "image": tool_images.append(part["image"]) # Normalize string content in tool messages for VLM processors before either path. @@ -1657,7 +1657,7 @@ async def _run_async_tools(async_coros): # Collect images from multimodal tool responses if isinstance(content, list): for part in content: - if isinstance(part, dict) and part.get("type") == "image": + if part["type"] == "image": tool_images[idx_with_tool].append(part["image"]) prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 585b1f605eb..618ea923814 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -912,7 +912,7 @@ def _tokenize_prompts(self, prompts: list): for message in prompt: if isinstance(message["content"], list): for part in message["content"]: - if isinstance(part, dict) and part.get("type") == "image": + if part["type"] == "image": prompt_images.append(part["image"]) has_images = True images.append(prompt_images if prompt_images else None) From 577503331dd57f8edc32e94073d2b5facd0e5543 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 3 Apr 2026 22:10:19 -0400 Subject: [PATCH 2/2] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 4ea0cd68437..a9cfbf39a03 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1657,7 +1657,7 @@ async def _run_async_tools(async_coros): # Collect images from multimodal tool responses if isinstance(content, list): for part in content: - if part["type"] == "image": + if isinstance(part, dict) and part.get("type") == "image": tool_images[idx_with_tool].append(part["image"]) prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message)