diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ab83c8550a1..a9cfbf39a03 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1292,7 +1292,7 @@ def _tokenize_prompts(self, prompts: list): for message in prompt: if isinstance(message["content"], list): for part in message["content"]: - if isinstance(part, dict) and part.get("type") == "image": + if part["type"] == "image": prompt_images.append(part["image"]) has_images = True images.append(prompt_images if prompt_images else None) @@ -1442,7 +1442,7 @@ def _get_tool_suffix_ids(self, tool_messages): for msg in tool_messages: if isinstance(msg.get("content"), list): for part in msg["content"]: - if isinstance(part, dict) and part.get("type") == "image": + if part["type"] == "image": tool_images.append(part["image"]) # Normalize string content in tool messages for VLM processors before either path. diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 67ed23ad65f..f9dd4280a5a 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -912,7 +912,7 @@ def _tokenize_prompts(self, prompts: list): for message in prompt: if isinstance(message["content"], list): for part in message["content"]: - if isinstance(part, dict) and part.get("type") == "image": + if part["type"] == "image": prompt_images.append(part["image"]) has_images = True images.append(prompt_images if prompt_images else None)