[GRPO] Truncated Importance Sampling to address rollout-training mismatch#3867
Conversation
|
First experiment results, these are with vLLM rollout (in server mode), and a vanilla training backend (no fsdp/deepspeed).
We observe a considerable difference, in line with the blog. Run on 2ef3af6. Code to run```python from datasets import load_dataset from trl import GRPOConfig, GRPOTrainerdataset = load_dataset("trl-lib/tldr", split="train[:50]") Define the reward function, which rewards completions that are close to 20 charactersdef reward_len(completions, **kwargs): training_args = GRPOConfig( trainer = GRPOTrainer( |
|
This is a very important one, thanks! Is it ready for review? |
I'm wavering on how we want to address this. Either we keep recomputing and introduce the Truncated Importance Sampling approach from the blog, or we move away from recomputing and use vLLM logprobs directly, everywhere. Both are valid; I see this as more a question of which approach scales better |
|
I've cleaned up the implementation now. There's one existing issue which is that vLLM sometimes spits out a NaN logprobs for the chosen token. This needs to be handled. |
|
we should update to use the final processed logprobs from vLLM, from vllm-project/vllm#22387. prior versions of vLLM didn't support retrieving the sampled logprobs. EDIT: the vllm patch has not been released yet, so we can hold of on that change for a future pr |
…ortance sampling configuration
|
I think we're good @LeonEricsson, right? Or is this PR still draft? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…atch (huggingface#3867) Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>






Motivation
TRL provides the option of using vLLM for rollouts, enabling fast and scalable generation. However, the token probabilities for the generated completions—used in the GRPO objective—do not come directly from vLLM. Instead, these probabilities are recomputed by the training backend. It has been known for a while that vLLM probabilities differ from Hugging Face, which ultimately means we inadvertently train off policy from our generation policy, despite using the same weights. A recent blog post highlights the effect of this discrepancy and proposes a solution in the form of an importance sampling factor
What does this PR do?
Initially, we document the numerical differences in token probabilities when using vLLM.
Depending on the results, we may address the issue through the recommended Truncated Importance Sampling method.
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.