Skip to content

🕵️‍♂️ GRPO: Agent training#4300

Merged
qgallouedec merged 224 commits intomainfrom
tool-call-finally
Dec 9, 2025
Merged

🕵️‍♂️ GRPO: Agent training#4300
qgallouedec merged 224 commits intomainfrom
tool-call-finally

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Oct 18, 2025

What does this PR do?

This PR implements tool calling for GRPO. The API is as follows:

⚠️ requires transformers v5.0.0.dev0

from datasets import Dataset
from trl import GRPOTrainer

def multiply(a: int, b: int) -> int:
    """
    Multiplies two integers.

    Args:
        a: The first integer.
        b: The second integer.

    Returns:
        The product of the two integers.
    """
    return a * b


dataset = Dataset.from_list(
    [
        {"prompt": [{"role": "user", "content": "What is 3 multiplied by 4?"}], "answer": 12},
        {"prompt": [{"role": "user", "content": "Calculate 7 times 8."}], "answer": 56},
        {"prompt": [{"role": "user", "content": "Find the product of 5 and 6."}], "answer": 30},
        {"prompt": [{"role": "user", "content": "What do you get when you multiply 9 by 9?"}], "answer": 81},
        {"prompt": [{"role": "user", "content": "Compute 12 multiplied by 11."}], "answer": 132},
        {"prompt": [{"role": "user", "content": "What is 15 times 14?"}], "answer": 210},
    ]
)

def accuracy(completions, answer, **kwargs):
    rewards = []
    for completion, ans in zip(completions, answer):
        if str(ans) in completion[-1]["content"]:
            rewards.append(1.0)
        else:
            rewards.append(0.0)
    return rewards

trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B",
    train_dataset=dataset,
    tools=[multiply],
    reward_funcs=accuracy,
)
trainer.train()

This PR contains a few important changes:

🚨 Removal of max_prompt_length

This PR contains a breaking change: max_prompt_length has been removed from GRPO.

Here are the reasons: (tldr: because it’s extremely hard to implement reliably with multi-turn tool calling, likely harmful to training anyway, likely not used in practice, and dropping it simplifies the API while keeping it consistent across LLMs and VLMs.)

  1. Supporting max_prompt_length with tool calling is extremely complex.
    For single-turn generation it works fine, but multi-turn generation introduces a major challenge: the prompt grows after every step. Since the model is called repeatedly with an increasingly long prompt, we would need to recalculate the allowed prompt length dynamically based on how many tokens have already been generated. Implementing this reliably is tricky and adds significant complexity.

  2. Truncating prompts is likely worse than dropping samples altogether.
    Although I’m not aware of formal studies, intuition suggests that truncation can remove information necessary to solve the task. Training on such incomplete examples can lead to strong biases, whereas simply skipping overly long samples avoids this risk.

  3. It simplifies the API and removes confusing edge cases.
    Previously, when training VLMs, we had to tell users to disable prompt truncation entirely because Transformers does not support truncating multimodal prompts. This led to inconsistent, non-user-friendly recommendations. Removing max_prompt_length allows us to provide one clean, unified API that works for all model types.

  4. It very likely not a widely used feature anyway

Online decoding

Before calling the reward function, we need to decode the completion. Previously, this was done here:

# Decode
prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
if isinstance(bootstrap, list): # for VLM, the format might be [{"type": "text", "text": "..."}]
assert len(bootstrap) == 1 and bootstrap[0]["type"] == "text"
bootstrap = bootstrap[0]["text"]
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text

The issue is that, while this works for single-turn outputs, it does not allow reliable parsing of multi-turn text. See this internal discussion. The workaround is to parse after each turn, which requires moving the decoding logic inside the generation loop (in _generate):

# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
if is_conversational({"prompt": prompts[0]}):
if (
Version(transformers.__version__) >= Version("5.0.0.dev0") # parse_response added in v5
and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors
and self.processing_class.response_schema is not None # only works if the tokenizer has a schema
):
completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids]
else:
contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completions = [[{"role": "assistant", "content": content}] for content in contents]
else:
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)

