Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions trl/experimental/async_grpo/async_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,27 +546,28 @@ async def _generate_one(
tool_call_count += n_calls
tool_failure_count += n_failures
completion.extend(tool_messages)
tool_suffix_ids = self._build_messages_suffix_ids(tool_messages)
completion_ids.extend(tool_suffix_ids)
completion_logprobs.extend([0.0] * len(tool_suffix_ids))
tool_mask.extend([0] * len(tool_suffix_ids))
prompt_ids = prompt_ids + turn_ids + tool_suffix_ids
suffix_ids = self._get_tool_suffix_ids(tool_messages)
completion_ids.extend(suffix_ids)
completion_logprobs.extend([0.0] * len(suffix_ids))
tool_mask.extend([0] * len(suffix_ids))
prompt_ids = prompt_ids + turn_ids + suffix_ids
iteration_num += 1

def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int]:
template_messages = [
def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]:
"""Get token IDs for tool result formatting by using a minimal dummy conversation."""
dummy_messages = [
{"role": "user", "content": ""},
{"role": "assistant", "content": ""},
]
prefix_ids = self.tokenizer.apply_chat_template(
template_messages,
dummy_messages,
return_dict=False,
tools=self.tools,
chat_template=self.chat_template,
**self.chat_template_kwargs,
)
prefix_and_messages_ids = self.tokenizer.apply_chat_template(
template_messages + messages,
full_ids = self.tokenizer.apply_chat_template(
dummy_messages + tool_messages,
return_dict=False,
chat_template=self.chat_template,
add_generation_prompt=True,
Expand All @@ -575,15 +576,15 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int
)

# Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
# When we compute `suffix_ids` by slicing `prefix_and_messages_ids`, we must align the slicing boundary to
# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
# EOS (not EOS + newline).
last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id)
prefix_ids = prefix_ids[: last_eos_idx + 1]

if prefix_and_messages_ids[: len(prefix_ids)] != prefix_ids:
if full_ids[: len(prefix_ids)] != prefix_ids:
raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.")

return prefix_and_messages_ids[len(prefix_ids) :]
return full_ids[len(prefix_ids) :]

def _execute_tool_calls(
self, tool_calls: list[dict[str, Any]], tool_dict: dict[str, Callable]
Expand Down
Loading