Skip to content

[GRPO] Truncated Importance Sampling to address rollout-training mismatch#3867

Merged
LeonEricsson merged 31 commits intohuggingface:mainfrom
LeonEricsson:rollout_off_policy_importance_sampling
Sep 3, 2025
Merged

[GRPO] Truncated Importance Sampling to address rollout-training mismatch#3867
LeonEricsson merged 31 commits intohuggingface:mainfrom
LeonEricsson:rollout_off_policy_importance_sampling

Conversation

@LeonEricsson
Copy link
Collaborator

@LeonEricsson LeonEricsson commented Aug 7, 2025

Motivation

TRL provides the option of using vLLM for rollouts, enabling fast and scalable generation. However, the token probabilities for the generated completions—used in the GRPO objective—do not come directly from vLLM. Instead, these probabilities are recomputed by the training backend. It has been known for a while that vLLM probabilities differ from Hugging Face, which ultimately means we inadvertently train off policy from our generation policy, despite using the same weights. A recent blog post highlights the effect of this discrepancy and proposes a solution in the form of an importance sampling factor

What does this PR do?

Initially, we document the numerical differences in token probabilities when using vLLM.

Depending on the results, we may address the issue through the recommended Truncated Importance Sampling method.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@LeonEricsson
Copy link
Collaborator Author

LeonEricsson commented Aug 7, 2025

First experiment results, these are with vLLM rollout (in server mode), and a vanilla training backend (no fsdp/deepspeed).

Screenshot 2025-08-07 at 17 57 35

We observe a considerable difference, in line with the blog.

Run on 2ef3af6.

Code to run ```python from datasets import load_dataset from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train[:50]")

Define the reward function, which rewards completions that are close to 20 characters

def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(
num_train_epochs=1,
steps_per_generation=4,
per_device_train_batch_size=2,
num_generations=4,
logging_steps=1,
report_to="wandb"
)

trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train()

</details>

@qgallouedec
Copy link
Member

This is a very important one, thanks! Is it ready for review?

@LeonEricsson
Copy link
Collaborator Author

LeonEricsson commented Aug 11, 2025

This is a very important one, thanks! Is it ready for review?

I'm wavering on how we want to address this. Either we keep recomputing and introduce the Truncated Importance Sampling approach from the blog, or we move away from recomputing and use vLLM logprobs directly, everywhere. Both are valid; I see this as more a question of which approach scales better

@LeonEricsson
Copy link
Collaborator Author

LeonEricsson commented Aug 20, 2025

Some training runs with TIS compared to baseline (TIS=-1). I wouldn't expect things to look much different, both are stable. Generally KL($\pi_{inference}, \pi_{training}$) is low.

For context this is single gpu training, on gsm8k.

Screenshot 2025-08-20 at 11 20 03 Screenshot 2025-08-20 at 11 19 15 Screenshot 2025-08-20 at 11 19 41 Screenshot 2025-08-20 at 11 23 29

@LeonEricsson
Copy link
Collaborator Author

same experiments with PPO-IS

Screenshot 2025-08-24 at 16 29 20

@LeonEricsson
Copy link
Collaborator Author

I've cleaned up the implementation now.

There's one existing issue which is that vLLM sometimes spits out a NaN logprobs for the chosen token. This needs to be handled.

@LeonEricsson
Copy link
Collaborator Author

LeonEricsson commented Aug 27, 2025

we should update to use the final processed logprobs from vLLM, from vllm-project/vllm#22387. prior versions of vLLM didn't support retrieving the sampled logprobs.

EDIT: the vllm patch has not been released yet, so we can hold of on that change for a future pr

@LeonEricsson LeonEricsson changed the title vLLM rollout numerical differences causing off-policy RL. [GRPO] Truncated Importance Sampling to address rollout-training mismatch Aug 28, 2025
@qgallouedec
Copy link
Member

I think we're good @LeonEricsson, right? Or is this PR still draft?

@LeonEricsson LeonEricsson marked this pull request as ready for review September 1, 2025 12:47
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@LeonEricsson LeonEricsson merged commit 12fc85f into huggingface:main Sep 3, 2025
10 checks passed
SamY724 pushed a commit to SamY724/trl that referenced this pull request Sep 6, 2025
…atch (huggingface#3867)

Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants