Skip to content
Closed
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
495f967
first commit
qgallouedec Mar 16, 2026
6a35777
consistency
qgallouedec Mar 16, 2026
1652027
fix
qgallouedec Mar 16, 2026
dcd9bdb
address review
qgallouedec Mar 16, 2026
a7fecc2
wire timeout + fix name tool
qgallouedec Mar 16, 2026
5d2327e
Merge branch 'main' into async-grpo
qgallouedec Mar 16, 2026
9e7706e
add buffer queue size metric to samples in AsyncRolloutWorker
qgallouedec Mar 16, 2026
7228399
added aiohttp
AmineDiro Mar 17, 2026
7e12720
style
qgallouedec Mar 16, 2026
a62fe60
support async scoring of reward functions and enhance metrics reporting
qgallouedec Mar 16, 2026
606c98b
style
qgallouedec Mar 16, 2026
887fe44
docstring + consistency
qgallouedec Mar 16, 2026
f6084ff
better segment trainer/rollout
qgallouedec Mar 17, 2026
a162681
inherit from base trainer and allow max_steps=None and stub for test
qgallouedec Mar 17, 2026
94fbeac
refactor: adjust max_steps calculation to consider accelerator processes
qgallouedec Mar 17, 2026
99ccd54
doc
qgallouedec Mar 17, 2026
ce6e00c
better log tok/sec
qgallouedec Mar 17, 2026
bdaf5c1
fix: update max_completion_length default value to 2048
qgallouedec Mar 17, 2026
3c5ec3d
fix timeout, tok/sec, and vllm import
qgallouedec Mar 17, 2026
fae8012
style
qgallouedec Mar 17, 2026
f468106
style
qgallouedec Mar 17, 2026
626a69d
Update async_grpo_trainer.md
qgallouedec Mar 17, 2026
1e5335c
Add AsyncGRPOTrainer with async rollout worker and weight transfer
AmineDiro Mar 17, 2026
64205d5
init weight sync"
AmineDiro Mar 17, 2026
9f66524
aiohttp in vllm dep
AmineDiro Mar 17, 2026
9be36f7
fix tracking metrics namespace
AmineDiro Mar 17, 2026
976f4ee
Add logging to weight sync initialization and train metrics
AmineDiro Mar 17, 2026
18d1ed7
license
qgallouedec Mar 17, 2026
ea72fe3
fix conditional import + allow 0.17.1
qgallouedec Mar 17, 2026
c62c5e8
fix type hint
qgallouedec Mar 17, 2026
7374cb7
fix vllm segment in pyproject
qgallouedec Mar 17, 2026
0a216b9
Merge branch 'main' into async-grpo
qgallouedec Mar 17, 2026
ed3d10a
align server timeout
qgallouedec Mar 17, 2026
392665f
rename Prompt type alias to Messages for clarity
qgallouedec Mar 17, 2026
6c38628
update vllm version requirement and clarify installation instructions
qgallouedec Mar 17, 2026
5663391
oups
qgallouedec Mar 17, 2026
c43d6f8
add max-model-len parameter and usage tip for vLLM server configuration
qgallouedec Mar 17, 2026
91b944c
Fix DDP/FSDP1 module prefix handling and vLLM pause mode
AmineDiro Mar 17, 2026
b54beec
some debug logging, remove get and better default values
qgallouedec Mar 17, 2026
2c20788
no need to log bsz (already in config)
qgallouedec Mar 17, 2026
289138f
revert pause mode change
qgallouedec Mar 17, 2026
406f415
more logging
qgallouedec Mar 17, 2026
5b6300c
Enhance max_inflight_tasks handling and documentation for vLLM server…
qgallouedec Mar 18, 2026
14d893d
keep
qgallouedec Mar 18, 2026
d4ab4e4
style
qgallouedec Mar 18, 2026
6502c19
Improve timeout handling and retry logic for HTTP requests in AsyncRo…
qgallouedec Mar 18, 2026
cd6d773
Add max_tool_calling_iterations parameter to AsyncRolloutWorker and A…
qgallouedec Mar 18, 2026
e92b883
fix config
qgallouedec Mar 18, 2026
e8a512d
Raise exceptions from tasks in AsyncRolloutWorker if present
qgallouedec Mar 18, 2026
a115faf
Improve error handling for server connection issues in AsyncRolloutWo…
qgallouedec Mar 18, 2026
d8c2908
Improve queue handling in AsyncRolloutWorker to prevent blocking on f…
qgallouedec Mar 18, 2026
5d5c46c
first sync and then start generation
qgallouedec Mar 18, 2026
9a76dec
nits
qgallouedec Mar 18, 2026
6266546
we don't need to set max-num-seq apparently
qgallouedec Mar 18, 2026
c22c8fa
remove dead code
qgallouedec Mar 18, 2026
0822d13
Add AsyncGRPOTrainer example with MBPP dataset and improve async rollout
AmineDiro Mar 17, 2026
c84b81c
cancel inflight when pausing
AmineDiro Mar 18, 2026
3906f53
fix max_seq_len
AmineDiro Mar 18, 2026
7d3f656
debug msg
AmineDiro Mar 18, 2026
232d2e1
Add callback to sync weights before training begins
AmineDiro Mar 18, 2026
8156e11
Fix attribute access in AsyncGRPOTrainer._inner_training_loop
AmineDiro Mar 18, 2026
a4a675e
Fix missing iteration counter increment and add reward_mean metric
AmineDiro Mar 18, 2026
7a44f21
Add advantage to rollout worker reward output
AmineDiro Mar 18, 2026
723236a
Add tool-calling completion tracking and schema conversion
AmineDiro Mar 18, 2026
1769c34
Remove errored iteration_num increments
AmineDiro Mar 18, 2026
1ae5ecc
Add turn-level and tool-call timing records to async rollout worker
AmineDiro Mar 19, 2026
e106e08
Refactor RolloutGroup to use append_rollout helper method
AmineDiro Mar 19, 2026
d1c7303
merged main
AmineDiro Mar 19, 2026
b3f534c
fix main problems
AmineDiro Mar 19, 2026
0378a09
fix logprobs main
AmineDiro Mar 19, 2026
e851ede
Refactor RolloutCompletion to store turns and compute fields lazily
AmineDiro Mar 19, 2026
856049a
Apply suggestion from @qgallouedec
AmineDiro Mar 19, 2026
a1dc9e1
Apply suggestion from @qgallouedec
AmineDiro Mar 19, 2026
dcb760f
Apply suggestion from @qgallouedec
AmineDiro Mar 19, 2026
e983fdd
Apply suggestion from @qgallouedec
AmineDiro Mar 19, 2026
987ed87
Apply suggestion from @qgallouedec
AmineDiro Mar 19, 2026
82519c9
Apply suggestion from @qgallouedec
AmineDiro Mar 19, 2026
2e03320
Apply suggestion from @qgallouedec
AmineDiro Mar 19, 2026
46d4315
Improve metrics collection in async rollout worker
AmineDiro Mar 19, 2026
674073a
Rename `max_completion_tokens` to `max_completion_length`
AmineDiro Mar 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,7 @@ checklink/cookies.txt
# wandb files
nbs/wandb/
examples/notebooks/wandb/
wandb/
wandb/