completions[idx_with_tool].append(tool_message)

# Add post-tool completions to the existing completions
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
if post_tool_completions[idx]: # {} if post-tool completions completely truncated
completions[idx_with_tool].append(post_tool_completions[idx])

The method then returns the list of messages:

completions,

Note that this change removes support for the "bootstrap" feature. I haven’t had time to investigate adding support for it.

Tool mask

We don't want the loss to be computed on the tokens corresponding to the tool result. Consequently, _generate builds and return a tool_mask

tool_mask,

which is then used to mask these tokens in the loss computation.

mask = completion_mask if not self.tools else completion_mask * (1 - inputs["tool_mask"])

Schema and fixed chat template

Chat template

To make this feature work, we need the chat template to be prefix-preserving. Ie:

messages1 = [
{"role": "user", "content": "What color is the sky?"},
]
messages2 = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
messages3 = [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
{"role": "user", "content": "And at night?"},
]
text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False)
text3 = tokenizer.apply_chat_template(messages3, tokenize=False)
return text2.startswith(text1) and text3.startswith(text2)

The issue is that some widely used tokenizers, such as GPT-OSS and Qwen3, are not prefix-preserving due to the way they handle think tokens. To address this, I suggest using a slightly modified version of the template that ensures it is prefix-preserving. Additionally, as @lewtun pointed out, it’s not even clear whether these templates might make the inference OOD

Screenshot 2025-11-19 at 11 59 14 PM

Response schema

To parse tool calls from the model’s response, we rely on tokenizer.parse_response, introduced in huggingface/transformers#40894. This requires the tokenizer to have a response_schema (integrated in a similar way as chat templates). However, very few (no?) model repositories currently include such a schema.

To enable this feature despite the lack of adoption, I propose adding a mapping for some popular chat templates to their response schemas (currently only Qwen3).

if tokenizer.chat_template == qwen3_chat_template:
tokenizer.response_schema = qwen3_schema
return tokenizer

Ideally, once adoption increases and model repos start including proper response schemas, we can remove this custom mapping entirely.

A fair amount of complexity in the generation

This PR adds 60+ lines of intricate code with many special cases in the generation method. While it’s admittedly hard to follow, after a lot of iteration this is likely the simplest reliable way to implement the feature. Normally, I would be very reluctant to introduce this level of complexity, but given the impact of this feature, I believe it’s truly worth it.

# Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt
while idxs_with_tool:
prompt_completion_tools = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls
# Tokenize the current prompt. We will use this to filter out overlong samples later.
kwargs = {
"tools": self.tools,
"add_generation_prompt": True,
"tokenize": True,
"chat_template": self.chat_template,
**self.chat_template_kwargs,
}
p_ids = self.processing_class.apply_chat_template(prompt_completion_tools, **kwargs)["input_ids"]
# Call the tools, and build the new prompt for generation
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
tool_call_list = tool_calls[idx]
prompt_completion_tool = prompt_completion_tools[idx]
prompt_completion_tool.append(completions[idx_with_tool][-1])
for tool_call in tool_call_list:
tool_call_count += 1
if tool_call["type"] == "function":
function = tool_call["function"]
try:
result = self._tool_dict[function["name"]](**function["arguments"])
except Exception as e:
result = {"error": str(e)}
tool_failure_count += 1
else:
result = {"error": f"Unsupported tool call type: {tool_call['type']}"}
tool_call["result"] = result
tool_message = {"role": "tool", "name": function["name"], "content": str(result)}
prompt_completion_tool.append(tool_message)
completions[idx_with_tool].append(tool_message)
# Tokenize and filter samples whose length exceeds max allowed length. This is important, because if vLLM
# is called with an input longer than its max model length, it will error out.
pct_ids = self.processing_class.apply_chat_template(prompt_completion_tools, **kwargs)["input_ids"]
overlong = [len(pct) - len(p) >= self.max_completion_length for p, pct in zip(p_ids, pct_ids, strict=True)]
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
if overlong[idx]:
prompt_length = len(prompt_ids[idx_with_tool])
ct = pct_ids[idx][prompt_length : prompt_length + self.max_completion_length]
completion_ids[idx_with_tool] = ct
tool_mask[idx_with_tool] += [0] * (len(ct) - len(tool_mask[idx_with_tool]))
if logprobs is not None:
logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool]))
idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o]
prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o]
if not idxs_with_tool:
break # all overlong, exit tool loop
# Generate new completions after tool execution
prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs, _ = self._generate_single_turn(
prompt_completion_tools
)
# Sanity check: from experience, this is useful to catch bugs in the chat template
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool
assert prompt_ids[idx_with_tool] == pct[: len(prompt_ids[idx_with_tool])]
# Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
prompt_len = len(prompt_ids[idx_with_tool])
completion_tool_ids = prompt_completion_tool_ids[idx][prompt_len:]
excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length
if excess_length > 0:
# If exceeding max length, truncate post_tool_ids
post_tool_ids[idx] = post_tool_ids[idx][:-excess_length]
if logprobs is not None:
post_tool_logprobs[idx] = post_tool_logprobs[idx][:-excess_length]
excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length
if excess_length > 0:
# If still exceeding max length, truncate completion_tool_ids as well
prompt_completion_tool_ids[idx] = prompt_completion_tool_ids[idx][:-excess_length]
# Update tool_mask: the tool result should be 1 and the post-tool 0
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
prompt_completion_tool_length = len(prompt_completion_tool_ids[idx])
prompt_length = len(prompt_ids[idx_with_tool])
completion_length = len(completion_ids[idx_with_tool])
post_tool_length = len(post_tool_ids[idx])
tool_length = prompt_completion_tool_length - prompt_length - completion_length
tool_mask[idx_with_tool] += [1] * tool_length + [0] * post_tool_length
if logprobs is not None:
logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx]
# Update completion_ids with the new completions (after tool execution)
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
prompt_length = len(prompt_ids[idx_with_tool])
pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool
completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx]
# Decode post-tool completions
post_tool_completions = [
parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids
]
# Add post-tool completions to the existing completions
for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
if post_tool_completions[idx]: # {} if post-tool completions completely truncated
completions[idx_with_tool].append(post_tool_completions[idx])
# Check for further tool calls
tool_calls = [completion.get("tool_calls") for completion in post_tool_completions]
idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call]
tool_calls = [tool_call for tool_call in tool_calls if tool_call]

Next steps

  • Add tool support to the vLLM client/server (currently supported only for collocated vLLM and Transformers generation)
  • Provide the same integration for RLOO

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor change to fix #4609.

After the merge of huggingface/transformers#40936, the attribute does not necessarily exist. Before it was None.

Feel free to ignore it if there is a better solution!

guess = completion[-1]["content"].strip()
guess = completion[-1]["content"].strip().lower()
guess_clean = guess.replace("*", "").replace("`", "").strip()
reward = 0.0
Copy link
Collaborator

@edbeeching edbeeching Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could L75-82 be simplified to:


if guess_clean == ans.lower():
    reward = 0.5
else:
    reward = -0.2

if "error" in turn["content"].lower():
reward -= 0.3 # penalize errors

if tool_called and tool_response_ok:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

107-112 would be easier to parse for a reader like this:

if tool_called:
    if tool_response_ok:
        reward += 0.25
    else:
        reward -= 0.2
else:
    reward -= 0.3

@qgallouedec qgallouedec changed the title 🕵️‍♂️ Agent training 🕵️‍♂️ GRPO: Agent training Dec 9, 2025
@qgallouedec qgallouedec merged commit ca21dd9 into main Dec 9, 2025
12 of 13 checks passed
@qgallouedec qgallouedec deleted the tool-call-finally branch December 9, 2025 03:45
@austinmw
Copy link

Would this be plug-and-play with smolagents for inference?

Also curious about extensibility to langgraph or strands agent harness.

pranavvm26 added a commit to aws-samples/amazon-sagemaker-generativeai that referenced this pull request Jan 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants