diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f5d81447bd0..3ebd3c649ae 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1211,11 +1211,8 @@ async def _run_async_funcs(): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_single_turn(self, prompts: list): - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - - # Tokenize prompts once, shared across all generation backends + def _tokenize_prompts(self, prompts: list): + """Tokenize prompts and extract images/multimodal fields for generation.""" if is_conversational({"prompt": prompts[0]}): # Extract images from messages for VLM support images = [] @@ -1255,6 +1252,11 @@ def _generate_single_turn(self, prompts: list): prompt_ids = self.processing_class(text=prompts)["input_ids"] images = None multimodal_fields = {} + return prompt_ids, images, multimodal_fields + + def _generate_single_turn(self, prompt_ids, images, multimodal_fields): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" # Generate completions using either vLLM or regular generation if self.use_vllm: @@ -1456,8 +1458,9 @@ async def _run_async_tools(async_coros): break # all overlong, exit tool loop # Generate new completions after tool execution + pct_prompt_ids, pct_images, pct_multimodal_fields = self._tokenize_prompts(prompt_completion_tools) prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs = self._generate_single_turn( - prompt_completion_tools + pct_prompt_ids, pct_images, pct_multimodal_fields ) # Sanity check: from experience, this is useful to catch bugs in the chat template @@ -1549,7 +1552,8 @@ def _generate(self, prompts: list): extra_fields = {k: v for k, v in output.items() if k not in required_keys} prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"] else: - prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts) + prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts) + prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields) extra_fields = {} # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index dd5650dff31..e149518cca5 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -885,11 +885,8 @@ async def _run_async_funcs(): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_single_turn(self, prompts: list): - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - - # Tokenize prompts once, shared across all generation backends + def _tokenize_prompts(self, prompts: list): + """Tokenize prompts and extract images/multimodal fields for generation.""" if is_conversational({"prompt": prompts[0]}): # Extract images from messages for VLM support images = [] @@ -927,6 +924,11 @@ def _generate_single_turn(self, prompts: list): prompt_ids = self.processing_class(text=prompts)["input_ids"] images = None multimodal_fields = {} + return prompt_ids, images, multimodal_fields + + def _generate_single_turn(self, prompt_ids, images, multimodal_fields): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" # Generate completions using either vLLM or regular generation if self.use_vllm: @@ -1026,7 +1028,8 @@ def _generate(self, prompts: list): # Copy the prompts to avoid modifying the original list prompts = copy.deepcopy(prompts) - prompt_ids, completion_ids = self._generate_single_turn(prompts) + prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts) + prompt_ids, completion_ids = self._generate_single_turn(prompt_ids, images, multimodal_fields) # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}):