diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e15585dd3e1..dd605750a0e 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1705,8 +1705,22 @@ def _generate_and_score_completions( ) -> dict[str, torch.Tensor | Any]: device = self.accelerator.device mode = "train" if self.model.training else "eval" - - prompts = [x["prompt"] for x in inputs] + + def remove_empty_fields(data): + if isinstance(data, dict): + return {k: remove_empty_fields(v) for k, v in data.items() if v is not None} + elif isinstance(data, list): + return [remove_empty_fields(item) for item in data if item is not None] + else: + return data + + prompts = [] + cleaned_inputs = [] + for item in inputs: + cleaned_item = remove_empty_fields(item) + cleaned_inputs.append(cleaned_item) + prompts.append(cleaned_item["prompt"]) + inputs = cleaned_inputs if self.environments: for prompt, environment, reset_kwargs in zip(prompts, self.environments, inputs, strict=True):