Skip to content
15 changes: 11 additions & 4 deletions trl/experimental/async_grpo/async_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,17 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int
tools=self.tools,
**self.chat_template_kwargs,
)
prefix_len = len(prefix_ids)
if prefix_and_messages_ids[:prefix_len] != prefix_ids:
raise ValueError("Failed to construct message suffix in token space.")
return prefix_and_messages_ids[prefix_len:]

# Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
# When we compute `suffix_ids` by slicing `prefix_and_messages_ids`, we must align the slicing boundary to
# EOS (not EOS + newline).
last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id)
prefix_ids = prefix_ids[: last_eos_idx + 1]

if prefix_and_messages_ids[: len(prefix_ids)] != prefix_ids:
Comment thread
casinca marked this conversation as resolved.
raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.")

return prefix_and_messages_ids[len(prefix_ids) :]

def _execute_tool_calls(
self, tool_calls: list[dict[str, Any]], tool_dict: dict[str, Callable]
Expand Down
12 changes: 10 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,16 @@ def _get_tool_suffix_ids(self, tool_messages):
return_dict=False,
**self.chat_template_kwargs,
)
if not full_ids[: len(prefix_ids)] == prefix_ids:
raise ValueError("Unexpected tokenization: the prefix IDs are not a prefix of the full IDs.")

# Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block.
# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
# EOS (not EOS + newline).
last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id)
prefix_ids = prefix_ids[: last_eos_idx + 1]

if full_ids[: len(prefix_ids)] != prefix_ids:
raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.")

return full_ids[len(prefix_ids) :]

def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields):
Expand Down
Loading