Conversation
…_thw` in GRPO and RLOO trainers; update `split_pixel_values_by_grid` to use `image_grid_thw`
Member
There was a problem hiding this comment.
Just a minor change to fix #4609.
After the merge of huggingface/transformers#40936, the attribute does not necessarily exist. Before it was None.
Feel free to ignore it if there is a better solution!
edbeeching
reviewed
Dec 3, 2025
| guess = completion[-1]["content"].strip() | ||
| guess = completion[-1]["content"].strip().lower() | ||
| guess_clean = guess.replace("*", "").replace("`", "").strip() | ||
| reward = 0.0 |
Collaborator
There was a problem hiding this comment.
Could L75-82 be simplified to:
if guess_clean == ans.lower():
reward = 0.5
else:
reward = -0.2
edbeeching
reviewed
Dec 3, 2025
examples/scripts/grpo_agent.py
Outdated
| if "error" in turn["content"].lower(): | ||
| reward -= 0.3 # penalize errors | ||
|
|
||
| if tool_called and tool_response_ok: |
Collaborator
There was a problem hiding this comment.
107-112 would be easier to parse for a reader like this:
if tool_called:
if tool_response_ok:
reward += 0.25
else:
reward -= 0.2
else:
reward -= 0.3
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
qgallouedec
commented
Dec 8, 2025
|
Would this be plug-and-play with smolagents for inference? Also curious about extensibility to langgraph or strands agent harness. |
This was referenced Dec 16, 2025
pranavvm26
added a commit
to aws-samples/amazon-sagemaker-generativeai
that referenced
this pull request
Jan 30, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR implements tool calling for GRPO. The API is as follows:
This PR contains a few important changes:
🚨 Removal of
max_prompt_lengthThis PR contains a breaking change:
max_prompt_lengthhas been removed from GRPO.Here are the reasons: (tldr: because it’s extremely hard to implement reliably with multi-turn tool calling, likely harmful to training anyway, likely not used in practice, and dropping it simplifies the API while keeping it consistent across LLMs and VLMs.)
Supporting
max_prompt_lengthwith tool calling is extremely complex.For single-turn generation it works fine, but multi-turn generation introduces a major challenge: the prompt grows after every step. Since the model is called repeatedly with an increasingly long prompt, we would need to recalculate the allowed prompt length dynamically based on how many tokens have already been generated. Implementing this reliably is tricky and adds significant complexity.
Truncating prompts is likely worse than dropping samples altogether.
Although I’m not aware of formal studies, intuition suggests that truncation can remove information necessary to solve the task. Training on such incomplete examples can lead to strong biases, whereas simply skipping overly long samples avoids this risk.
It simplifies the API and removes confusing edge cases.
Previously, when training VLMs, we had to tell users to disable prompt truncation entirely because Transformers does not support truncating multimodal prompts. This led to inconsistent, non-user-friendly recommendations. Removing
max_prompt_lengthallows us to provide one clean, unified API that works for all model types.It very likely not a widely used feature anyway
Online decoding
Before calling the reward function, we need to decode the completion. Previously, this was done here:
trl/trl/trainer/grpo_trainer.py
Lines 1605 to 1617 in 1a9ff52
The issue is that, while this works for single-turn outputs, it does not allow reliable parsing of multi-turn text. See this internal discussion. The workaround is to parse after each turn, which requires moving the decoding logic inside the generation loop (in
_generate):trl/trl/trainer/grpo_trainer.py
Lines 1483 to 1495 in c54bf4f
trl/trl/trainer/grpo_trainer.py
Line 1543 in c54bf4f
trl/trl/trainer/grpo_trainer.py
Lines 1614 to 1618 in c54bf4f
The method then returns the list of messages:
trl/trl/trainer/grpo_trainer.py
Line 1669 in c54bf4f
Note that this change removes support for the "bootstrap" feature. I haven’t had time to investigate adding support for it.
Tool mask
We don't want the loss to be computed on the tokens corresponding to the tool result. Consequently,
_generatebuilds and return atool_masktrl/trl/trainer/grpo_trainer.py
Line 1668 in c54bf4f
which is then used to mask these tokens in the loss computation.
trl/trl/trainer/grpo_trainer.py
Line 2100 in c54bf4f
Schema and fixed chat template
Chat template
To make this feature work, we need the chat template to be prefix-preserving. Ie:
trl/trl/chat_template_utils.py
Lines 195 to 212 in 9f0aa3d
The issue is that some widely used tokenizers, such as GPT-OSS and Qwen3, are not prefix-preserving due to the way they handle think tokens. To address this, I suggest using a slightly modified version of the template that ensures it is prefix-preserving. Additionally, as @lewtun pointed out, it’s not even clear whether these templates might make the inference OOD
Response schema
To parse tool calls from the model’s response, we rely on
tokenizer.parse_response, introduced in huggingface/transformers#40894. This requires the tokenizer to have aresponse_schema(integrated in a similar way as chat templates). However, very few (no?) model repositories currently include such a schema.To enable this feature despite the lack of adoption, I propose adding a mapping for some popular chat templates to their response schemas (currently only Qwen3).
trl/trl/chat_template_utils.py
Lines 172 to 174 in fbb625f
Ideally, once adoption increases and model repos start including proper response schemas, we can remove this custom mapping entirely.
A fair amount of complexity in the generation
This PR adds 60+ lines of intricate code with many special cases in the generation method. While it’s admittedly hard to follow, after a lot of iteration this is likely the simplest reliable way to implement the feature. Normally, I would be very reluctant to introduce this level of complexity, but given the impact of this feature, I believe it’s truly worth it.
trl/trl/trainer/grpo_trainer.py
Lines 1509 to 1623 in c54bf4f
Next steps