diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index ac814b40ed7..eb720a2dda5 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -546,27 +546,28 @@ async def _generate_one( tool_call_count += n_calls tool_failure_count += n_failures completion.extend(tool_messages) - tool_suffix_ids = self._build_messages_suffix_ids(tool_messages) - completion_ids.extend(tool_suffix_ids) - completion_logprobs.extend([0.0] * len(tool_suffix_ids)) - tool_mask.extend([0] * len(tool_suffix_ids)) - prompt_ids = prompt_ids + turn_ids + tool_suffix_ids + suffix_ids = self._get_tool_suffix_ids(tool_messages) + completion_ids.extend(suffix_ids) + completion_logprobs.extend([0.0] * len(suffix_ids)) + tool_mask.extend([0] * len(suffix_ids)) + prompt_ids = prompt_ids + turn_ids + suffix_ids iteration_num += 1 - def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int]: - template_messages = [ + def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]: + """Get token IDs for tool result formatting by using a minimal dummy conversation.""" + dummy_messages = [ {"role": "user", "content": ""}, {"role": "assistant", "content": ""}, ] prefix_ids = self.tokenizer.apply_chat_template( - template_messages, + dummy_messages, return_dict=False, tools=self.tools, chat_template=self.chat_template, **self.chat_template_kwargs, ) - prefix_and_messages_ids = self.tokenizer.apply_chat_template( - template_messages + messages, + full_ids = self.tokenizer.apply_chat_template( + dummy_messages + tool_messages, return_dict=False, chat_template=self.chat_template, add_generation_prompt=True, @@ -575,15 +576,15 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int ) # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. - # When we compute `suffix_ids` by slicing `prefix_and_messages_ids`, we must align the slicing boundary to + # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to # EOS (not EOS + newline). last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) prefix_ids = prefix_ids[: last_eos_idx + 1] - if prefix_and_messages_ids[: len(prefix_ids)] != prefix_ids: + if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") - return prefix_and_messages_ids[len(prefix_ids) :] + return full_ids[len(prefix_ids) :] def _execute_tool_calls( self, tool_calls: list[dict[str, Any]], tool_dict: dict[str, Callable]