Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,8 +1705,22 @@ def _generate_and_score_completions(
) -> dict[str, torch.Tensor | Any]:
device = self.accelerator.device
mode = "train" if self.model.training else "eval"

prompts = [x["prompt"] for x in inputs]

def remove_empty_fields(data):
if isinstance(data, dict):
return {k: remove_empty_fields(v) for k, v in data.items() if v is not None}
elif isinstance(data, list):
return [remove_empty_fields(item) for item in data if item is not None]
else:
return data

prompts = []
cleaned_inputs = []
for item in inputs:
cleaned_item = remove_empty_fields(item)
cleaned_inputs.append(cleaned_item)
prompts.append(cleaned_item["prompt"])
inputs = cleaned_inputs

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broad None stripping removes top-level image keys breaking detection

High Severity

remove_empty_fields is applied to the entire input dict, not just the prompt content blocks. This strips top-level keys with None values, including "image". When inputs[0] has "image": None (a text-only sample in a mixed batch) but other inputs have actual images, the key is removed from inputs[0]. The subsequent check "image" in inputs[0] then fails, causing images = None and silently losing all images in the batch. The fix should only clean the nested prompt content, not the entire input dict.

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs[0]["image"]should be PIL and will no be removed by my changes.

print(inputs)
[{'prompt': [{'content': [{'image': None, 'text': 'You are good at step by step reasoning.', 'type': 'text'}], 'role': 'system'}, {'content': [{'image': '../datas/VisuRiddles/images/sichuan/2021_59.png', 'text': None, 'type': 'image'}, {'image': None, 'text': '[Logical Reasoning] \nThe left image shows the unfolded surface of a cube-shaped box. Which option can be folded into the cube depicted?option: A,B,C,D\nWrite the answer into a JSON form\njson\n{"answer": "X"}', 'type': 'text'}], 'role': 'user'}], 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=560x168 at 0xFFFBE5FEB5D0>, 'metadatas': {'gold_answer': 'A'}}, {'prompt': [{'content': [{'image': None, 'text': 'You are good at step by step reasoning.', 'type': 'text'}], 'role': 'system'}, {'content': [{'image': '../datas/VisuRiddles/images/sichuan/2021_59.png', 'text': None, 'type': 'image'}, {'image': None, 'text': '[Logical Reasoning] \nThe left image shows the unfolded surface of a cube-shaped box. Which option can be folded into the cube depicted?option: A,B,C,D\nWrite the answer into a JSON form\njson\n{"answer": "X"}', 'type': 'text'}], 'role': 'user'}], 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=560x168 at 0xFFFBE5FEBC90>, 'metadatas': {'gold_answer': 'A'}}]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix not propagated to RLOO trainer's duplicated code

Medium Severity

The remove_empty_fields logic was added only to grpo_trainer.py but not to rloo_trainer.py, which has the same duplicated _generate_and_score_completions method with the identical prompts = [x["prompt"] for x in inputs] pattern. Per project rules, changes to duplicated logic across trainers must be applied consistently to all copies.

Fix in Cursor Fix in Web

Triggered by project rule: BUGBOT.md

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the subsequent execution flow and call stack of rloo_trainer.py are exactly the same as grpo_trainer.py. So, to be safe, I will only modify the grpo_trainer that has already been tested.


if self.environments:
for prompt, environment, reset_kwargs in zip(prompts, self.environments, inputs, strict=True):
Expand Down