# uv
uv.lock
4 changes: 3 additions & 1 deletion docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@
title: Experimental Overview
- local: openenv
title: OpenEnv Integration
- local: bema_for_reference_model # Sorted alphabetically
- local: async_grpo_trainer # Sorted alphabetically
title: Asynchronous GRPO
- local: bema_for_reference_model
title: BEMA for Reference Model
- local: bco_trainer
title: BCO
Expand Down
79 changes: 79 additions & 0 deletions docs/source/async_grpo_trainer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Asynchronous GRPO

> [!IMPORTANT]
> This trainer requires `vllm>=0.17.1` and `transformers>=5.2.0`. For distributed training, only FSDP2 is supported (DeepSpeed ZeRO is not).
>
> Currently, `vllm` and `transformers` have conflicting dependency constraints. To work around this, install vLLM first and then force-install transformers:
>
> ```bash
> pip install 'vllm>=0.17.1'
> pip install 'transformers>=5.2.0' --no-deps
> ```

## Overview

[`AsyncGRPOTrainer`] implements the same [GRPO](grpo_trainer) algorithm but decouples rollout generation from training. A background worker continuously streams completions from a vLLM server while the training loop consumes them, so generation and gradient updates overlap instead of alternating. The API mirrors [`GRPOTrainer`] — for full details on the GRPO method itself (advantage computation, KL estimation, loss formulation, reward functions, etc.), see the [GRPO Trainer](grpo_trainer) documentation. Not all features from [`GRPOTrainer`] are available; refer to [`AsyncGRPOConfig`] for the supported parameters.

