-
Notifications
You must be signed in to change notification settings - Fork 124
feat: async partial rollout trainer with sample supplementation and caching #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mamazi0131
wants to merge
1
commit into
verl-project:main
Choose a base branch
from
mamazi0131:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| # Recipe: Async Partial Rollout Trainer | ||
|
|
||
| **Group:** `Tencent Data & Computation Platform Department` | ||
|
|
||
| **Author:** Yue Wang*, Zhipeng Ma*, Yi Yan, Hang Xu, Yang Li, Bo Qian, Peng Chen | ||
|
|
||
| Last updated: 01/15/2026. | ||
|
|
||
| ## 1. Introduction | ||
|
|
||
| ### 1.1 Background | ||
|
|
||
| During synchronous reinforcement learning training in verl, we observe that the training dataset exhibits significant length imbalance, with a small fraction of exceptionally long samples. As illustrated in Figure 1, the maximum response length in the dataset reaches 160k tokens, while approximately 97% of responses are shorter than 80k tokens. Consequently, the minority of long-tail samples (3%) significantly slows down the training of the majority (97%) of the data. Moreover, these long-tail samples often correspond to more challenging cases, which are essential for effectively enhancing the model’s reasoning capabilities. Therefore, they cannot be removed without compromising training effectiveness. | ||
|
|
||
|
|
||
|  | ||
|
|
||
|
|
||
| ### 1.2 Solution | ||
|
|
||
| We enhance the partial-rollout mechanism by introducing **sample supplementation** and **interruption techniques**. Since response lengths are unknown at inference time, **inference bubbles** are inevitable. We leverage sample supplementation to effectively utilize this otherwise unavoidable idle GPU time. Specifically, when a GPU worker completes its inference workload earlier than others, we supplement it with additional samples until the total number of samples returned by all GPU workers meets the training requirement. Once this requirement is satisfied, some GPU workers may still be processing ongoing inference tasks. To better utilize these partially processed samples, we **cache unfinished samples** and reuse them in the subsequent inference round. | ||
|
|
||
|  | ||
| > reference: [APRIL: ACTIVE PARTIAL ROLLOUTS IN REINFORCEMENT LEARNING TO TAME LONG-TAIL GENERATION]( | ||
| > https://arxiv.org/pdf/2509.18521) | ||
|
mamazi0131 marked this conversation as resolved.
|
||
|
|
||
|
|
||
| Our core contributions include: | ||
|
|
||
| 1. **Sample Supplementation and Interruption Mechanisms**: | ||
| Introducing sample supplementation and interruption mechanisms to enable dynamic sample replenishment and automated scheduling of inference tasks. | ||
|
|
||
| 2. **Rollout Caching**: | ||
| Using a prompt manager to resume partial rollouts, managing complete and partial samples in the buffer based on sample staleness. | ||
|
|
||
|
|
||
| ### 1.3 Experimental Results | ||
|
|
||
| - **Machine Configuration**: 2 nodes with 8 H20 GPUs | ||
| - **Model**: Qwen3-4B | ||
| - **Rollout Configuration**: | ||
| - **Max Response Length**: 18384 tokens (for DAPO-MATH17k), 1024 tokens (for GSM8K) | ||
| - **Algorithm**: GRPO | ||
| - **Rollout Engine**: vLLM | ||
|
|
||
| #### GSM8K | ||
| On the GSM8K dataset, our method achieves comparable convergence and tangible performance gains compared to the baseline. Upon completing the **full dataset** training, it reduces total training time by <span style="color:red">11.7%</span> and improves average GPU utilization by <span style="color:red">5.93%</span>. | ||
|
|
||
| | Training mode | Engine | Step | Total Time |Acc/mean@1 | GPU Avg Utilization | | ||
| |------------------------|---------------|------|------------------|---------------|---------------| | ||
| | GRPO+noPR | VLLM+Megatron | 290 | 4h59m | 94.99 |71.54 | | ||
| | GRPO+PR | VLLM+Megatron | 280 | 4h24m <span style="color:red"> (-35m) </span> | 94.08 |77.47| | ||
|
|
||
|
|
||
| > source data: https://swanlab.cn/@allenzpma/verl_exp_partial-rollout_gsm8k/runs | ||
|
|
||
| #### DAPO-MATH17k | ||
| Furthermore, on the DAPO-math dataset, our approach facilitates **full dataset** training with a <span style="color:red">51.1%</span> reduction in end-to-end execution time and an <span style="color:red">8.77%</span> boost in GPU utilization. And, our method achieves comparable convergence to the baseline. | ||
|
|
||
| | Training Mode | Engine | Step | Total Time |Acc/best@32/mean | Acc/maj@32/mean |GPU Avg Utilization | | ||
| | :--- | :--- | :--- | :--- | :--- | :--- | :--- | | ||
| | GRPO+noPR | VLLM+Megatron | 200 |67h34m | 79.94 | 73.33 |74.64| | ||
| | GRPO+PR | VLLM+Megatron | 110 | 33h02m <span style="color:red"> (-34h32m) </span> | 82.90 | 73.41 |83.41| | ||
|
|
||
|
|
||
| > source data: https://swanlab.cn/@allenzpma/verl_exp_partial-rollout_dapo-math/runs | ||
|
|
||
|
|
||
| ## 2. Implementation | ||
|
|
||
| ### 2.1 Sample Supplementation and Interruption Mechanisms (SSIM) | ||
|
|
||
| The main components of the SSIM mechanism are as follows: | ||
| <!--  --> | ||
|
|
||
| <img src="https://raw.githubusercontent.com/mamazi0131/verl_doc/fca7a6d3acbeca12d69c5de6f85c312c1c9e47b6/Architectural_Design_of_the_SSIM_Mechanism.png" width="60%"> | ||
|
|
||
| The event interaction logic of the SSIM mechanism is as follows: | ||
|  | ||
|
|
||
|
|
||
| ### 2.2 Rollout Caching | ||
| The rollout caching mechanism is implemented using a prompt manager. The prompt manager uses a queue to control the order of sample resumption, with prompt priority defined by the **get_scheduling_priority** function. | ||
|
|
||
| ```python | ||
| class PromptsManager: | ||
| """ | ||
| PromptsManager is used to manage the prompts queue. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| global_config, | ||
| train_dataloader : StatefulDataLoader, | ||
| sampling_num : int, | ||
| rollout_manager_obj, | ||
| trained_prompts_index: set[int] = set(), | ||
| ): | ||
| """ | ||
| Args: | ||
| global_config: the global config | ||
| train_dataloader: the train dataloader from `ray_trainer.py` | ||
| sampling_num: the number of samples to generate for each prompt | ||
| rollout_manager_obj: the rollout manager object | ||
| trained_prompts_index: the prompts that have been trained, used to skip the prompts that have been trained | ||
| """ | ||
| self.global_config = global_config | ||
| self.sampling_num = sampling_num | ||
| self.prompt_queue = PromptsQueue() | ||
| self.trained_prompts_index = trained_prompts_index | ||
|
|
||
| # init dataloader_iter | ||
| self.dataloader_iter = iter(train_dataloader) | ||
| self.dataloader_iter_exhausted = False | ||
| self.filter_cnt = 0 | ||
| self.model_version = 0 | ||
|
|
||
|
|
||
| # Sort Priority (for each prompt) | ||
| def get_scheduling_priority(self, ignored_samples: set[Sample] = set()) -> tuple[int, float, int]: | ||
| """ | ||
| Return a priority key for prompt scheduling. | ||
|
|
||
| The tuple is ordered so that it can be directly used in `sort(key=...)`: | ||
| ( | ||
| unfinished_samples_num, | ||
| finished_mean_response_length (1e9 if no finished samples), | ||
| max_staleness | ||
| ) | ||
| """ | ||
| unfinished_samples = set(self.get_unfinished_samples()) - set(ignored_samples) | ||
| finished_samples = self.get_finished_samples() | ||
|
|
||
| # 1. unfinished samples number | ||
| unfinished_num = len(unfinished_samples) | ||
| # 2. mean response length of finished samples | ||
| finished_mean_resp_len = ( | ||
| np.mean([sample.get_responses_length() for sample in finished_samples]) | ||
| if finished_samples | ||
| else 1e9 | ||
| ) | ||
| # 3. max staleness | ||
| max_staleness = np.max( | ||
| [sample.get_staleness(expected_version=self.expected_model_version) | ||
| for sample in self.samples] | ||
| ) | ||
|
|
||
| return unfinished_num, finished_mean_resp_len, max_staleness | ||
| ``` | ||
|
|
||
| ### 2.3 Off-Policy Correctness | ||
| To ensure the correctness of the PPO algorithm, PPO importance sampling is performed using **rollout log probs** with a decoupled trick, which preserves algorithmic correctness under interruptible generation and policy updates. | ||
|
|
||
| $$ | ||
| J(\theta)=\mathbb{E}_{q \sim \mathcal{D}, a_t \sim \pi_{\text {behav}}^{\text{rollout}}}[\sum_{t=1}^H \min (\frac{\pi_{\theta}^{\text{train}}}{\pi_{\text {behav}}^{\text{rollout}}} \hat{A}_t, \frac{\pi_{\text {prox }}^{\text{rollout}}}{\pi_{\text {behav }}^{\text{rollout}}} \operatorname{clip}\left(\frac{\pi_{\theta}^{\text{train}}}{\pi_{\text {prox }}^{\text{rollout}}}, 1-\epsilon, 1+\epsilon\right) \hat{A}_t)] \\ | ||
| $$ | ||
| > reference: [AREAL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning]( | ||
| > https://arxiv.org/pdf/2505.24298) | ||
|
|
||
| ### 2.4 AgentLoop | ||
| In the current implementation, we use AgentLoop mode, which also supports multi-turn tool calling. | ||
|
|
||
| ## 3.Usage | ||
| ### GSM8K Configuration Example | ||
| ```shell | ||
| bash recipe/partial_rollout/run_gsm8k_nopr_4b_bs128.sh | ||
| bash recipe/partial_rollout/run_gsm8k_pr_4b_bs128.sh | ||
| ``` | ||
|
|
||
| ### DAPO_MATH Configuration Example | ||
| ```shell | ||
| bash recipe/partial_rollout/run_dapo_math17k_nopr_4b_2node.sh | ||
| bash recipe/partial_rollout/run_dapo_math17k_pr_4b_2node.sh | ||
| ``` | ||
|
|
||
| ## 4. Functional Support | ||
|
|
||
| Furthermore, **our implementation supports both verl 0.5.0 and 0.6.1.** We recommend freezing the verl version in your environment to ensure long-term stability and prevent potential breaking changes from future upstream PRs. | ||
|
|
||
| | Category | Support Situation | | ||
| |--------------------|-----------------------------------------------------------------------------------------------------------------| | ||
| | train engine | FSDP2 <br/> Megatron | | ||
| | rollout engine | vLLM | | ||
| | AdvantageEstimator | GRPO <br/> GSPO <br/> SAPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG | | ||
| | Reward | all | | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # Copyright 2025 Meituan Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from .agent_loop import PRv3AgentLoopManager | ||
| from .partial_single_turn_agent_loop import PartialSingleTurnAgentLoop | ||
| from .partial_tool_agent_loop import PartialToolAgentLoop | ||
|
|
||
| _ = [PartialSingleTurnAgentLoop, PartialToolAgentLoop] | ||
| __all__ = [PRv3AgentLoopManager] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.