-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Rollout generation #5299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rollout generation #5299
Changes from 66 commits
495f967
6a35777
1652027
dcd9bdb
a7fecc2
5d2327e
9e7706e
7228399
7e12720
a62fe60
606c98b
887fe44
f6084ff
a162681
94fbeac
99ccd54
ce6e00c
bdaf5c1
3c5ec3d
fae8012
f468106
626a69d
1e5335c
64205d5
9f66524
9be36f7
976f4ee
18d1ed7
ea72fe3
c62c5e8
7374cb7
0a216b9
ed3d10a
392665f
6c38628
5663391
c43d6f8
91b944c
b54beec
2c20788
289138f
406f415
5b6300c
14d893d
d4ab4e4
6502c19
cd6d773
e92b883
e8a512d
a115faf
d8c2908
5d5c46c
9a76dec
6266546
c22c8fa
0822d13
c84b81c
3906f53
7d3f656
232d2e1
8156e11
a4a675e
7a44f21
723236a
1769c34
1ae5ecc
e106e08
d1c7303
b3f534c
0378a09
e851ede
856049a
a1dc9e1
dcb760f
e983fdd
987ed87
82519c9
2e03320
46d4315
674073a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -142,4 +142,7 @@ checklink/cookies.txt | |
| # wandb files | ||
| nbs/wandb/ | ||
| examples/notebooks/wandb/ | ||
| wandb/ | ||
| wandb/ | ||
|
|
||
| # uv | ||
| uv.lock | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # Asynchronous GRPO | ||
|
|
||
| > [!IMPORTANT] | ||
| > This trainer requires `vllm>=0.17.1` and `transformers>=5.2.0`. For distributed training, only FSDP2 is supported (DeepSpeed ZeRO is not). | ||
| > | ||
| > Currently, `vllm` and `transformers` have conflicting dependency constraints. To work around this, install vLLM first and then force-install transformers: | ||
| > | ||
| > ```bash | ||
| > pip install 'vllm>=0.17.1' | ||
| > pip install 'transformers>=5.2.0' --no-deps | ||
| > ``` | ||
|
|
||
| ## Overview | ||
|
|
||
| [`AsyncGRPOTrainer`] implements the same [GRPO](grpo_trainer) algorithm but decouples rollout generation from training. A background worker continuously streams completions from a vLLM server while the training loop consumes them, so generation and gradient updates overlap instead of alternating. The API mirrors [`GRPOTrainer`] — for full details on the GRPO method itself (advantage computation, KL estimation, loss formulation, reward functions, etc.), see the [GRPO Trainer](grpo_trainer) documentation. Not all features from [`GRPOTrainer`] are available; refer to [`AsyncGRPOConfig`] for the supported parameters. | ||
|
|
||
| This trainer was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Amine Dirhoussi](https://huggingface.co/aminediroHF). | ||
|
|
||
| ## How it differs from [`GRPOTrainer`] | ||
|
|
||
| In the standard [`GRPOTrainer`], generation and training are sequential: generate a batch, compute the loss, update weights, repeat. Even in [vLLM colocate mode](grpo_trainer#speed-up-training-with-vllm), where generation runs on the same GPUs, one phase must finish before the other begins. | ||
|
|
||
| [`AsyncGRPOTrainer`] separates these two concerns: | ||
|
|
||
| - **Rollout worker** (background thread) — sends prompts to a vLLM server, scores completions with reward functions, computes advantages, and pushes ready-to-train samples into a queue. | ||
| - **Training loop** (main process) — pulls samples from the queue, computes the clipped surrogate loss, and updates the model weights. | ||
|
|
||
| After every `weight_sync_steps` training steps, the updated weights are transferred to the vLLM server via NCCL so that subsequent generations reflect the latest policy. | ||
|
|
||
| Because generation and training run concurrently, the training samples may have been generated by a slightly older version of the model. The `max_staleness` parameter controls how many weight updates a sample can lag behind before being discarded. | ||
|
|
||
| The number of concurrent requests sent to the vLLM server is controlled by `max_inflight_tasks`. By default it is set automatically to `max_staleness × per_device_train_batch_size × gradient_accumulation_steps × num_processes` — the maximum number of samples the trainer can consume before they become stale. Generating more than this is wasteful since the excess samples will be discarded. | ||
|
|
||
| ## Quick start | ||
|
|
||
| ```python | ||
| # train_async_grpo.py | ||
| from datasets import load_dataset | ||
| from trl.experimental.async_grpo import AsyncGRPOTrainer | ||
| from trl.rewards import accuracy_reward | ||
|
|
||
| dataset = load_dataset("trl-lib/DeepMath-103K", split="train") | ||
|
|
||
| trainer = AsyncGRPOTrainer( | ||
| model="Qwen/Qwen3-4B", | ||
| reward_funcs=accuracy_reward, | ||
| train_dataset=dataset, | ||
| ) | ||
| trainer.train() | ||
| ``` | ||
|
|
||
| The vLLM server and the trainer must run on **separate GPUs**. Use `CUDA_VISIBLE_DEVICES` to partition your GPUs. For example, with 2 GPUs, you can run the vLLM server on GPU 0 and the trainer on GPU 1 as follows: | ||
|
|
||
| ```bash | ||
| # Terminal 1: vLLM server on GPU 0 (dev mode + NCCL weight transfer are required) | ||
| CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \ | ||
| --max-model-len 4096 \ | ||
| --weight-transfer-config '{"backend":"nccl"}' | ||
| ``` | ||
|
|
||
| > [!TIP] | ||
| > Set `--max-model-len` to the maximum total sequence length (prompt + completion) you expect. A lower value reduces GPU memory usage on the server, freeing more memory for the KV cache and increasing throughput. A good starting point is the prompt length plus `max_completion_length` from your config. | ||
|
|
||
| ```bash | ||
| # Terminal 2: training on GPU 1 | ||
| CUDA_VISIBLE_DEVICES=1 accelerate launch train_async_grpo.py | ||
| ``` | ||
|
|
||
| ## Design philosophy | ||
|
|
||
| This trainer is intentionally kept minimal and is not meant to grow into a general-purpose solution. If you need a feature that is not supported, we recommend cloning the repository and adapting the trainer to your needs directly. New features will only be considered when there is significant community demand. | ||
|
|
||
| ## AsyncGRPOConfig | ||
|
|
||
| [[autodoc]] trl.experimental.async_grpo.AsyncGRPOConfig | ||
|
|
||
| ## AsyncGRPOTrainer | ||
|
|
||
| [[autodoc]] trl.experimental.async_grpo.AsyncGRPOTrainer |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| CUDA_VISIBLE_DEVICES=2,3,4,5 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \ | ||
| --data-parallel-size 4 \ | ||
| --weight-transfer-config '{"backend":"nccl"}' \ | ||
| --max-model-len 9216 | ||
|
|
||
| CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file examples/accelerate_configs/fsdp2.yaml examples/scripts/async_grpo.py | ||
|
|
||
| !/! NOTE: depends on transformers > 5.0.0 | ||
| """ | ||
|
|
||
| import logging | ||
| import os | ||
|
|
||
| from datasets import load_dataset | ||
|
|
||
| from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer | ||
| from trl.rewards import accuracy_reward | ||
|
|
||
|
|
||
| logging.basicConfig( | ||
| level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), | ||
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | ||
| ) | ||
| logging.getLogger("trl").setLevel(logging.DEBUG) | ||
|
|
||
|
|
||
| def format_sample(sample): | ||
| return {"prompt": sample["messages"][:1], "solution": sample["answer"]} | ||
|
|
||
|
|
||
| def main() -> None: | ||
| dataset = load_dataset("open-r1/OpenR1-Math-220k", split="train[:10000]") | ||
| dataset = dataset.map(format_sample, remove_columns=dataset.column_names) | ||
|
|
||
| config = AsyncGRPOConfig( | ||
| output_dir="./results", | ||
| per_device_train_batch_size=1, | ||
| num_train_epochs=1, | ||
| max_completion_length=4096, | ||
| max_steps=10, | ||
| report_to="trackio", | ||
| trackio_space_id=None, | ||
| project="async_grpo", | ||
| log_completions=True, | ||
| ) | ||
| trainer = AsyncGRPOTrainer( | ||
| model="Qwen/Qwen3-4B", | ||
| args=config, | ||
| train_dataset=dataset, | ||
| reward_funcs=accuracy_reward, | ||
| ) | ||
| trainer.train() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,173 @@ | ||||||||
| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | ||||||||
| # | ||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||
| # you may not use this file except in compliance with the License. | ||||||||
| # You may obtain a copy of the License at | ||||||||
| # | ||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||
| # | ||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||
| # See the License for the specific language governing permissions and | ||||||||
| # limitations under the License. | ||||||||
|
|
||||||||
| """ | ||||||||
| CUDA_VISIBLE_DEVICES=2,3,4,5 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \ | ||||||||
| --data-parallel-size 4 \ | ||||||||
| --weight-transfer-config '{"backend":"nccl"}' \ | ||||||||
| --max-num-seqs 64 \ | ||||||||
|
AmineDiro marked this conversation as resolved.
Outdated
|
||||||||
| --max-model-len 9216 | ||||||||
|
|
||||||||
| CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file examples/accelerate_configs/fsdp2.yaml examples/scripts/async_grpo_mbpp.py | ||||||||
|
AmineDiro marked this conversation as resolved.
Outdated
|
||||||||
|
|
||||||||
| !/! NOTE: depends on transformers > 5.0.0 | ||||||||
| """ | ||||||||
|
|
||||||||
| import logging | ||||||||
| import os | ||||||||
| import subprocess | ||||||||
| import sys | ||||||||
| import tempfile | ||||||||
|
|
||||||||
| from datasets import load_dataset | ||||||||
|
|
||||||||
| from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer | ||||||||
|
|
||||||||
|
|
||||||||
| logging.basicConfig( | ||||||||
| level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), | ||||||||
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| class MBPPEnvironment: | ||||||||
| """ | ||||||||
| A synchronous environment class designed for `AsyncGRPOTrainer`. | ||||||||
| Each environment instance handles tracking test cases and exposes the `execute_python_code` tool. | ||||||||
| """ | ||||||||
|
|
||||||||
| def __init__(self): | ||||||||
| self.test_list = [] | ||||||||
| self.done = False | ||||||||
|
|
||||||||
|
Comment on lines
+49
to
+52
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO we can remove this
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's safer because it will fail if we try to step before the reset
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||||||||
| def reset(self, test_list: list[str], **kwargs): | ||||||||
| """ | ||||||||
| Resets the environment with the test suite for the new problem. | ||||||||
| `**kwargs` ignores additional columns sent from the dataset map. | ||||||||
| """ | ||||||||
| self.test_list = test_list | ||||||||
| self.done = False | ||||||||
|
|
||||||||
| def execute_python_code(self, code: str) -> str: | ||||||||
| """Execute python code to test against the hidden test cases. Provide the complete python code. | ||||||||
|
|
||||||||
| Args: | ||||||||
| code: The complete python code to execute. | ||||||||
|
|
||||||||
| Returns: | ||||||||
| program stdout string | ||||||||
| """ | ||||||||
| full_code = code + "\n\n" + "\n".join(self.test_list) | ||||||||
|
|
||||||||
| with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f: | ||||||||
| f.write(full_code) | ||||||||
| temp_path = f.name | ||||||||
|
|
||||||||
| try: | ||||||||
| result = subprocess.run( | ||||||||
| [sys.executable, temp_path], | ||||||||
| capture_output=True, | ||||||||
| text=True, | ||||||||
| timeout=3.0, | ||||||||
| ) | ||||||||
| if result.returncode == 0: | ||||||||
| self.done = True | ||||||||
| return "Tests passed." | ||||||||
| else: | ||||||||
| # Return the last 2000 characters of the stderr to fit within context length | ||||||||
| feedback = result.stderr[-2000:] | ||||||||
| return f"Execution failed with error:\n{feedback}\nPlease fix the code and try again." | ||||||||
| except subprocess.TimeoutExpired: | ||||||||
| return "Execution timeout." | ||||||||
| finally: | ||||||||
| if os.path.exists(temp_path): | ||||||||
| os.remove(temp_path) | ||||||||
|
|
||||||||
| def is_done(self) -> bool: | ||||||||
| return self.done | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because it's missing the docstring, it won't be exposed as a tool to the model. If it's not meant to be exposed, then I'd recommend
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is not a tool function. the |
||||||||
|
|
||||||||
|
|
||||||||
| def tests_passed_reward(completions, **kwargs) -> list[float]: | ||||||||
| """ | ||||||||
| Reward function that checks the model's chat history for the last tool execution | ||||||||
| and returns 1.0 if the tests passed, 0.0 otherwise. | ||||||||
| """ | ||||||||
| rewards = [] | ||||||||
| for completion in completions: | ||||||||
| passed = False | ||||||||
| # Interrogate the completion, looping backwards to find the last tool interaction result | ||||||||
| for msg in reversed(completion): | ||||||||
| if msg["role"] == "tool" and "Tests passed." in msg.get("content", ""): | ||||||||
| passed = True | ||||||||
| break | ||||||||
| rewards.append(1.0 if passed else 0.0) | ||||||||
| return rewards | ||||||||
|
|
||||||||
|
|
||||||||
| def format_sample(sample): | ||||||||
| """ | ||||||||
| Format the MBPP dataset row into a prompt using OpenAI chat formatting, | ||||||||
| and persist `test_list` so `reset()` can inject it. | ||||||||
| """ | ||||||||
| prompt_text = sample.get("text", "") | ||||||||
| content = ( | ||||||||
| f"You are an expert Python programmer.\n\n" | ||||||||
| f"{prompt_text}\n\n" | ||||||||
| f"Please write a python code to solve the problem and use the execute_python_code tool to test it." | ||||||||
| ) | ||||||||
| prompt = [{"role": "user", "content": content}] | ||||||||
|
|
||||||||
| return {"prompt": prompt, "test_list": sample.get("test_list", [])} | ||||||||
|
|
||||||||
|
|
||||||||
| def main() -> None: | ||||||||
| os.environ["WANDB_PROJECT"] = "async_grpo_trl_mbpp" | ||||||||
| # 1. Load dataset | ||||||||
| dataset = load_dataset("google-research-datasets/mbpp", split="train+test") | ||||||||
| dataset = dataset.map(format_sample, remove_columns=dataset.column_names) | ||||||||
|
|
||||||||
| # 2. Config setup | ||||||||
| config = AsyncGRPOConfig( | ||||||||
|
AmineDiro marked this conversation as resolved.
Outdated
|
||||||||
| output_dir="./results", | ||||||||
| per_device_train_batch_size=1, | ||||||||
| max_completion_length=8192, | ||||||||
| max_seq_length=8192, | ||||||||
| max_tool_calling_iterations=5, | ||||||||
| max_steps=100, | ||||||||
| max_staleness=8, | ||||||||
| # Logging | ||||||||
| log_completions=True, | ||||||||
| num_completions_to_print=2, | ||||||||
| report_to="wandb", | ||||||||
| logging_steps=1, | ||||||||
| # trackio | ||||||||
| project="async_grpo_trl_mbpp", | ||||||||
| trackio_space_id=None, | ||||||||
|
AmineDiro marked this conversation as resolved.
Outdated
|
||||||||
| ) | ||||||||
|
|
||||||||
| # 3. Trainer initialization | ||||||||
| trainer = AsyncGRPOTrainer( | ||||||||
| model="Qwen/Qwen3-4B", | ||||||||
| args=config, | ||||||||
|
AmineDiro marked this conversation as resolved.
Outdated
|
||||||||
| train_dataset=dataset, | ||||||||
| reward_funcs=[tests_passed_reward], | ||||||||
| environment_factory=MBPPEnvironment, | ||||||||
| ) | ||||||||
|
|
||||||||
| # 4. Train | ||||||||
| trainer.train() | ||||||||
|
|
||||||||
|
|
||||||||
| if __name__ == "__main__": | ||||||||
| main() | ||||||||
Uh oh!
There was an error while loading. Please reload this page.