diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f54ba039c7b..7354202258b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1295,7 +1295,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): # Generate using vLLM with raw token IDs num_generations = self.num_generations if mode == "train" else self.num_generations_eval - prompt_ids, completion_ids, logprobs, _ = self.vllm_generation.generate( + _, completion_ids, logprobs, _ = self.vllm_generation.generate( prompts=prompt_ids, images=images, num_generations=num_generations, @@ -1361,8 +1361,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): **generate_inputs, generation_config=self.generation_config, disable_compile=True ) # Compute prompt length and extract completion ids - prompt_ids_tensor, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] - prompt_length = prompt_ids_tensor.size(1) + prompt_length = generate_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token @@ -1371,18 +1370,35 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - # Move tensors to CPU before per-sample to avoid many CUDA syncs/copies (costly at scale/contention). - prompt_ids = [ - p[m].tolist() for p, m in zip(prompt_ids_tensor.cpu(), prompt_mask.bool().cpu(), strict=True) - ] completion_ids = [ c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True) ] logprobs = None # not used in this case - return prompt_ids, completion_ids, logprobs + return completion_ids, logprobs + + def _get_tool_suffix_ids(self, tool_messages): + """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + prefix_ids = self.processing_class.apply_chat_template( + dummy_messages, + add_generation_prompt=False, + chat_template=self.chat_template, + return_dict=False, + **self.chat_template_kwargs, + ) + full_ids = self.processing_class.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + chat_template=self.chat_template, + return_dict=False, + **self.chat_template_kwargs, + ) + if not full_ids[: len(prefix_ids)] == prefix_ids: + raise ValueError("Unexpected tokenization: the prefix IDs are not a prefix of the full IDs.") + return full_ids[len(prefix_ids) :] - def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs): + def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields): # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt tool_calls = [completion[0].get("tool_calls") for completion in completions] idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] @@ -1449,17 +1465,24 @@ async def _run_async_tools(async_coros): prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message) - # Tokenize and filter samples whose length exceeds max allowed length. This is important, because both + # Build token IDs by concatenation: prompt + completion + tool_suffix. + prompt_completion_tool_ids = [] + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + # Extract trailing tool messages from completions + tool_messages = [] + for message in reversed(completions[idx_with_tool]): + if message["role"] == "tool": + tool_messages.insert(0, message) + else: + break + suffix_ids = self._get_tool_suffix_ids(tool_messages) + prompt_completion_tool_ids.append( + prompt_ids[idx_with_tool] + completion_ids[idx_with_tool] + suffix_ids + ) + + # Filter samples whose length exceeds max allowed length. This is important, because both # vLLM and transformers will error out if the input is longer than the model's max length. - pct_ids = self.processing_class.apply_chat_template( - prompt_completion_tools, - tools=self.tools, - chat_template=self.chat_template, - add_generation_prompt=True, - tokenize=True, - return_dict=False, - **self.chat_template_kwargs, - ) if self.use_vllm and self.vllm_mode == "colocate": max_model_len = self.vllm_generation.llm.llm_engine.model_config.max_model_len elif not self.use_vllm: @@ -1468,12 +1491,12 @@ async def _run_async_tools(async_coros): raise NotImplementedError( f"Unsupported mode detected: use_vllm={self.use_vllm}, vllm_mode={self.vllm_mode}" ) - overlong = [len(pct) >= max_model_len for pct in pct_ids] + overlong = [len(pct) >= max_model_len for pct in prompt_completion_tool_ids] for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] if overlong[idx]: prompt_length = len(prompt_ids[idx_with_tool]) - ct = pct_ids[idx][prompt_length : prompt_length + self.max_completion_length] + ct = prompt_completion_tool_ids[idx][prompt_length : prompt_length + self.max_completion_length] completion_ids[idx_with_tool] = ct tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) if logprobs is not None: @@ -1481,24 +1504,22 @@ async def _run_async_tools(async_coros): # Keep only non-overlong items for further processing idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o] + prompt_completion_tool_ids = [ + pct for pct, o in zip(prompt_completion_tool_ids, overlong, strict=True) if not o + ] if not idxs_with_tool: break # all overlong, exit tool loop - # Generate new completions after tool execution - pct_prompt_ids, pct_images, pct_multimodal_fields = self._tokenize_prompts(prompt_completion_tools) - prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs = self._generate_single_turn( - pct_prompt_ids, pct_images, pct_multimodal_fields + # Filter images and multimodal fields to match the current subset (index into full batch) + loop_images = [images[i] for i in idxs_with_tool] if images else None + loop_multimodal_fields = ( + {k: [v[i] for i in idxs_with_tool] for k, v in multimodal_fields.items()} if multimodal_fields else {} ) - # Sanity check: from experience, this is useful to catch bugs in the chat template - for idx in range(len(idxs_with_tool)): - idx_with_tool = idxs_with_tool[idx] - pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool - if prompt_ids[idx_with_tool] != pct[: len(prompt_ids[idx_with_tool])]: - raise ValueError( - "The chat template is not prefix-preserving. Please update it to use a prefix-preserving " - "format." - ) + # Generate new completions after tool execution (using concatenated IDs, no re-tokenization) + post_tool_ids, post_tool_logprobs = self._generate_single_turn( + prompt_completion_tool_ids, loop_images, loop_multimodal_fields + ) # Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length for idx in range(len(idxs_with_tool)): @@ -1580,7 +1601,7 @@ def _generate(self, prompts: list): prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"] else: prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts) - prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields) + completion_ids, logprobs = self._generate_single_turn(prompt_ids, images, multimodal_fields) extra_fields = {} # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. @@ -1607,7 +1628,9 @@ def _generate(self, prompts: list): logprobs, tool_call_count, tool_failure_count, - ) = self._tool_call_loop(prompts, prompt_ids, completion_ids, completions, logprobs) + ) = self._tool_call_loop( + prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields + ) else: # Support custom env_mask from rollout_func (e.g., for environment feedback masking) # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 9ac34c19a64..7179b05d0c4 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -952,7 +952,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): # Generate using vLLM (note: RLOO doesn't use logprobs from generation, so we ignore them) num_generations = self.num_generations if mode == "train" else self.num_generations_eval - prompt_ids, completion_ids, _, _ = self.vllm_generation.generate( + _, completion_ids, _, _ = self.vllm_generation.generate( prompts=prompt_ids, images=images, num_generations=num_generations, @@ -1013,8 +1013,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): **generate_inputs, generation_config=self.generation_config, disable_compile=True ) # Compute prompt length and extract completion ids - prompt_ids_tensor, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] - prompt_length = prompt_ids_tensor.size(1) + prompt_length = generate_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token @@ -1023,15 +1022,11 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - # Move tensors to CPU before per-sample to avoid many CUDA syncs/copies (costly at scale/contention). - prompt_ids = [ - p[m].tolist() for p, m in zip(prompt_ids_tensor.cpu(), prompt_mask.bool().cpu(), strict=True) - ] completion_ids = [ c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True) ] - return prompt_ids, completion_ids + return completion_ids def _generate(self, prompts: list): device = self.accelerator.device @@ -1041,7 +1036,7 @@ def _generate(self, prompts: list): prompts = copy.deepcopy(prompts) prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts) - prompt_ids, completion_ids = self._generate_single_turn(prompt_ids, images, multimodal_fields) + completion_ids = self._generate_single_turn(prompt_ids, images, multimodal_fields) # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}):