This trainer was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Amine Dirhoussi](https://huggingface.co/aminediroHF).

## How it differs from [`GRPOTrainer`]

In the standard [`GRPOTrainer`], generation and training are sequential: generate a batch, compute the loss, update weights, repeat. Even in [vLLM colocate mode](grpo_trainer#speed-up-training-with-vllm), where generation runs on the same GPUs, one phase must finish before the other begins.

[`AsyncGRPOTrainer`] separates these two concerns:

- **Rollout worker** (background thread) — sends prompts to a vLLM server, scores completions with reward functions, computes advantages, and pushes ready-to-train samples into a queue.
- **Training loop** (main process) — pulls samples from the queue, computes the clipped surrogate loss, and updates the model weights.

After every `weight_sync_steps` training steps, the updated weights are transferred to the vLLM server via NCCL so that subsequent generations reflect the latest policy.

Because generation and training run concurrently, the training samples may have been generated by a slightly older version of the model. The `max_staleness` parameter controls how many weight updates a sample can lag behind before being discarded.

The number of concurrent requests sent to the vLLM server is controlled by `max_inflight_tasks`. By default it is set automatically to `max_staleness × per_device_train_batch_size × gradient_accumulation_steps × num_processes` — the maximum number of samples the trainer can consume before they become stale. Generating more than this is wasteful since the excess samples will be discarded.

## Quick start

```python
# train_async_grpo.py
from datasets import load_dataset
from trl.experimental.async_grpo import AsyncGRPOTrainer
from trl.rewards import accuracy_reward

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = AsyncGRPOTrainer(
model="Qwen/Qwen3-4B",
reward_funcs=accuracy_reward,
train_dataset=dataset,
)
trainer.train()
```

The vLLM server and the trainer must run on **separate GPUs**. Use `CUDA_VISIBLE_DEVICES` to partition your GPUs. For example, with 2 GPUs, you can run the vLLM server on GPU 0 and the trainer on GPU 1 as follows:

```bash
# Terminal 1: vLLM server on GPU 0 (dev mode + NCCL weight transfer are required)
CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \
--max-model-len 4096 \
--weight-transfer-config '{"backend":"nccl"}'
```

> [!TIP]
> Set `--max-model-len` to the maximum total sequence length (prompt + completion) you expect. A lower value reduces GPU memory usage on the server, freeing more memory for the KV cache and increasing throughput. A good starting point is the prompt length plus `max_completion_length` from your config.

```bash
# Terminal 2: training on GPU 1
CUDA_VISIBLE_DEVICES=1 accelerate launch train_async_grpo.py
```

## Design philosophy

This trainer is intentionally kept minimal and is not meant to grow into a general-purpose solution. If you need a feature that is not supported, we recommend cloning the repository and adapting the trainer to your needs directly. New features will only be considered when there is significant community demand.

## AsyncGRPOConfig

[[autodoc]] trl.experimental.async_grpo.AsyncGRPOConfig

## AsyncGRPOTrainer

[[autodoc]] trl.experimental.async_grpo.AsyncGRPOTrainer
2 changes: 1 addition & 1 deletion docs/source/vllm_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood.

> [!WARNING]
> TRL currently only supports vLLM versions from `0.10.2` to `0.17.0`. Please ensure you have a version in this range installed to avoid compatibility issues.
> TRL currently only supports vLLM versions from `0.10.2` to `0.17.1`. Please ensure you have a version in this range installed to avoid compatibility issues.

> [!TIP]
> The following trainers currently support generation with vLLM:
Expand Down
71 changes: 71 additions & 0 deletions examples/scripts/async_grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
CUDA_VISIBLE_DEVICES=2,3,4,5 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \
--data-parallel-size 4 \
--weight-transfer-config '{"backend":"nccl"}' \
--max-model-len 9216

CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file examples/accelerate_configs/fsdp2.yaml examples/scripts/async_grpo.py
Comment thread
AmineDiro marked this conversation as resolved.
Outdated

!/! NOTE: depends on transformers > 5.0.0
"""

import logging
import os

from datasets import load_dataset

from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer
from trl.rewards import accuracy_reward


logging.basicConfig(
level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.getLogger("trl").setLevel(logging.DEBUG)


def format_sample(sample):
return {"prompt": sample["messages"][:1], "solution": sample["answer"]}


def main() -> None:
dataset = load_dataset("open-r1/OpenR1-Math-220k", split="train[:10000]")
dataset = dataset.map(format_sample, remove_columns=dataset.column_names)

config = AsyncGRPOConfig(
output_dir="./results",
per_device_train_batch_size=1,
num_train_epochs=1,
max_completion_length=4096,
max_steps=10,
report_to="trackio",
trackio_space_id=None,
project="async_grpo",
log_completions=True,
)
trainer = AsyncGRPOTrainer(
model="Qwen/Qwen3-4B",
args=config,
train_dataset=dataset,
reward_funcs=accuracy_reward,
)
trainer.train()


if __name__ == "__main__":
main()
173 changes: 173 additions & 0 deletions examples/scripts/async_grpo_mbpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
CUDA_VISIBLE_DEVICES=2,3,4,5 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \
--data-parallel-size 4 \
--weight-transfer-config '{"backend":"nccl"}' \
--max-num-seqs 64 \
Comment thread
AmineDiro marked this conversation as resolved.
Outdated
--max-model-len 9216

CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file examples/accelerate_configs/fsdp2.yaml examples/scripts/async_grpo_mbpp.py
Comment thread
AmineDiro marked this conversation as resolved.
Outdated

!/! NOTE: depends on transformers > 5.0.0
"""

import logging
import os
import subprocess
import sys
import tempfile

from datasets import load_dataset

from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer


logging.basicConfig(
level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)


class MBPPEnvironment:
"""
A synchronous environment class designed for `AsyncGRPOTrainer`.
Each environment instance handles tracking test cases and exposes the `execute_python_code` tool.
"""

def __init__(self):
self.test_list = []
self.done = False

Comment on lines +49 to +52

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we can remove this

Suggested change
def __init__(self):
self.test_list = []
self.done = False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's safer because it will fail if we try to step before the reset

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think test_list has to be init to [] if I am not mistaken

def reset(self, test_list: list[str], **kwargs):
"""
Resets the environment with the test suite for the new problem.
`**kwargs` ignores additional columns sent from the dataset map.
"""
self.test_list = test_list
self.done = False

def execute_python_code(self, code: str) -> str:
"""Execute python code to test against the hidden test cases. Provide the complete python code.

Args:
code: The complete python code to execute.

Returns:
program stdout string
"""
full_code = code + "\n\n" + "\n".join(self.test_list)

with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
f.write(full_code)
temp_path = f.name

try:
result = subprocess.run(
[sys.executable, temp_path],
capture_output=True,
text=True,
timeout=3.0,
)
if result.returncode == 0:
self.done = True
return "Tests passed."
else:
# Return the last 2000 characters of the stderr to fit within context length
feedback = result.stderr[-2000:]
return f"Execution failed with error:\n{feedback}\nPlease fix the code and try again."
except subprocess.TimeoutExpired:
return "Execution timeout."
finally:
if os.path.exists(temp_path):
os.remove(temp_path)

def is_done(self) -> bool:
return self.done

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it's missing the docstring, it won't be exposed as a tool to the model. If it's not meant to be exposed, then I'd recommend def _is_done. If it is meant to be exposed as a tool, it requires a docstring.

@AmineDiro AmineDiro Mar 19, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is not a tool function. the _is_done means it it a private function but we cant to signal that it needs to be implemented.
The issue we currently have is that user can't define a stop function to exist the generation loop. For example, when in code environment, we want to exist if the tests success but the current setup keeps rolling until we reach max_turns or token exhaustion



def tests_passed_reward(completions, **kwargs) -> list[float]:
"""
Reward function that checks the model's chat history for the last tool execution
and returns 1.0 if the tests passed, 0.0 otherwise.
"""
rewards = []
for completion in completions:
passed = False
# Interrogate the completion, looping backwards to find the last tool interaction result
for msg in reversed(completion):
if msg["role"] == "tool" and "Tests passed." in msg.get("content", ""):
passed = True
break
rewards.append(1.0 if passed else 0.0)
return rewards


def format_sample(sample):
"""
Format the MBPP dataset row into a prompt using OpenAI chat formatting,
and persist `test_list` so `reset()` can inject it.
"""
prompt_text = sample.get("text", "")
content = (
f"You are an expert Python programmer.\n\n"
f"{prompt_text}\n\n"
f"Please write a python code to solve the problem and use the execute_python_code tool to test it."
)
prompt = [{"role": "user", "content": content}]

return {"prompt": prompt, "test_list": sample.get("test_list", [])}


def main() -> None:
os.environ["WANDB_PROJECT"] = "async_grpo_trl_mbpp"
# 1. Load dataset
dataset = load_dataset("google-research-datasets/mbpp", split="train+test")
dataset = dataset.map(format_sample, remove_columns=dataset.column_names)

# 2. Config setup
config = AsyncGRPOConfig(
Comment thread
AmineDiro marked this conversation as resolved.
Outdated
output_dir="./results",
per_device_train_batch_size=1,
max_completion_length=8192,
max_seq_length=8192,
max_tool_calling_iterations=5,
max_steps=100,
max_staleness=8,
# Logging
log_completions=True,
num_completions_to_print=2,
report_to="wandb",
logging_steps=1,
# trackio
project="async_grpo_trl_mbpp",
trackio_space_id=None,
Comment thread
AmineDiro marked this conversation as resolved.
Outdated
)

# 3. Trainer initialization
trainer = AsyncGRPOTrainer(
model="Qwen/Qwen3-4B",
args=config,
Comment thread
AmineDiro marked this conversation as resolved.
Outdated
train_dataset=dataset,
reward_funcs=[tests_passed_reward],
environment_factory=MBPPEnvironment,
)

# 4. Train
trainer.train()


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ test = [
"pytest"
]
vllm = [
"vllm>=0.10.2,<=0.17.0",
"vllm>=0.10.2,<=0.17.1",
"fastapi",
"pydantic",
"aiohttp>=3.13.3",
"requests",
"uvicorn"
]
Expand Down
Loading
Loading