diff --git a/.github/workflows/stash/e2e_fully_async_policy.yml b/.github/workflows/e2e_fully_async_policy.yml
similarity index 100%
rename from .github/workflows/stash/e2e_fully_async_policy.yml
rename to .github/workflows/e2e_fully_async_policy.yml
diff --git a/.github/workflows/stash/e2e_one_step_off_policy.yml b/.github/workflows/e2e_one_step_off_policy.yml
similarity index 100%
rename from .github/workflows/stash/e2e_one_step_off_policy.yml
rename to .github/workflows/e2e_one_step_off_policy.yml
diff --git a/.github/workflows/stash/e2e_one_step_off_policy_ascend.yml b/.github/workflows/e2e_one_step_off_policy_ascend.yml
similarity index 100%
rename from .github/workflows/stash/e2e_one_step_off_policy_ascend.yml
rename to .github/workflows/e2e_one_step_off_policy_ascend.yml
diff --git a/docs/advance/fully_async.md b/docs/advance/fully_async.md
index 0c03bac6e86..a2b7ccb3aea 100644
--- a/docs/advance/fully_async.md
+++ b/docs/advance/fully_async.md
@@ -2,7 +2,7 @@
**Author:** `https://github.com/meituan-search`
-Last updated: 12/25/2025.
+Last updated: 02/05/2026.
This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter,
supporting asynchronous sample generation and training.
@@ -37,33 +37,36 @@ can significantly improve training efficiency.
> Generation https://arxiv.org/abs/2504.15930
>
> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663
+>
### Core Contributions
-- **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to
+* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to
specify the resources they occupy separately.
-- **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples.
-- **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to
+* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples.
+* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to
multiple steps, making the asynchronous solution more flexible.
-- **NCCL Parameter Synchronization**: Based on the nccl communication primitive, refer to [checkpoint-engine](https://github.com/MoonshotAI/checkpoint-engine) to
+* **NCCL Parameter Synchronization**: Based on the nccl communication primitive, refer
+ to [checkpoint-engine](https://github.com/MoonshotAI/checkpoint-engine) to
achieve efficient parameter synchronization between Rollouter and Trainer.
-- **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single
+* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single
sample as the minimum transmission unit.
-- **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it
+* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it
supports training with samples generated by old parameters.
-- **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter
+* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter
synchronization, by adding `sleep() and resume()` logic, it
saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for
ongoing tasks to finish during parameter synchronization.
-Currently, the supported usage mode is Megatron/FSDP+vLLM/SGLang. vLLM/SGLang must use the server mode based on AgentLoop.
+Currently, the supported usage mode is megatron/fsdp+vllm. vllm must use the server mode based on AgentLoop.
## Design
The overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four
parts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer.
-
+
1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the
production speed controlled by freshness.
@@ -79,14 +82,15 @@ After we perform resource isolation, the time for rollout and train may be longe
are used),
but the overlap in their time consumption reduces the end-to-end time consumption.
-
+
## Usage
### Parameter Description
| super params | implication |
-| ---------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- |
+|------------------------------------------------------------------|------------------------------------------------------------------------------------------------|
| `trainer.nnodes` | Number of nodes for Trainer |
| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer |
| `rollout.nnodes` | Number of nodes for Rollouter |
@@ -96,65 +100,59 @@ but the overlap in their time consumption reduces the end-to-end time consumptio
| `rollout.total_rollout_steps` | Total number of rollout samples |
| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation |
| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
+| `actor_rollout_ref.actor.use_rollout_log_probs=True` | Use log_probs generated by rollout |
+| `algorithm.rollout_correction.bypass_mode` | Whether to compute log_prob using the training model's parameters during the training phase. |
| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once |
| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |
| `async_training.staleness_threshold` | Freshness control |
| `async_training.partial_rollout` | Whether to perform partial_rollout |
-| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout |
-| `async_training.compute_prox_log_prob` | Whether to compute log_prob using the training model's parameters during the training phase |
| `async_training.checkpoint_engine.enable` | Whether to use checkpoint_engine for accelerating, default `True` |
| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | When use checkpoint_engine, whether to overlap broadcast and load_weights, default `False` |
| `async_training.checkpoint_engine.device_buffer_size_M` | When use checkpoint_engine, the user-specific bucket size (MB), default `4096` |
-| `async_training.use_trainer_do_validate` | Whether use trainer node to do validate process, default `False`|
+| `async_training.use_trainer_do_validate` | Whether use trainer node to do validate process, default `False` |
**Further Explanation:**
-- `rollout.total_rollout_steps`
+* `rollout.total_rollout_steps`
Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step:
`rollout.total_rollout_steps = data.train_batch_size * step`.
-- `async_training.trigger_parameter_sync_step`
+* `async_training.trigger_parameter_sync_step`
In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches
`require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter.
Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process
`trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples.
- To fairly compare speed with colocate, trigger_parameter_sync_step should be set to
+ To fairly compare speed with colocate, `trigger_parameter_sync_step` should be set to
`data.train_batch_size / (require_batches * ppo_mini_batch_size)`.
-- `async_training.staleness_threshold`
+* `async_training.staleness_threshold`
In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used.
- - staleness_threshold=0, indicates synchronous training.
- Rollouter will generate a fixed number of samples between two parameter updates, the sample count is:
- $$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$
- - staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous
- calls.
- Rollouter will generate at most the following number of samples between two parameter updates:
- $$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$
+ * `staleness_threshold`=0, indicates synchronous training.
+ Rollouter will generate a fixed number of samples between two parameter updates, the sample count is:
+
+ `rollout_num = (trigger_parameter_sync_step*require_batches*ppo_mini_batch_size)`
+ * `staleness_threshold`>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous
+ calls.
+ Rollouter will generate at most the following number of samples between two parameter updates:
- num_staleness_sample represents the number of stale samples generated in excess during the last rollout.
+ `rollout_num = (1+staleness_threshold)*(trigger_parameter_sync_step*require_batches*ppo_mini_batch_size) - num_staleness_sample`
+
+ `num_staleness_sample` represents the number of stale samples generated in excess during the last rollout.
Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower,
trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples.
- When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy.
+ When rollout is fast enough, setting `staleness_threshold` to 1 is basically equivalent to one_step_off policy.
To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1.
-- `async_training.partial_rollout`
+* `async_training.partial_rollout`
partial_rollout only actually takes effect when staleness_threshold>0.
-- `async_training.use_rollout_log_probs`
-
- In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to
- the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,
- old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm
- correctness. In the fully
- async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.
-
-- `async_training.require_batches`
+* `async_training.require_batches`
In streaming training, require_batches should be set to 1, indicating that training is performed after producing
enough ppo_mini_batch_size samples.
@@ -163,37 +161,47 @@ but the overlap in their time consumption reduces the end-to-end time consumptio
Here, we additionally provide require_batches for streaming distribution and control the number of samples
participating in training at once.
-- `async_training.compute_prox_log_prob` (experimental)
+* `actor_rollout_ref.actor.use_rollout_log_probs=True`
+
+ In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to
+ the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,
+ old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm
+ correctness. In the fully
+ async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.
+
+* `algorithm.rollout_correction.bypass_mode`
+
+ > algorithm.rollout_correction.bypass_mode default is True, using rollout log prob.
During the training process, we observed that metrics and response lengths may become unstable in the later
stages of training. To mitigate this issue, we can use
the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html)
technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using
the training engine, which requires enabling this switch.
- Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d
+ Additionally, when `algorithm.rollout_correction.bypass_mode=False` and Rollout Importance Sampling are enabled under
+ mode d
(async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`.
-- `async_training.checkpoint_engine.enable`
+* `async_training.checkpoint_engine.enable`
Enabling the checkpoint engine generally reduces synchronization time overhead by more than 60% compared to
the original per-tensor parameter synchronization method. However, assembling buckets incurs additional
temporary GPU memory overhead.
-- `async_training.checkpoint_engine.overlap_broadcast_and_consume`
+* `async_training.checkpoint_engine.overlap_broadcast_and_consume`
Enabling pipeline between the broadcast and load_weights parameters will allocate additional GPU memory.
Since the main time consumption for parameter synchronization is not in the broadcast and load_weights phases,
but in the parameter generation phase (by megatron or FSDP), this option is off by default.
-- `async_training.checkpoint_engine.device_buffer_size_M`
+* `async_training.checkpoint_engine.device_buffer_size_M`
It controls the size of the memory buffer used for synchronization when the checkpoint-engine is enabled.
The actual `bucket_size` = `max(device_buffer_size_M, maximum parameter tensor size)`.
-
- - When enable `overlap_broadcast_and_consume`, the additional device memory overhead of
- trainer rank is `3 * bucket_size`and rollout rank is `2 * bucket_size`。
- - When disable `overlap_broadcast_and_consume`, the additional device memory overhead of
- trainer rank is `2 * bucket_size`and rollout rank is `1 * bucket_size`。
+ * When enable `overlap_broadcast_and_consume`, the additional device memory overhead of
+ trainer rank is `3 * bucket_size`and rollout rank is `2 * bucket_size`。
+ * When disable `overlap_broadcast_and_consume`, the additional device memory overhead of
+ trainer rank is `2 * bucket_size`and rollout rank is `1 * bucket_size`。
* `async_training.use_trainer_do_validate`
@@ -205,53 +213,51 @@ but the overlap in their time consumption reduces the end-to-end time consumptio
### Supported Modes
1. on policy pipeline:
-
- 1. **trigger_parameter_sync_step=1, staleness_threshold=0**
- 2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for
- training, and after training completes, Trainer and Rollouter perform a parameter synchronization;
- 3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill
- idle resources, causing some resource waste.
- 4. As shown in figure a;
+ 1. **trigger_parameter_sync_step=1, staleness_threshold=0**
+ 2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for
+ training, and after training completes, Trainer and Rollouter perform a parameter synchronization;
+ 3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill
+ idle resources, causing some resource waste.
+ 4. As shown in figure a;
2. stream off policy pipeline:
-
- 1. **trigger_parameter_sync_step>1, staleness_threshold=0**
- 2. Synchronous streaming training will be performed. Rollouter produces
- `require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local
- training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training
- trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization;
- 3. Compared to a, since more samples are generated at once, resource idleness will be lower.
- 4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples,
- train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter
- update, rollout waits for training to complete.
- 5. As shown in figure b;
+ 1. **trigger_parameter_sync_step>1, staleness_threshold=0**
+ 2. Synchronous streaming training will be performed. Rollouter produces
+ `require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local
+ training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training
+ trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization;
+ 3. Compared to a, since more samples are generated at once, resource idleness will be lower.
+ 4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples,
+ train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter
+ update, rollout waits for training to complete.
+ 5. As shown in figure b;
3. async stream pipeline with stale samples:
-
- 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False**
- 2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number
- of samples generated may be less than this value depending on rollout speed).
- 3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples
- before parameter synchronization for immediate use by Trainer after synchronization.
- When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete
- and not add new tasks;
- 4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the
- first batch rollout to finish, but will have the time to wait for active tasks to finish.
- 5. As shown in figure c;
+ 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False**
+ 2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number
+ of samples generated may be less than this value depending on rollout speed).
+ 3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples
+ before parameter synchronization for immediate use by Trainer after synchronization.
+ When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete
+ and not add new tasks;
+ 4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the
+ first batch rollout to finish, but will have the time to wait for active tasks to finish.
+ 5. As shown in figure c;
4. async stream pipeline with partial rollout:
- 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True**
- 2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will
- interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be
- generated after synchronization. This reduces the time to wait for active tasks to finish.
- 3. As shown in figure d;
+ 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True**
+ 2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will
+ interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be
+ generated after synchronization. This reduces the time to wait for active tasks to finish.
+ 3. As shown in figure d;
-
+
### Key Metrics
| metrics | implication |
-| ---------------------------------------------- | ------------------------------------------------------------------------------------------------------ |
+|------------------------------------------------|--------------------------------------------------------------------------------------------------------|
| `trainer/idle_ratio` | Trainer idle rate |
| `rollouter/idle_ratio` | Rollouter idle rate |
| `fully_async/count/stale_samples_processed` | Total number of old samples used in training |
@@ -262,41 +268,39 @@ but the overlap in their time consumption reduces the end-to-end time consumptio
### Parameter Tuning Recommendations
-- Resource Allocation and Adjustment:
-
- - Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource
- allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire
- training process,
- avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource
- allocation can be adjusted based on the idle time of rollout and train during actual training,
- which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and
- trainer/idle_ratio is low,
- Trainer resources should be increased and Rollouter resources should be reduced, and vice versa.
-
-- Key Parameters:
-
- - staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It
- is recommended to set it to less than 1.
- - require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and
- the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample
- processing;
- - trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent
- parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in
- low resource utilization.
- The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy.
- - rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small.
-
-- Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at
+* Resource Allocation and Adjustment:
+ * Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource
+ allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire
+ training process,
+ avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource
+ allocation can be adjusted based on the idle time of rollout and train during actual training,
+ which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and
+ trainer/idle_ratio is low,
+ Trainer resources should be increased and Rollouter resources should be reduced, and vice versa.
+
+* Key Parameters:
+ * staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It
+ is recommended to set it to less than 1.
+ * require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and
+ the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample
+ processing;
+ * trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent
+ parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in
+ low resource utilization.
+ The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy.
+ * rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small.
+
+* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at
different levels, suitable for tasks in different scenarios.
- - For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed
- requirements, the on policy pipeline mode (Mode 1) can be tried.
- - For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy
- pipeline mode can be tried. That is, by
- setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization
- mechanism (staleness_threshold=0) (Mode 2).
- - For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and
- staleness, setting staleness_threshold>
- 0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4).
+ * For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed
+ requirements, the on policy pipeline mode (Mode 1) can be tried.
+ * For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy
+ pipeline mode can be tried. That is, by
+ setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization
+ mechanism (staleness_threshold=0) (Mode 2).
+ * For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and
+ staleness, setting staleness_threshold>
+ 0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4).
### Quick Start
@@ -319,7 +323,7 @@ trigger_parameter_sync_step=16
partial_rollout=False
-python -m verl.experimental.fully_async_policy.fully_async_main \
+python -m recipe.fully_async_policy.fully_async_main \
train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
@@ -332,7 +336,6 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
- actor_rollout_ref.rollout.calculate_log_probs=True \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
@@ -352,36 +355,35 @@ We used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under
Using the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards,
64 cards, and 128 cards without significantly affecting experimental results.
-- Machine: H20
-- Model: Qwen2.5-Math-7B
-- Rollout length: max_response_length FSDP2: 28K tokens;
-- Algorithm: DAPO
-- Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
-- Engine: vLLM + FSDP2
-- rollout.n: 16
-- ppo_mini_batch_size: 32
-- test_freq: 20
-
-- colocate sync:
-
- - step: 400
- - train_batch_size: 512
-
-- fully_async_policy
- - total_rollout_steps: 512\*400
- - require_batches: 4
- - trigger_parameter_sync_step: 4
- - staleness_threshold: 0.5
- - partial_rollout: True
-
-| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 |
-| :----------------: | :-----------------: | :----: | :----: | :----------: | :----------: | :--------------------: | :--------------------: | :--------------------: | :--------------------: | :-------------------------: |
-| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 269.80 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313
last: 0.2448 |
-| fully_async_policy | 16:16 | 294.77 | 21.26 | \ | 313.81 | 7h 58m
(1.72x) | 16h 21m
(1.70x) | 1d 0h 53m
(2.31x) | 1d 9h 26m
(2.66x) | max: 0.3302
last: 0.2333 |
-| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365
last: 0.2333 |
-| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m
(2.09x) | 10h 14m
(2.03x) | 16h 58m
(1.83x) | 21h 40m
(1.92x) | max: 0.3677
last: 0.3406 |
-| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 |
-| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m
(2.67x) | 6h 46m
(2.65x) | 10h 53m
(2.67x) | 17h 22m
(2.35x) | max: 0.3521
last: 0.3094 |
+* Machine: H20
+* Model: Qwen2.5-Math-7B
+* Rollout length: max_response_length FSDP2: 28K tokens;
+* Algorithm: DAPO
+* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
+* Engine: vllm+FSDP2
+* rollout.n: 16
+* ppo_mini_batch_size: 32
+* test_freq: 20
+
+* colocate sync:
+ * step: 400
+ * train_batch_size: 512
+
+* fully_async_policy
+ * total_rollout_steps: 512*400
+ * require_batches: 4
+ * trigger_parameter_sync_step: 4
+ * staleness_threshold: 0.5
+ * partial_rollout: True
+
+| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 |
+|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:---------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-------------------------------:|
+| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 269.80 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313
last: 0.2448 |
+| fully_async_policy | 16:16 | 294.77 | 21.26 | \ | 313.81 | 7h 58m
(1.72x) | 16h 21m
(1.70x) | 1d 0h 53m
(2.31x) | 1d 9h 26m
(2.66x) | max: 0.3302
last: 0.2333 |
+| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365
last: 0.2333 |
+| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m
(2.09x) | 10h 14m
(2.03x) | 16h 58m
(1.83x) | 21h 40m
(1.92x) | max: 0.3677
last: 0.3406 |
+| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 |
+| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m
(2.67x) | 6h 46m
(2.65x) | 10h 53m
(2.67x) | 17h 22m
(2.35x) | max: 0.3521
last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg
@@ -391,12 +393,12 @@ We used Qwen2.5-Math-7B to verify the effects of various modes supported by full
We can see that the benefit brought by streaming is approximately 1.6x, and after combining staleness and
partial_rollout, the benefit reaches 2.35x.
-| mode | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 |
-| :---------------------------------------------------------------------------------------------------: | :----: | :----: | :----------: | :----------: | :--------------------: | :--------------------: | :--------------------: | :--------------------: | :-------------------------: |
-| colocate sync | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 |
-| `stream off policy pipeline`
(+fully async: trigger_parameter_sync_step= 4,
require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 |
-| `async stream pipeline with stale samples`
(+staleness_threshold=0.5) | | | | | | | | | |
-| `async stream pipeline with partial rollout`
(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 |
+| mode | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 |
+|:-------------------------------------------------------------------------------------------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|
+| colocate sync | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 |
+| `stream off policy pipeline`
(+fully async: trigger_parameter_sync_step= 4,
require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 |
+| `async stream pipeline with stale samples`
(+staleness_threshold=0.5) | | | | | | | | | |
+| `async stream pipeline with partial rollout`
(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
@@ -409,12 +411,12 @@ We also noticed that the times for staleness values of 0.3 and 0.5 are quite clo
increase, the response length changes significantly, causing training instability.
Further analysis and optimization are needed for this issue.
-| staleness_threshold | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 |
-| :-----------------: | :----: | :----: | :----------: | :----------: | :--------------------: | :--------------------: | :--------------------: | :--------------------: | :-------------------------: |
-| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 |
-| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542
last: 0.2979 |
-| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469
last: 0.2865 |
-| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 |
+| staleness_threshold | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 |
+|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
+| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 |
+| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542
last: 0.2979 |
+| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469
last: 0.2865 |
+| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
@@ -424,11 +426,11 @@ In multiple tests, we found that the number of samples issued each time in strea
training, which in turn affects training time. We verified the impact on results by modifying
`async_training.require_batches`.
-| require_batches | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | acc/mean@1 |
-| :-------------: | :----: | :---: | :----------: | :----------: | :--------------------: | :--------------------: | :--------------------: | :-------------------------: |
-| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349
last: 0.326 |
-| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351
last: 0.3406 |
-| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521
last: 0.3521 |
+| require_batches | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | acc/mean@1 |
+|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
+| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349
last: 0.326 |
+| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351
last: 0.3406 |
+| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521
last: 0.3521 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg
@@ -445,48 +447,47 @@ resource adjustment less flexible. Additionally, as asynchronous training and de
is gradually narrowing. Therefore, enabling more flexible resource allocation and dynamic resource adjustment in the
future will be our next focus.
-- Machine: H20
-- Model: Qwen3-30B-A3B-Base
-- Rollout length: max_response_length : 8K tokens;
-- Algorithm: GRPO
-- Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
-- Engine: vLLM + Megatron
-- rollout.n: 16
-- ppo_mini_batch_size: 128
-- test_freq: 20
-
-- colocate sync:
+* Machine: H20
+* Model: Qwen3-30B-A3B-Base
+* Rollout length: max_response_length : 8K tokens;
+* Algorithm: GRPO
+* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
+* Engine: vllm+Megatron
+* rollout.n: 16
+* ppo_mini_batch_size: 128
+* test_freq: 20
- - step:400
- - train_batch_size: 512
+* colocate sync:
+ * step:400
+ * train_batch_size: 512
-- fully_async_policy
- - total_rollout_steps: 512\*400
- - trigger_parameter_sync_step: 512/128 = 4
- - staleness_threshold: 0.5
- - partial_rollout: True
+* fully_async_policy
+ * total_rollout_steps: 512*400
+ * trigger_parameter_sync_step: 512/128 = 4
+ * staleness_threshold: 0.5
+ * partial_rollout: True
| Training Mode | Resource Allocation | Step | Gen | Old Log Prob | Ref | Update Actor | Total Time 100 Step | Total Time 200 Step | Total Time 300 Step | Total Time 400 Step | Acc/Mean@1 |
-| ------------------ | ------------------- | ------ | ------ | ------------ | ----- | ------------ | ------------------- | ------------------- | ------------------- | ------------------- | --------------------------- |
+|--------------------|---------------------|--------|--------|--------------|-------|--------------|---------------------|---------------------|---------------------|---------------------|-----------------------------|
| Colocate Sync | 128 | 497.89 | 348.05 | 28.73 | 20.86 | 86.27 | 13h 36m | 1d 3h 48m | 1d 19h 4m | 2d 11h 39m | max: 0.3500
last: 0.3208 |
| Fully Async Policy | 96:32 | 282.75 | 22.06 | \ | 50.05 | 206.63 | 6h 45m (2.01x) | 14h 48m (1.88x) | 1d 0h 9m (1.78x) | 1d 10h 41m (1.72x) | max: 0.3813
last: 0.3448 |
-> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-30B?nw=nwuserhouzg | | |
+> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-30B?nw=nwuserhouzg | | |
### checkpoint-engine Ablation Experiment
-
We tested the single-step parameter synchronization time of the checkpoint-engine on three models: Qwen2.5-Math-7B, Qwen3-30B-A3B, and Qwen3-235B-A22B, using default checkpoint-engine configurations. All experiments were performed on H20 machines, and the Megatron engine was used for training.
-| model | trainer rank | rollout rank | checkpoint-engine | total sync time |
-|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|
-| Qwen2.5-Math-7B | 4 | 4 | False | 0.12s |
-| Qwen2.5-Math-7B | 4 | 4 | True | 0.02s |
-| Qwen3-30B-A3B | 16 | 16 | False | 15.76s |
-| Qwen3-30B-A3B | 16 | 16 | True | 4.38s |
-| Qwen3-235B-A22B | 64 | 64 | False | 58.57s |
-| Qwen3-235B-A22B | 64 | 64 | True | 23.70s |
-### use_trainer_do_validate Experiment
+| model | trainer rank | rollout rank | checkpoint-engine | total sync time |
+|:---------------:|:--------------:|:-------------:|:-------------------:|:-----------------:|
+| Qwen2.5-Math-7B | 4 | 4 | False | 0.12s |
+| Qwen2.5-Math-7B | 4 | 4 | True | 0.02s |
+| Qwen3-30B-A3B | 16 | 16 | False | 15.76s |
+| Qwen3-30B-A3B | 16 | 16 | True | 4.38s |
+| Qwen3-235B-A22B | 64 | 64 | False | 58.57s |
+| Qwen3-235B-A22B | 64 | 64 | True | 23.70s |
+
+### use_trainer_do_validate Experiment
We tested the effect of setting `use_trainer_do_validate=True` on the training process. The results show that setting
this parameter to True can reduce the validation time overhead and trainer node idle time.
We used Qwen2.5-Math-7B to verify the benefits of `use_trainer_do_validate=True` on the training process, we achieved about 2x performance improvement on validation time, and the trainer node idle time is reduced by about 40%.
@@ -508,10 +509,10 @@ We used Qwen2.5-Math-7B to verify the benefits of `use_trainer_do_validate=True`
* staleness_threshold: 0.5
* partial_rollout: True
-| training mode | resource allocation | step | gen | old_log_prob | update_actor | validate time | total time
50 step | acc/mean@2 |
-|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|
-| colocate sync | 16 | 484.623 | 52.939 | 0 | 430.263 | 205.080 | 7h9m | 22.6 |
-| fully_async_policy | 8:8 | 489.953 | 52.622 | 0 | 435.874 | 95.699 | 7h2m | 21.0 |
+| training mode | resource allocation | step | gen | old_log_prob | update_actor | validate time | total time
50 step | acc/mean@2 |
+|:------------------:|:-------------------:|:-------:|:-------:|:------------:|:------------:|:-------------:|:---------------------:|:----------:|
+| colocate sync | 16 | 484.623 | 52.939 | 0 | 430.263 | 205.080 | 7h9m | 22.6 |
+| fully_async_policy | 8:8 | 489.953 | 52.622 | 0 | 435.874 | 95.699 | 7h2m | 21.0 |
## Multi-Turn Tool Calling
@@ -559,37 +560,44 @@ specifying `multi_turn` configurations in the config file.
To validate the performance of `fully_async_policy` on multi-turn tool-calling tasks, we compared it with the standard
`colocate` synchronous mode. Key parameter settings are as follows.
-- **SFT Model**: Based on `Qwen2.5-7B-Instruct`, trained for 6 epochs on the `ReTool-SFT` dataset
-- **RL Algorithm**: DAPO
-- **Dataset**:
- - Train: `DAPO-Math-17k`
- - Test: `aime_2025`
-- **Resource and Mode Comparison**:
- - `colocate sync`: 32 H20 gpus
- - `fully_async_policy`: 16 gpus for Trainer + 16 gpus for Rollouter
-- **Key Configurations**:
- 1. **Tool Calling Configuration**:
- - `multi_turn.enable: True`
- - `multi_turn.max_user_turns: 16`
- - `multi_turn.max_assistant_turns: 16`
- - `multi_turn.tool_config_path: recipe/retool/sandbox_fusion_tool_config.yaml`
- 2. **`colocate sync` Configuration**:
- - `ppo_mini_batch_size: 16`
- - `train_batch_size: 64`
- 3. **`fully_async_policy` Configuration**:
- - `ppo_mini_batch_size: 16`
- - `trigger_parameter_sync_step: 4`
- - `require_batches: 1`
- - `staleness_threshold: 1`
- - `partial_rollout: True`
-
-| training mode | Resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | aime_2025
acc/mean@30 |
-| :----------------: | :-----------------: | :----: | :----: | :----------: | :----------: | :--------------------: | :--------------------: | :-------------------------: |
-| colocate | 32 | 375.47 | 228.03 | 35.19 | 111.84 | 9h 46m | 22h 28m | start:0.1078
last:0.2056 |
-| fully_async_policy | 16: 16 | 221.36 | 40.59 | \ | 179.58 | 6h 19m
(1.55x) | 14h 4m
(1.60x) | start:0.11
last:0.2044 |
+* **SFT Model**: Based on `Qwen2.5-7B-Instruct`, trained for 6 epochs on the `ReTool-SFT` dataset
+* **RL Algorithm**: DAPO
+* **Dataset**:
+ * Train: `DAPO-Math-17k`
+ * Test: `aime_2025`
+* **Resource and Mode Comparison**:
+ * `colocate sync`: 32 H20 gpus
+ * `fully_async_policy`: 16 gpus for Trainer + 16 gpus for Rollouter
+* **Key Configurations**:
+ 1. **Tool Calling Configuration**:
+ * `multi_turn.enable: True`
+ * `multi_turn.max_user_turns: 16`
+ * `multi_turn.max_assistant_turns: 16`
+ * `multi_turn.tool_config_path: recipe/retool/sandbox_fusion_tool_config.yaml`
+ 2. **`colocate sync` Configuration**:
+ * `ppo_mini_batch_size: 16`
+ * `train_batch_size: 64`
+ 3. **`fully_async_policy` Configuration**:
+ * `ppo_mini_batch_size: 16`
+ * `trigger_parameter_sync_step: 4`
+ * `require_batches: 1`
+ * `staleness_threshold: 1`
+ * `partial_rollout: True`
+
+| training mode | Resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | aime_2025
acc/mean@30 |
+|:--------------------:|:---------------------:|:---------:|:---------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:-------------------------------:|
+| colocate | 32 | 375.47 | 228.03 | 35.19 | 111.84 | 9h 46m | 22h 28m | start:0.1078
last:0.2056 |
+| fully_async_policy | 16: 16 | 221.36 | 40.59 | \ | 179.58 | 6h 19m
(1.55x) | 14h 4m
(1.60x) | start:0.11
last:0.2044 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-multiturn-tool?nw=nwuserhouzg
## Future Plans
-- Transfer queue integration
-- Asynchronous parameter synchronization
+
+* GRPO experiments
+* Megatron adaptation
+* SGLang integration
+* Transfer queue integration
+* Asynchronous parameter synchronization
+* AReaL asynchronous algorithm implementation
+* TPPO algorithm implementation
+* Multi-turn and Tool support
diff --git a/tests/special_e2e/run_fully_async_policy.sh b/tests/special_e2e/run_fully_async_policy.sh
index 3d061a59164..3a657cf4e80 100644
--- a/tests/special_e2e/run_fully_async_policy.sh
+++ b/tests/special_e2e/run_fully_async_policy.sh
@@ -15,7 +15,7 @@ MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
rollout_mode="async"
-rollout_name="sglang" # sglang or vllm
+rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
@@ -123,6 +123,7 @@ common_params=(
trainer.resume_mode=disable
trainer.nnodes=1
trainer.n_gpus_per_node=${n_gpus_training}
+ trainer.log_val_generations=10
rollout.nnodes=1
rollout.n_gpus_per_node=${n_gpus_rollout}
rollout.total_rollout_steps=${total_rollout_steps}
diff --git a/verl/experimental/fully_async_policy/README.md b/verl/experimental/fully_async_policy/README.md
index b7406514f3b..a24e7610102 100644
--- a/verl/experimental/fully_async_policy/README.md
+++ b/verl/experimental/fully_async_policy/README.md
@@ -2,7 +2,7 @@
**Author:** `https://github.com/meituan-search`
-Last updated: 12/25/2025.
+Last updated: 02/05/2026.
This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter,
supporting asynchronous sample generation and training.
@@ -88,27 +88,27 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev
### Parameter Description
-| super params | implication |
-|-----------------------------------------------|------------------------------------------------------------------------------------------------|
-| `trainer.nnodes` | Number of nodes for Trainer |
-| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer |
-| `rollout.nnodes` | Number of nodes for Rollouter |
-| `rollout.n_gpus_per_node` | Number of GPUs per node for Rollouter |
-| `data.train_batch_size` | In the fully async strategy, this value is not effective (default is 0) |
-| `data.gen_batch_size` | In the fully async strategy, uses streaming sample production logic (default is 1) |
-| `rollout.total_rollout_steps` | Total number of rollout samples |
-| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation |
-| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
-| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once |
-| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |
-| `async_training.staleness_threshold` | Freshness control |
-| `async_training.partial_rollout` | Whether to perform partial_rollout |
-| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout |
-| `async_training.compute_prox_log_prob` | Whether to compute log_prob using the training model's parameters during the training phase. | |
-| `async_training.checkpoint_engine.enable`| Whether to use checkpoint_engine for accelerating, default `True`|
-| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | When use checkpoint_engine, whether to overlap broadcast and load_weights, default `False`|
-| `async_training.checkpoint_engine.device_buffer_size_M` | When use checkpoint_engine, the user-specific bucket size (MB), default `4096`|
-| `async_training.use_trainer_do_validate` | Whether use trainer node to do validate process, default `False`|
+| super params | implication |
+|------------------------------------------------------------------|------------------------------------------------------------------------------------------------|
+| `trainer.nnodes` | Number of nodes for Trainer |
+| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer |
+| `rollout.nnodes` | Number of nodes for Rollouter |
+| `rollout.n_gpus_per_node` | Number of GPUs per node for Rollouter |
+| `data.train_batch_size` | In the fully async strategy, this value is not effective (default is 0) |
+| `data.gen_batch_size` | In the fully async strategy, uses streaming sample production logic (default is 1) |
+| `rollout.total_rollout_steps` | Total number of rollout samples |
+| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation |
+| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
+| `actor_rollout_ref.actor.use_rollout_log_probs=True` | Use log_probs generated by rollout |
+| `algorithm.rollout_correction.bypass_mode` | Whether to compute log_prob using the training model's parameters during the training phase. |
+| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once |
+| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |
+| `async_training.staleness_threshold` | Freshness control |
+| `async_training.partial_rollout` | Whether to perform partial_rollout |
+| `async_training.checkpoint_engine.enable` | Whether to use checkpoint_engine for accelerating, default `True` |
+| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | When use checkpoint_engine, whether to overlap broadcast and load_weights, default `False` |
+| `async_training.checkpoint_engine.device_buffer_size_M` | When use checkpoint_engine, the user-specific bucket size (MB), default `4096` |
+| `async_training.use_trainer_do_validate` | Whether use trainer node to do validate process, default `False` |
**Further Explanation:**
@@ -151,14 +151,6 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev
partial_rollout only actually takes effect when staleness_threshold>0.
-* `async_training.use_rollout_log_probs`
-
- In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to
- the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,
- old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm
- correctness. In the fully
- async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.
-
* `async_training.require_batches`
In streaming training, require_batches should be set to 1, indicating that training is performed after producing
@@ -168,14 +160,25 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev
Here, we additionally provide require_batches for streaming distribution and control the number of samples
participating in training at once.
-* `async_training.compute_prox_log_prob` (experimental)
+* `actor_rollout_ref.actor.use_rollout_log_probs=True`
+
+ In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to
+ the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,
+ old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm
+ correctness. In the fully
+ async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.
+
+* `algorithm.rollout_correction.bypass_mode`
+
+ > algorithm.rollout_correction.bypass_mode default is True, using rollout log prob.
During the training process, we observed that metrics and response lengths may become unstable in the later
stages of training. To mitigate this issue, we can use
the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html)
technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using
the training engine, which requires enabling this switch.
- Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d
+ Additionally, when `algorithm.rollout_correction.bypass_mode=False` and Rollout Importance Sampling are enabled under
+ mode d
(async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`.
* `async_training.checkpoint_engine.enable`
@@ -332,7 +335,6 @@ python -m recipe.fully_async_policy.fully_async_main \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
- actor_rollout_ref.rollout.calculate_log_probs=True \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
@@ -473,14 +475,15 @@ future will be our next focus.
### checkpoint-engine Ablation Experiment
We tested the single-step parameter synchronization time of the checkpoint-engine on three models: Qwen2.5-Math-7B, Qwen3-30B-A3B, and Qwen3-235B-A22B, using default checkpoint-engine configurations. All experiments were performed on H20 machines, and the Megatron engine was used for training.
-| model | trainer rank | rollout rank | checkpoint-engine | total sync time |
-|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|
-| Qwen2.5-Math-7B | 4 | 4 | False | 0.12s |
-| Qwen2.5-Math-7B | 4 | 4 | True | 0.02s |
-| Qwen3-30B-A3B | 16 | 16 | False | 15.76s |
-| Qwen3-30B-A3B | 16 | 16 | True | 4.38s |
-| Qwen3-235B-A22B | 64 | 64 | False | 58.57s |
-| Qwen3-235B-A22B | 64 | 64 | True | 23.70s |
+
+| model | trainer rank | rollout rank | checkpoint-engine | total sync time |
+|:---------------:|:--------------:|:-------------:|:-------------------:|:-----------------:|
+| Qwen2.5-Math-7B | 4 | 4 | False | 0.12s |
+| Qwen2.5-Math-7B | 4 | 4 | True | 0.02s |
+| Qwen3-30B-A3B | 16 | 16 | False | 15.76s |
+| Qwen3-30B-A3B | 16 | 16 | True | 4.38s |
+| Qwen3-235B-A22B | 64 | 64 | False | 58.57s |
+| Qwen3-235B-A22B | 64 | 64 | True | 23.70s |
### use_trainer_do_validate Experiment
@@ -505,10 +508,10 @@ We used Qwen2.5-Math-7B to verify the benefits of `use_trainer_do_validate=True`
* staleness_threshold: 0.5
* partial_rollout: True
-| training mode | resource allocation | step | gen | old_log_prob | update_actor | validate time | total time
50 step | acc/mean@2 |
-|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|:---------------:|
-| colocate sync | 16 | 484.623 | 52.939 | 0 | 430.263 | 205.080 | 7h9m | 22.6 |
-| fully_async_policy | 8:8 | 489.953 | 52.622 | 0 | 435.874 | 95.699 | 7h2m | 21.0 |
+| training mode | resource allocation | step | gen | old_log_prob | update_actor | validate time | total time
50 step | acc/mean@2 |
+|:------------------:|:-------------------:|:-------:|:-------:|:------------:|:------------:|:-------------:|:---------------------:|:----------:|
+| colocate sync | 16 | 484.623 | 52.939 | 0 | 430.263 | 205.080 | 7h9m | 22.6 |
+| fully_async_policy | 8:8 | 489.953 | 52.622 | 0 | 435.874 | 95.699 | 7h2m | 21.0 |
## Multi-Turn Tool Calling
diff --git a/verl/experimental/fully_async_policy/README_zh.md b/verl/experimental/fully_async_policy/README_zh.md
index b6b5eb5344a..19a257247c3 100644
--- a/verl/experimental/fully_async_policy/README_zh.md
+++ b/verl/experimental/fully_async_policy/README_zh.md
@@ -2,7 +2,7 @@
**Author:** `https://github.com/meituan-search`
-Last updated: 12/15/2025.
+Last updated: 02/05/2026.
本文档介绍了完全异步PPO训练系统,该系统实现了 Trainer 和 Rollouter 的完全解耦,支持异步样本生成和训练。
在该系统下,我们使用128卡训练qwen2.5-7B模型取得了2.35x-2.67x的性能提升,同时效果没有显著受到影响。
@@ -65,27 +65,27 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev
### 参数说明
-| super params | implication |
-|------------------------------------------------------|-----------------------------------------------------------------|
-| `trainer.nnodes` | Trainer的node数量 |
-| `trainer.n_gpus_per_node` | Trainer每个node上gpu的数量 |
-| `rollout.nnodes` | Rollouter的node数量 |
-| `rollout.n_gpus_per_node` | Rollouter每个node上gpu的数量 |
-| `data.train_batch_size` | 在fully async策略中,该值不生效(默认设置为0) |
-| `data.gen_batch_size` | 在fully async策略中,使用流式的样本生产逻辑(默认设置为1) |
-| `rollout.total_rollout_steps` | 总的rollout的sample数量 |
-| `rollout.test_freq` | Rollouter每更新多少次参数,进行一次validation |
-| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
-| `async_training.require_batches` | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量 |
-| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 |
-| `async_training.staleness_threshold` | 新鲜度控制 |
-| `async_training.partial_rollout` | 是否进行partial_rollout |
-| `async_training.use_rollout_log_probs` | 使用rollout产生的log_probs |
-| `async_training.compute_prox_log_prob`(experimental) | 是否在train阶段,使用train模型的参数计算token的 log_prob |
-| `async_training.checkpoint_engine.enable`| 是否开启checkpoint_engine模式的加速,默认值True |
-| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | 启动checkpoint_engine时,是否在参数同步时在broadcast和加载之间使用流水,默认值False|
-| `async_training.checkpoint_engine.device_buffer_size_M` | 启动checkpoint_engine时,组装的bucket的大小(MB),默认为4096 |
-| `async_training.use_trainer_do_validate` | 是否使用Trainer的do_validate方法进行validation,默认值False |
+| super params | implication |
+|------------------------------------------------------------------|-----------------------------------------------------------------|
+| `trainer.nnodes` | Trainer的node数量 |
+| `trainer.n_gpus_per_node` | Trainer每个node上gpu的数量 |
+| `rollout.nnodes` | Rollouter的node数量 |
+| `rollout.n_gpus_per_node` | Rollouter每个node上gpu的数量 |
+| `data.train_batch_size` | 在fully async策略中,该值不生效(默认设置为0) |
+| `data.gen_batch_size` | 在fully async策略中,使用流式的样本生产逻辑(默认设置为1) |
+| `rollout.total_rollout_steps` | 总的rollout的sample数量 |
+| `rollout.test_freq` | Rollouter每更新多少次参数,进行一次validation |
+| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
+| `actor_rollout_ref.actor.use_rollout_log_probs=True` | 使用rollout产生的log_probs |
+| `algorithm.rollout_correction.bypass_mode` | 是否在train阶段,使用train模型的参数计算token的 log_prob |
+| `async_training.require_batches` | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量 |
+| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 |
+| `async_training.staleness_threshold` | 新鲜度控制 |
+| `async_training.partial_rollout` | 是否进行partial_rollout |
+| `async_training.checkpoint_engine.enable` | 是否开启checkpoint_engine模式的加速,默认值True |
+| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | 启动checkpoint_engine时,是否在参数同步时在broadcast和加载之间使用流水,默认值False |
+| `async_training.checkpoint_engine.device_buffer_size_M` | 启动checkpoint_engine时,组装的bucket的大小(MB),默认为4096 |
+| `async_training.use_trainer_do_validate` | 是否使用Trainer的do_validate方法进行validation,默认值False |
**进一步的解释:**
@@ -124,26 +124,28 @@ https://github.com/ArronHZG/verl-community/blob/main/docs/fully_async_policy_rev
partial_rollout只会在staleness_threshold>0时才实际上起作用。
-* `async_training.use_rollout_log_probs`
+* `actor_rollout_ref.actor.use_rollout_log_probs=True`
在强化学习算法中,log_probs与参数版本,token都存在隐性的相关性。由于PPO/GRPO/DAPO等算法的设定,我们在计算重要性采样时,
即 old_log_prob必须使用rollout参数及token所对应log_probs,才能保证算法的正确性。在fully
- async策略中,我们默认old_log_prob是有rollout所计算的,而不是由trainer所计算。
+ async策略中,我们默认old_log_prob是由rollout所计算的,而不是由trainer所计算。
-* `async_training.require_batches`
-
- 在流式训练中,require_batches 应该设置为1,表示生产够ppo_mini_batch_size样本后,就进行训练。
- 在实际测试中,我们发现,如果单次下发的样本较少,由于数据分发的顺序,会导致训练不稳定,response 长度变长。
- 在这里,我们额外提供 require_batches 进行流式分发,单次参与训练的样本数量控制。
-
-* `async_training.compute_prox_log_prob` (experimental)
+* `algorithm.rollout_correction.bypass_mode`
+ algorithm.rollout_correction.bypass_mode 默认为 True, 直接使用rollout log prob。
我们在训练过程中,观测到随着训练的进行,训练后期指标和response长度可能会出现不稳定的情况,
这里我们可以使用 [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html) 的技术进行
重要性采样,缓解这一问题。为了使用 `Rollout Importance Sampling` 我们需要使用训练引擎使用当前的参数版本计算old_log_prob,此开关需要打开。
- 此外,在 mode d (async stream pipeline with partial rollout) 的情况下开启 `compute_prox_log_prob` 以及
+ 此外,在 mode d (async stream pipeline with partial rollout) 的情况下 `algorithm.rollout_correction.bypass_mode=False`
+ 以及
`Rollout Importance Sampling` 后,我们的实现已近似Areal的 `Decoupled PPO`。
+* `async_training.require_batches`
+
+ 在流式训练中,require_batches 应该设置为1,表示生产够ppo_mini_batch_size样本后,就进行训练。
+ 在实际测试中,我们发现,如果单次下发的样本较少,由于数据分发的顺序,会导致训练不稳定,response 长度变长。
+ 在这里,我们额外提供 require_batches 进行流式分发,单次参与训练的样本数量控制。
+
* `async_training.checkpoint_engine.enable`
开启checkpoint engine后,相较于原始的逐tensor的参数同步方式,同步时间开销普遍可以降低60%以上。但是组装bucket会带来额外的临时显存开销。
@@ -269,7 +271,6 @@ python -m recipe.fully_async_policy.fully_async_main \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
- actor_rollout_ref.rollout.calculate_log_probs=True \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
diff --git a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py
index 463703cd3af..3098e48ba96 100644
--- a/verl/experimental/fully_async_policy/agent_loop/agent_loop.py
+++ b/verl/experimental/fully_async_policy/agent_loop/agent_loop.py
@@ -347,27 +347,5 @@ async def wake_up(self):
async def sleep(self):
await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas])
- async def reset_prefix_cache(self):
- print("[FullyAsyncAgentLoopManager] Reset prefix cache ...")
- # await asyncio.gather(*[replica.reset_prefix_cache() for replica in self.rollout_replicas])
- # Note: debug
- timeout = 5.0
-
- async def reset_one(idx, replica):
- print(f"[reset_prefix_cache] start replica={idx}")
- try:
- await asyncio.wait_for(replica.reset_prefix_cache(), timeout=timeout)
- except asyncio.TimeoutError:
- print(f"[reset_prefix_cache] TIMEOUT replica={idx} after {timeout}s")
- return
- except Exception as e:
- print(f"[reset_prefix_cache] ERROR replica={idx}: {e!r}")
- return
- print(f"[reset_prefix_cache] done replica={idx}")
-
- tasks = [reset_one(i, replica) for i, replica in enumerate(self.rollout_replicas)]
- await asyncio.gather(*tasks, return_exceptions=True)
- print("[FullyAsyncAgentLoopManager] Reset prefix cache finished")
-
async def clear_kv_cache(self):
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml
index 85b8307ee0c..eece540865c 100644
--- a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml
+++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml
@@ -21,12 +21,6 @@ async_training:
# When synchronizing parameters, whether to interrupt rollouter and perform partial rollout
partial_rollout: True
- # Whether to use rollout log probs for training
- use_rollout_log_probs: True
-
- # compute_prox_log_prob
- compute_prox_log_prob: False
-
# whether to use trainer do_validate
use_trainer_do_validate: False
@@ -71,6 +65,21 @@ actor_rollout_ref:
# checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer
checkpoint_engine: ${oc.select:async_training.checkpoint_engine, null}
+ rollout:
+ # Must be turned off! Otherwise, Parameter synchronization cannot be performed.
+ free_cache_engine: False
+ # Must be enabled! Otherwise, log_probs cannot be calculated.
+ calculate_log_probs: True
+ # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced.
+ # TODO: Can be removed in the future once parameter synchronization is ready.
+ load_format: "auto"
+
actor:
- # Whether to use rollout log probs for training
- use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs, True}
\ No newline at end of file
+ # Must use rollout log probs for training
+ use_rollout_log_probs: True
+
+# Only then will the use of log probs be correct.
+# And it can be used in conjunction with other rollout_correction algorithms.
+algorithm:
+ rollout_correction:
+ bypass_mode: True
\ No newline at end of file
diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml
index c5692b4a931..7dece1cd479 100644
--- a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml
+++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml
@@ -21,12 +21,6 @@ async_training:
# When synchronizing parameters, whether to interrupt rollouter and perform partial rollout
partial_rollout: True
- # Whether to use rollout log probs for training
- use_rollout_log_probs: True
-
- # compute_prox_log_prob
- compute_prox_log_prob: False
-
# whether to use trainer do_validate
use_trainer_do_validate: False
@@ -71,6 +65,21 @@ actor_rollout_ref:
# checkpoint_engine config for accelerating parameter synchronization between rollouter and trainer
checkpoint_engine: ${oc.select:async_training.checkpoint_engine, null}
+ rollout:
+ # Must be turned off! Otherwise, Parameter synchronization cannot be performed.
+ free_cache_engine: False
+ # Must be enabled! Otherwise, log_probs cannot be calculated.
+ calculate_log_probs: True
+ # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced.
+ # TODO: Can be removed in the future once parameter synchronization is ready.
+ load_format: "auto"
+
actor:
- # Whether to use rollout log probs for training
- use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs, True}
\ No newline at end of file
+ # Must use rollout log probs for training
+ use_rollout_log_probs: True
+
+# Only then will the use of log probs be correct.
+# And it can be used in conjunction with other rollout_correction algorithms.
+algorithm:
+ rollout_correction:
+ bypass_mode: True
\ No newline at end of file
diff --git a/verl/experimental/fully_async_policy/fully_async_main.py b/verl/experimental/fully_async_policy/fully_async_main.py
index 685af1a2eaa..7dcb91e3b2c 100644
--- a/verl/experimental/fully_async_policy/fully_async_main.py
+++ b/verl/experimental/fully_async_policy/fully_async_main.py
@@ -170,11 +170,17 @@ def _initialize_components(self, config) -> None:
self.components["role_worker_mapping"] = role_worker_mapping
self.components["ray_worker_group_cls"] = ray_worker_group_cls
- print("[ASYNC MAIN] Creating FullyAsyncRollouter...")
- self._create_rollouter(config)
+ from concurrent.futures import ThreadPoolExecutor
- print("[ASYNC MAIN] Creating FullyAsyncTrainer...")
- self._create_trainer(config)
+ print("[ASYNC MAIN] Creating FullyAsyncRollouter and FullyAsyncTrainer in parallel...")
+ with ThreadPoolExecutor(max_workers=2) as executor:
+ rollouter_future = executor.submit(self._create_rollouter, config)
+ rollouter_future.result()
+
+ # TODO: keep _create_rollouter and _create_trainer parallel
+ trainer_future = executor.submit(self._create_trainer, config)
+ # Wait for both to complete
+ trainer_future.result()
# sync total_train_steps between rollouter and trainer
total_train_steps = ray.get(self.components["rollouter"].get_total_train_steps.remote())
diff --git a/verl/experimental/fully_async_policy/fully_async_rollouter.py b/verl/experimental/fully_async_policy/fully_async_rollouter.py
index 757432f4cf0..04f7a77fa82 100644
--- a/verl/experimental/fully_async_policy/fully_async_rollouter.py
+++ b/verl/experimental/fully_async_policy/fully_async_rollouter.py
@@ -31,7 +31,7 @@
prepare_single_generation_data,
)
from verl.experimental.fully_async_policy.message_queue import MessageQueueClient
-from verl.experimental.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
+from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
from verl.trainer.ppo.reward import load_reward_manager
@@ -42,7 +42,7 @@
@ray.remote(num_cpus=10, max_concurrency=100)
-class FullyAsyncRollouter(FullyAsyncRayPPOTrainer):
+class FullyAsyncRollouter(SeparateRayPPOTrainer):
"""
Asynchronous sample generator, responsible for continuously generating training samples
and putting them into MessageQueue
@@ -83,6 +83,12 @@ def __init__(
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
+ self.use_reference_policy = False
+
+ self.use_rm = False
+ self.use_reward_loop = self.config.reward_model.use_reward_loop
+
+ self.use_critic = False
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name if device_name else self.config.trainer.device
self.validation_generations_logger = ValidationGenerationsLogger(
@@ -92,9 +98,11 @@ def __init__(
self.ref_in_actor = False
self.kl_ctrl_in_reward = False
- self.use_critic = False
- self.use_reference_policy = False
- self.use_rm = False
+
+ self.use_prefix_grouper = self.config.actor_rollout_ref.actor.get("use_prefix_grouper", False)
+ self.use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
+
+ # ==================== fully async config ====================
print("[FullyAsyncRollouter] Creating datasets...")
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
@@ -120,8 +128,6 @@ def __init__(
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
- # ==================== fully async config ====================
-
self.total_rollout_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.rollout.total_rollout_steps is not None:
self.total_rollout_steps = min(self.config.rollout.total_rollout_steps, self.total_rollout_steps)
@@ -264,7 +270,7 @@ async def update_param_version(
or (validate and self.val_reward_fn is not None)
)
print(
- f"[FullyAsyncRollouter] need_validate: {need_validate},"
+ f"[FullyAsyncRollouter] need_validate: {need_validate}, "
f"parallel_validate_and_rollout: {self.parallel_validate_and_rollout}"
)
if not need_validate:
@@ -411,6 +417,7 @@ async def init_workers(self):
self._create_worker_classes()
self._init_worker_groups()
self._init_models()
+ self._init_reward_loop()
await self._init_async_rollout_manager()
def _create_actor_rollout_classes(self):
@@ -439,14 +446,23 @@ def _create_continuous_iterator(self):
yield epoch, batch_dict
async def _init_async_rollout_manager(self):
+ # infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design
+ # agent_reward_loop: streaming reward computation with actor rollout
+ # two conditions satisfied: (1) no reward model, or (2) reward model with extra resource pool
+ enable_agent_reward_loop = self.use_reward_loop and (
+ not self.use_rm or self.config.reward_model.enable_resource_pool
+ )
+ # if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager
+ # to stream reward computation with actor rollout
+ reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None
+
# create async rollout manager and request scheduler
assert self.config.actor_rollout_ref.rollout.mode == "async"
from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
self.async_rollout_mode = True
self.async_rollout_manager = await FullyAsyncAgentLoopManager.create(
- config=self.config,
- worker_group=self.rollout_wg,
+ config=self.config, worker_group=self.rollout_wg, reward_loop_worker_handles=reward_loop_worker_handles
)
# Add samples to the pending_queue
@@ -478,7 +494,7 @@ async def _feed_samples(self):
if self.global_steps >= self.total_rollout_steps:
print(
f"[FullyAsyncRollouter][Feed] "
- f"Maximum count has been reached, stop adding new samples"
+ f"Maximum count has been reached, stop adding new samples: "
f"{self.global_steps} >= {self.total_rollout_steps}"
)
break
@@ -751,10 +767,12 @@ async def pause(self):
await asyncio.gather(*self.active_tasks, return_exceptions=True)
self.active_tasks.clear()
print("[FullyAsyncRollouter][Public][Pause] All active tasks completed")
- print("[FullyAsyncRollouter][Public][Pause] Prefix cache reset")
- # Always clear KV cache to release GPU memory during weight synchronization,
- # regardless of partial_rollout setting.
- await self.async_rollout_manager.clear_kv_cache()
+
+ # TODO use checkpoint engine for rollout clear_kv_cache
+ # print("[FullyAsyncRollouter][Public][Pause] clear kv cache")
+ # # Always clear KV cache to release GPU memory during weight synchronization,
+ # # regardless of partial_rollout setting.
+ # await self.async_rollout_manager.clear_kv_cache()
self.monitor_loop_trigger = False
async def resume(self, dependency_ref: ObjectRef = None):
diff --git a/verl/experimental/fully_async_policy/fully_async_trainer.py b/verl/experimental/fully_async_policy/fully_async_trainer.py
index eb272185423..d93d606204b 100644
--- a/verl/experimental/fully_async_policy/fully_async_trainer.py
+++ b/verl/experimental/fully_async_policy/fully_async_trainer.py
@@ -19,16 +19,16 @@
from typing import Any
import ray
-from omegaconf import OmegaConf
from tqdm import tqdm
+from verl import DataProto
from verl.experimental.fully_async_policy.detach_utils import (
MetricsAggregator,
ValidateMetrics,
assemble_batch_from_rollout_samples,
)
from verl.experimental.fully_async_policy.message_queue import MessageQueueClient
-from verl.experimental.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
+from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
@@ -38,8 +38,14 @@
from verl.utils.debug import marked_timer
+class TrainingStopException(Exception):
+ """Exception raised to signal training should stop"""
+
+ pass
+
+
@ray.remote(num_cpus=10)
-class FullyAsyncTrainer(FullyAsyncRayPPOTrainer):
+class FullyAsyncTrainer(SeparateRayPPOTrainer):
"""
A fully asynchronous PPO trainer that obtains samples from a MessageQueue for training.
Based on an improved implementation of OneStepOffRayTrainer
@@ -57,6 +63,8 @@ def __init__(
val_reward_fn=None,
device_name=None,
):
+ # ==================== RayPPOTrainer config ====================
+
# Store the tokenizer for text processing
self.tokenizer = tokenizer
self.processor = processor
@@ -74,22 +82,46 @@ def __init__(
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = need_reference_policy(self.config)
+
self.use_rm = need_reward_model(self.role_worker_mapping)
+ self.use_reward_loop = self.config.reward_model.use_reward_loop
+
self.use_critic = need_critic(self.config)
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name if device_name else self.config.trainer.device
+ # if ref_in_actor is True, the reference policy will be actor without lora applied
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
if lora_rank <= 0:
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
- # if ref_in_actor is True, the reference policy will be actor without lora applied
- self.ref_in_actor = lora_rank > 0
+ self.ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None
# define in-reward KL control
# kl loss control currently not suppoorted
if self.config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
+ self.use_prefix_grouper = self.config.actor_rollout_ref.actor.get("use_prefix_grouper", False)
+ self.use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
+
+ # ==================== SeparateRayPPOTrainer config ====================
+ self.global_steps = 0
+ self.epoch = 0
+ self.max_steps_duration = 0
+ self.progress_bar = None
+ self.logger = None
+ self.is_last_step = False
+ self.prev_step_profile = False
+ self.curr_step_profile = False
+ self.next_step_profile = False
+ self.last_val_metrics = {}
+ self.metrics = {}
+ self.timing_raw = {}
+ # reward message
+ self.future_reward = None
+ self.reward_tensor = None
+ self.reward_extra_infos_dict = {}
+
# ==================== fully async config ====================
self.message_queue_client = None
@@ -113,7 +145,6 @@ def __init__(
# required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples.
self.require_batches = config.async_training.require_batches
self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches
- self.compute_prox_log_prob = self.config.async_training.compute_prox_log_prob
total_gpus = (
config.trainer.nnodes * config.trainer.n_gpus_per_node
+ config.rollout.nnodes * config.rollout.n_gpus_per_node
@@ -271,14 +302,31 @@ async def _init_async_rollout_manager(self):
# use async rollout do validate
print(f"[FullyAsyncTrainer] use_trainer_do_validate: {self.config.async_training.use_trainer_do_validate}")
if self.config.async_training.use_trainer_do_validate:
- assert self.config.actor_rollout_ref.rollout.mode == "async"
- self.async_rollout_mode = True
print("[FullyAsyncTrainer] Init async rollout manager")
+
+ # infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design
+ # agent_reward_loop: streaming reward computation with actor rollout
+ # two conditions satisfied: (1) no reward model, or (2) reward model with extra resource pool
+ enable_agent_reward_loop = self.use_reward_loop and (
+ not self.use_rm or self.config.reward_model.enable_resource_pool
+ )
+ # if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager
+ # to stream reward computation with actor rollout
+ reward_loop_worker_handles = (
+ self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None
+ )
+
+ # create async rollout manager and request scheduler
+ assert self.config.actor_rollout_ref.rollout.mode == "async"
from verl.experimental.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
+ self.async_rollout_mode = True
self.async_rollout_manager = await FullyAsyncAgentLoopManager.create(
- config=self.config, worker_group=self.actor_rollout_wg
+ config=self.config,
+ worker_group=self.actor_rollout_wg,
+ reward_loop_worker_handles=reward_loop_worker_handles,
)
+
print("[FullyAsyncTrainer] async_rollout_manager sleep")
await self.async_rollout_manager.sleep()
else:
@@ -297,6 +345,8 @@ async def fit(self):
if self.param_synchronizer is None:
raise ValueError("param_synchronizer client not set. Call set_parameter_synchronizer() first.")
+ from omegaconf import OmegaConf
+
from verl.utils.tracking import Tracking
self.logger = Tracking(
@@ -314,36 +364,11 @@ async def fit(self):
# Use queue mode, no need for traditional dataloader iterator
# Initialize to get the first batch of data
while True:
- metrics = {}
- timing_raw = {}
-
- with marked_timer("step", timing_raw):
- with marked_timer("gen", timing_raw, color="red"):
- epoch, batch = self._get_samples_from_queue()
- if batch is None:
- break
- self._collect_metrics_from_samples(batch, metrics)
- batch, reward_extra_infos_dict = self._process_batch_common(
- batch, metrics, timing_raw, self.local_trigger_step if self.compute_prox_log_prob else None
- )
- self._log_rollout(batch, reward_extra_infos_dict, timing_raw)
-
- self._collect_metrics(batch, 0, metrics, timing_raw)
- self.metrics_aggregator.add_step_metrics(
- metrics=metrics, sample_count=self.required_samples, timestamp=time.time()
- )
- # Trigger parameter synchronization after training step
- time_str = datetime.now().strftime("%H:%M:%S.%f")[:-3]
- print(
- f"[FullyAsyncTrainer] global_steps: {self.global_steps} "
- f"local_trigger_step: {self.local_trigger_step} "
- f"trigger_parameter_sync_step: {self.trigger_parameter_sync_step} "
- f"{time_str}"
- )
- await self._trigger_parameter_sync_after_step(global_steps=self.global_steps)
- self._log_validation_data()
- self._check_save_checkpoint(timing_raw)
- self.global_steps += 1
+ try:
+ await self.fit_step()
+ except TrainingStopException:
+ print("[FullyAsyncTrainer] Training stopped by queue termination signal")
+ break
# final parameter sync and validate
# 1. waiting remaining validate task
@@ -355,12 +380,106 @@ async def fit(self):
ray.get(self.param_synchronizer.wait_last_valid.remote())
self._log_validation_data()
self.progress_bar.close()
+ self._fit_save_checkpoint()
- self._check_save_checkpoint(timing_raw)
+ async def fit_step(self, batch_dict: dict = None):
+ """
+ Single-step training template method. Handles all logic for one training step.
- def _check_save_checkpoint(self, timing_raw):
- if self.current_param_version == self.last_ckpt_version:
- return
+ Flow:
+ 1. Pre-step processing -> 2. Get batch -> 3. Generate sequences ->
+ 4. Compute reward -> 5. Compute log_prob -> 6. Compute reward ->
+ 7. Compute advantage -> 8. Update critic -> 9. Update actor -> 10. Post-step processing
+
+ Args:
+ batch_dict: Raw data dictionary
+ """
+ print("[FullyAsyncTrainer] fit_step")
+ self.metrics = {"training/global_step": self.global_steps, "training/epoch": self.epoch}
+ self.timing_raw = {}
+ # reward message
+ self.future_reward = None
+ self.reward_tensor = None
+ self.reward_extra_infos_dict = {}
+
+ # self._fit_prepare_step()
+ self._fit_start_profile()
+
+ with marked_timer("step", self.timing_raw):
+ batch = self._fit_generate(None)
+ batch = self._fit_compute_reward(batch)
+ batch = self._fit_compute_log_prob(batch)
+ batch = self._fit_compute_ref_log_prob(batch)
+ batch = self._fit_compute_critic(batch)
+ batch = self._fit_compute_advantage(batch)
+ batch = self._fit_update_critic(batch)
+ batch = self._fit_update_actor(batch)
+ await self._fit_update_weights()
+ self._fit_dump_data(batch)
+
+ # self._fit_validate()
+ self._fit_save_checkpoint()
+ self._fit_stop_profile()
+ self._fit_collect_metrics(batch)
+ self._fit_torch_memory()
+ # self._fit_experimental(batch)
+ self._fit_postprocess_step()
+
+ def _fit_generate(self, batch: DataProto = None) -> DataProto:
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+ with marked_timer("gen", timing_raw, color="red"):
+ epoch, batch = self._get_samples_from_queue()
+ if batch is None:
+ raise TrainingStopException("Training terminated: queue returned None")
+ self._collect_metrics_from_samples(batch, metrics)
+ return batch
+
+ def _compute_old_log_prob(self, batch: DataProto):
+ """
+ If algorithm.rollout_correction.bypass_mode is False,
+ use model engine and first version model params to re-calculate old_log_prob.
+
+ If local_trigger_step == 1, load the training engine's parameters to the CPU
+ and save a copy for subsequent MIS use.
+
+ If local_trigger_step == 2, 3, ..., restore the parameters of version 1 to calculate the old_log_prob,
+ then restore the parameters of the current version.
+ """
+ if self.local_trigger_step == 1:
+ self.actor_rollout_wg.save_model_to_cpu(1)
+ old_log_prob, old_log_prob_mfu = super()._compute_old_log_prob(batch)
+ else:
+ self.actor_rollout_wg.save_model_to_cpu(self.local_trigger_step)
+ self.actor_rollout_wg.restore_model_from_cpu(1)
+ old_log_prob, old_log_prob_mfu = super()._compute_old_log_prob(batch)
+ self.actor_rollout_wg.restore_model_from_cpu(self.local_trigger_step)
+ self.actor_rollout_wg.clear_cpu_model(self.local_trigger_step)
+ return old_log_prob, old_log_prob_mfu
+
+ def _fit_collect_metrics(self, batch):
+ super()._fit_collect_metrics(batch)
+ self.metrics_aggregator.add_step_metrics(
+ metrics=self.metrics, sample_count=self.required_samples, timestamp=time.time()
+ )
+ self._log_validation_data()
+
+ async def _fit_update_weights(self):
+ # with marked_timer("update_weights", self.timing_raw, color="red"):
+ # self.checkpoint_manager.update_weights()
+
+ # Trigger parameter synchronization after training step
+ time_str = datetime.now().strftime("%H:%M:%S.%f")[:-3]
+ print(
+ f"[FullyAsyncTrainer] global_steps: {self.global_steps} "
+ f"local_trigger_step: {self.local_trigger_step} "
+ f"trigger_parameter_sync_step: {self.trigger_parameter_sync_step} "
+ f"{time_str}"
+ )
+ await self._trigger_parameter_sync_after_step()
+
+ def _fit_save_checkpoint(self):
+ timing_raw = self.timing_raw
# Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
esi_close_to_expiration = should_save_ckpt_esi(
max_steps_duration=self.max_steps_duration,
@@ -370,16 +489,23 @@ def _check_save_checkpoint(self, timing_raw):
# The conditions include a mandatory condition (1) and
# one of the following optional conditions (2/3/4):
# 1. The save frequency is set to a positive value.
- # 2. The current step number is a multiple of the save frequency.
- # 3. The ESI(Elastic Server Instance)/training plan is close to expiration.
+ # 2. It's the last training step.
+ # 3. The current step number is a multiple of the save frequency.
+ # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
if self.config.trainer.save_freq > 0 and (
self.current_param_version % self.config.trainer.save_freq == 0 or esi_close_to_expiration
):
if esi_close_to_expiration:
print("Force saving checkpoint: ESI instance expiration approaching.")
with marked_timer("save_checkpoint", timing_raw, color="green"):
+ # sleep replicas to avoid OOM during checkpoint saving
+ # self.checkpoint_manager.sleep_replicas()
self._save_checkpoint()
- self.last_ckpt_version = self.current_param_version
+ # wake replicas to avoid OOM during checkpoint saving
+ # self.checkpoint_manager.update_weights()
+
+ def _fit_postprocess_step(self):
+ self.global_steps += 1
def _save_checkpoint(self):
# Warning: Currently, to align the training process and metrics of colocate,
@@ -522,7 +648,7 @@ def _collect_metrics_from_samples(self, batch, metrics):
if key.startswith("fully_async") or key.startswith("timing_s"):
metrics[key] = value
- async def _trigger_parameter_sync_after_step(self, validate: bool = False, global_steps: int = None):
+ async def _trigger_parameter_sync_after_step(self, validate: bool = False):
"""
Trigger parameter synchronization after training step
This ensures rollouter always uses the latest trained parameters
@@ -547,7 +673,7 @@ async def _trigger_parameter_sync_after_step(self, validate: bool = False, globa
self.param_synchronizer.sync_weights.remote(
self.current_param_version,
validate=validate,
- global_steps=global_steps,
+ global_steps=self.global_steps,
use_trainer_do_validate=self.config.async_training.use_trainer_do_validate,
)
)
diff --git a/verl/experimental/fully_async_policy/message_queue.py b/verl/experimental/fully_async_policy/message_queue.py
index 85860c6f2a0..f5dcec566bc 100644
--- a/verl/experimental/fully_async_policy/message_queue.py
+++ b/verl/experimental/fully_async_policy/message_queue.py
@@ -60,7 +60,7 @@ def __init__(self, config: DictConfig, max_queue_size: int = 1000):
self.dropped_samples = 0
print(
- f"[MessageQueue] initialized with max_queue_size={max_queue_size},"
+ f"[MessageQueue] initialized with max_queue_size={max_queue_size}, "
f"staleness_threshold={self.staleness_threshold}"
)
diff --git a/verl/experimental/fully_async_policy/param_sync.py b/verl/experimental/fully_async_policy/param_sync.py
index 4a9ac167aa3..000568d12d6 100644
--- a/verl/experimental/fully_async_policy/param_sync.py
+++ b/verl/experimental/fully_async_policy/param_sync.py
@@ -129,15 +129,17 @@ def sync_weights(self, version, validate=False, global_steps=0, use_trainer_do_v
# sync weights
# For sglang, always use sync_rollout_weights instead of sync_rollout_weights_by_checkpoint
- rollout_name = getattr(self.config.actor_rollout_ref.rollout, "name", None)
- use_checkpoint_engine = self.config.async_training.checkpoint_engine.enable and rollout_name != "sglang"
- if use_checkpoint_engine:
- self.actor_wg.sync_rollout_weights_by_checkpoint(self.sync_group_name)
- ray.get(self.rollout_wg.sync_rollout_weights_by_checkpoint(self.sync_group_name))
- else:
- self.actor_wg.sync_rollout_weights(self.sync_group_name)
- ray.get(self.rollout_wg.sync_rollout_weights(self.sync_group_name))
+ # TODO use checkpoint engine for sglang rollout
+ # rollout_name = getattr(self.config.actor_rollout_ref.rollout, "name", None)
+ # use_checkpoint_engine = self.config.async_training.checkpoint_engine.enable and rollout_name != "sglang"
+ # if use_checkpoint_engine:
+ # self.actor_wg.sync_rollout_weights_by_checkpoint(self.sync_group_name)
+ # ray.get(self.rollout_wg.sync_rollout_weights_by_checkpoint(self.sync_group_name))
+ # else:
+ # self.actor_wg.sync_rollout_weights(self.sync_group_name)
+ # ray.get(self.rollout_wg.sync_rollout_weights(self.sync_group_name))
+
end_time = time.time()
print(
f"[ParameterSynchronizer] sync_weights success. cost {end_time - start_time:.2f} seconds, "
diff --git a/verl/experimental/fully_async_policy/ray_trainer.py b/verl/experimental/fully_async_policy/ray_trainer.py
deleted file mode 100644
index f31e55d1388..00000000000
--- a/verl/experimental/fully_async_policy/ray_trainer.py
+++ /dev/null
@@ -1,538 +0,0 @@
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-# Copyright 2023-2024 SGLang Team
-# Copyright 2025 ModelBest Inc. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-PPO Trainer with Ray-based single controller.
-This trainer supports model-agonistic model initialization with huggingface
-"""
-
-import uuid
-from copy import deepcopy
-from pprint import pprint
-
-import numpy as np
-import ray
-import torch
-from omegaconf import OmegaConf
-from tqdm import tqdm
-
-from verl import DataProto
-from verl.experimental.dataset.sampler import AbstractCurriculumSampler
-from verl.single_controller.ray import RayClassWithInitArgs
-from verl.single_controller.ray.base import create_colocated_worker_cls
-from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
-from verl.trainer.ppo.metric_utils import (
- compute_data_metrics,
- compute_throughout_metrics,
- compute_timing_metrics,
-)
-from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
-from verl.trainer.ppo.reward import compute_reward, compute_reward_async
-from verl.trainer.ppo.utils import Role
-from verl.utils.config import omega_conf_to_dataclass
-from verl.utils.debug import marked_timer
-from verl.utils.metric import (
- reduce_metrics,
-)
-from verl.utils.rollout_skip import RolloutSkip
-
-
-class FullyAsyncRayPPOTrainer(RayPPOTrainer):
- def init_workers(self):
- """Initialize distributed training workers using Ray backend.
-
- Creates:
- 1. Ray resource pools from configuration
- 2. Worker groups for each role (actor, critic, etc.)
- """
- self._init_resource_pools()
- self._create_worker_classes()
- self._init_worker_groups()
- self._init_models()
- self._init_async_rollout_manager()
-
- def _init_resource_pools(self):
- self.resource_pool_manager.create_resource_pool()
-
- self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
-
- def _create_worker_classes(self):
- self._create_actor_rollout_classes()
- self._create_critic_class()
- self._create_reference_policy_class()
- self._create_reward_model_class()
-
- def _create_actor_rollout_classes(self):
- raise NotImplementedError
-
- def _create_critic_class(self):
- # create critic
- if self.use_critic:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
- critic_cfg = omega_conf_to_dataclass(self.config.critic)
- critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
- self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
-
- def _create_reference_policy_class(self):
- # create reference policy if needed
- if self.use_reference_policy:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
- ref_policy_cls = RayClassWithInitArgs(
- self.role_worker_mapping[Role.RefPolicy],
- config=self.config.actor_rollout_ref,
- role=str(Role.RefPolicy),
- # profile_option=self.config.trainer.npu_profile.options,
- )
- self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
-
- def _create_reward_model_class(self):
- # create a reward model if reward_fn is None
- if self.use_rm:
- # we create a RM here
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
- rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
- self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
-
- def _init_worker_groups(self):
- # initialize WorkerGroup
- # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
- # you should not use `create_colocated_worker_cls`.
- # Instead, directly pass different resource pool to different worker groups.
- # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
- all_wg = {}
- wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
- if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
- wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
- if OmegaConf.select(self.config.global_profiler, "steps") is not None:
- wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
- # Only require nsight worker options when tool is nsys
- if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
- assert (
- OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
- is not None
- ), "worker_nsight_options must be set when using nsys with profile_steps"
- wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
- OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
- )
- wg_kwargs["device_name"] = self.device_name
-
- for resource_pool, class_dict in self.resource_pool_to_cls.items():
- worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
- wg_dict = self.ray_worker_group_cls(
- resource_pool=resource_pool,
- ray_cls_with_init=worker_dict_cls,
- **wg_kwargs,
- )
- spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
- all_wg.update(spawn_wg)
- self.all_wg = all_wg
-
- def _init_models(self):
- if self.use_critic:
- self.critic_wg = self.all_wg[str(Role.Critic)]
- self.critic_wg.init_model()
-
- if self.use_reference_policy and not self.ref_in_actor:
- self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)]
- self.ref_policy_wg.init_model()
-
- if self.use_rm:
- self.rm_wg = self.all_wg[str(Role.RewardModel)]
- self.rm_wg.init_model()
-
- # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
- self.actor_rollout_wg = self.all_wg[str(Role.ActorRollout)]
- self.actor_rollout_wg.init_model()
-
- def _init_async_rollout_manager(self):
- pass
-
- def fit(self):
- """
- The training loop of PPO.
- The driver process only need to call the compute functions of the worker group through RPC
- to construct the PPO dataflow.
- The light-weight advantage computation is done on the driver process.
- """
- from omegaconf import OmegaConf
-
- from verl.utils.tracking import Tracking
-
- logger = Tracking(
- project_name=self.config.trainer.project_name,
- experiment_name=self.config.trainer.experiment_name,
- default_backend=self.config.trainer.logger,
- config=OmegaConf.to_container(self.config, resolve=True),
- )
-
- self.global_steps = 0
-
- # load checkpoint before doing anything
- self._load_checkpoint()
-
- # perform validation before training
- # currently, we only support validation using the reward_function.
- if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
- val_metrics = self._validate()
- assert val_metrics, f"{val_metrics=}"
- pprint(f"Initial validation metrics: {val_metrics}")
- logger.log(data=val_metrics, step=self.global_steps)
- if self.config.trainer.get("val_only", False):
- return
-
- if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
- rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
- rollout_skip.wrap_generate_sequences()
-
- # add tqdm
- progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
-
- # we start from step 1
- self.global_steps += 1
- last_val_metrics = None
- self.max_steps_duration = 0
-
- prev_step_profile = False
- curr_step_profile = (
- self.global_steps in self.config.global_profiler.steps
- if self.config.global_profiler.steps is not None
- else False
- )
- next_step_profile = False
-
- for epoch in range(self.config.trainer.total_epochs):
- for batch_dict in self.train_dataloader:
- metrics = {}
- timing_raw = {}
-
- with marked_timer("start_profile", timing_raw):
- self._start_profiling(
- not prev_step_profile and curr_step_profile
- if self.config.global_profiler.profile_continuous_steps
- else curr_step_profile
- )
-
- batch, gen_batch = self._prepare_generate_batch(batch_dict)
-
- is_last_step = self.global_steps >= self.total_training_steps
-
- with marked_timer("step", timing_raw):
- # generate a batch
- with marked_timer("gen", timing_raw, color="red"):
- if not self.async_rollout_mode:
- gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
- else:
- gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
- timing_raw.update(gen_batch_output.meta_info["timing"])
- gen_batch_output.meta_info.pop("timing", None)
-
- if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
- if self.reward_fn is None:
- raise ValueError("A reward_fn is required for REMAX advantage estimation.")
-
- with marked_timer("gen_max", timing_raw, color="purple"):
- gen_baseline_batch = deepcopy(gen_batch)
- gen_baseline_batch.meta_info["do_sample"] = False
- if not self.async_rollout_mode:
- gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
- else:
- gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
- batch = batch.union(gen_baseline_output)
- reward_baseline_tensor = self.reward_fn(batch)
- reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
-
- batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
-
- batch.batch["reward_baselines"] = reward_baseline_tensor
-
- del gen_baseline_batch, gen_baseline_output
-
- batch = self._post_generate_batch(batch, gen_batch_output, metrics)
- batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw)
- self._log_rollout(batch, reward_extra_infos_dict, timing_raw)
-
- last_val_metrics = self._validate_metrics(is_last_step, last_val_metrics, metrics, timing_raw)
- self._check_save_checkpoint(is_last_step, timing_raw)
-
- with marked_timer("stop_profile", timing_raw):
- next_step_profile = (
- self.global_steps + 1 in self.config.global_profiler.steps
- if self.config.global_profiler.steps is not None
- else False
- )
- self._stop_profiling(
- curr_step_profile and not next_step_profile
- if self.config.global_profiler.profile_continuous_steps
- else curr_step_profile
- )
- prev_step_profile = curr_step_profile
- curr_step_profile = next_step_profile
-
- self._collect_metrics(batch, epoch, metrics, timing_raw)
- self._post_batch_processing(batch)
-
- # TODO: make a canonical logger that supports various backend
- logger.log(data=metrics, step=self.global_steps)
-
- progress_bar.update(1)
- self.global_steps += 1
-
- if (
- hasattr(self.config.actor_rollout_ref.actor, "profiler")
- and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
- ):
- self.actor_rollout_wg.dump_memory_snapshot(
- tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
- )
-
- if is_last_step:
- pprint(f"Final validation metrics: {last_val_metrics}")
- progress_bar.close()
- return
-
- def _prepare_generate_batch(self, batch_dict):
- batch: DataProto = DataProto.from_single_dict(batch_dict)
-
- # add uid to batch
- batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
-
- gen_batch = self._get_gen_batch(batch)
-
- # pass global_steps to trace
- gen_batch.meta_info["global_steps"] = self.global_steps
- gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
- return batch, gen_batch
-
- def _post_generate_batch(self, batch, gen_batch_output, metrics):
- # repeat to align with repeated responses in rollout
- batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
- batch = batch.union(gen_batch_output)
-
- if "response_mask" not in batch.batch.keys():
- batch.batch["response_mask"] = compute_response_mask(batch)
- # Balance the number of valid tokens across DP ranks.
- # NOTE: This usually changes the order of data in the `batch`,
- # which won't affect the advantage calculation (since it's based on uid),
- # but might affect the loss calculation (due to the change of mini-batching).
- # TODO: Decouple the DP balancing and mini-batching.
- if self.config.trainer.balance_batch:
- self._balance_batch(batch, metrics=metrics)
-
- # compute global_valid tokens
- batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
-
- return batch
-
- def _process_batch_common(self, batch, metrics, timing_raw, local_trigger_step=None):
- with marked_timer("reward", timing_raw, color="yellow"):
- # compute reward model score
- if self.use_rm:
- reward_tensor = self.rm_wg.compute_rm_score(batch)
- batch = batch.union(reward_tensor)
-
- if self.config.reward_model.launch_reward_fn_async:
- future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)
- else:
- reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
-
- with marked_timer("old_log_prob", timing_raw, color="blue"):
-
- def compute_old_log_prob(batch):
- old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
- entropys = old_log_prob.batch["entropys"]
- response_masks = batch.batch["response_mask"]
- actor_config = self.config.actor_rollout_ref.actor
- entropy_agg = agg_loss(
- loss_mat=entropys,
- loss_mask=response_masks,
- loss_agg_mode=actor_config.loss_agg_mode,
- loss_scale_factor=actor_config.loss_scale_factor,
- )
- old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
- metrics.update(old_log_prob_metrics)
- old_log_prob.batch.pop("entropys")
- batch = batch.union(old_log_prob)
- if "rollout_log_probs" in batch.batch.keys():
- # TODO: we may want to add diff of probs too.
- from verl.utils.debug.metrics import calculate_debug_metrics
-
- metrics.update(calculate_debug_metrics(batch))
- return batch
-
- async_training = self.config.get("async_training", None)
- if async_training and async_training.use_rollout_log_probs:
- # If local_triger_step == 1, load the training engine's parameters to the CPU
- # and save a copy for subsequent MIS use.
- # If local_trigger_step == 2, 3, ..., restore the parameters of version 1 to calculate the old_log_prob,
- # then restore the parameters of the current version.
- if local_trigger_step == 1:
- self.actor_rollout_wg.save_model_to_cpu(1)
- batch = compute_old_log_prob(batch)
- elif local_trigger_step is not None:
- self.actor_rollout_wg.save_model_to_cpu(local_trigger_step)
- self.actor_rollout_wg.restore_model_from_cpu(1)
- batch = compute_old_log_prob(batch)
- self.actor_rollout_wg.restore_model_from_cpu(local_trigger_step)
- self.actor_rollout_wg.clear_cpu_model(local_trigger_step)
- else:
- batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"]
- batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
-
- else:
- batch = compute_old_log_prob(batch)
-
- if self.use_reference_policy:
- # compute reference log_prob
- with marked_timer("ref", timing_raw, color="olive"):
- if not self.ref_in_actor:
- ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
- else:
- ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
- batch = batch.union(ref_log_prob)
-
- # compute values
- if self.use_critic:
- with marked_timer("values", timing_raw, color="cyan"):
- values = self.critic_wg.compute_values(batch)
- batch = batch.union(values)
-
- with marked_timer("adv", timing_raw, color="brown"):
- # we combine with rule-based rm
- reward_extra_infos_dict: dict[str, list]
- if self.config.reward_model.launch_reward_fn_async:
- reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
- batch.batch["token_level_scores"] = reward_tensor
-
- if reward_extra_infos_dict:
- batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
-
- # compute rewards. apply_kl_penalty if available
- if self.config.algorithm.use_kl_in_reward:
- batch, kl_metrics = apply_kl_penalty(
- batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
- )
- metrics.update(kl_metrics)
- else:
- batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
-
- # Compute rollout correction weights centrally (once per batch)
- # This corrects for off-policy issues (policy mismatch, model staleness, etc.)
- # Also computes off-policy diagnostic metrics (KL, PPL, etc.)
- from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch
-
- rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
- if rollout_corr_config is not None and "rollout_log_probs" in batch.batch:
- batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
- # IS and off-policy metrics already have rollout_corr/ prefix
- metrics.update(is_metrics)
-
- # compute advantages, executed on the driver process
- norm_adv_by_std_in_grpo = self.config.algorithm.get(
- "norm_adv_by_std_in_grpo", True
- ) # GRPO adv normalization factor
-
- batch = compute_advantage(
- batch,
- adv_estimator=self.config.algorithm.adv_estimator,
- gamma=self.config.algorithm.gamma,
- lam=self.config.algorithm.lam,
- num_repeat=self.config.actor_rollout_ref.rollout.n,
- norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
- config=self.config.algorithm,
- )
-
- # update critic
- if self.use_critic:
- with marked_timer("update_critic", timing_raw, color="pink"):
- critic_output = self.critic_wg.update_critic(batch)
- critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
- metrics.update(critic_output_metrics)
-
- # implement critic warmup
- if self.config.trainer.critic_warmup <= self.global_steps:
- # update actor
- with marked_timer("update_actor", timing_raw, color="red"):
- batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
- actor_output = self.actor_rollout_wg.update_actor(batch)
- actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
- metrics.update(actor_output_metrics)
- return batch, reward_extra_infos_dict
-
- def _log_rollout(self, batch, reward_extra_infos_dict, timing_raw):
- # Log rollout generations if enabled
- rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
- if rollout_data_dir:
- with marked_timer("dump_rollout_generations", timing_raw, color="green"):
- inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
- outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
- scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
- sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch]
-
- if "request_id" in batch.non_tensor_batch:
- reward_extra_infos_dict.setdefault(
- "request_id",
- batch.non_tensor_batch["request_id"].tolist(),
- )
-
- self._dump_generations(
- inputs=inputs,
- outputs=outputs,
- gts=sample_gts,
- scores=scores,
- reward_extra_infos_dict=reward_extra_infos_dict,
- dump_path=rollout_data_dir,
- )
-
- def _validate_metrics(self, is_last_step, last_val_metrics, metrics, timing_raw):
- if (
- self.val_reward_fn is not None
- and self.config.trainer.test_freq > 0
- and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
- ):
- with marked_timer("testing", timing_raw, color="green"):
- val_metrics: dict = self._validate()
- if is_last_step:
- last_val_metrics = val_metrics
- metrics.update(val_metrics)
- return last_val_metrics
-
- def _collect_metrics(self, batch, epoch, metrics, timing_raw):
- steps_duration = timing_raw["step"]
- self.max_steps_duration = max(self.max_steps_duration, steps_duration)
-
- # training metrics
- metrics.update(
- {
- "training/global_step": self.global_steps,
- "training/epoch": epoch,
- }
- )
- # collect metrics
- metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
- metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
- # TODO: implement actual tflpo and theoretical tflpo
- n_gpus = self.resource_pool_manager.get_n_gpus()
- metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
-
- def _post_batch_processing(self, batch: DataProto):
- # this is experimental and may be changed/removed in the future in favor of a general-purpose one
- if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
- self.train_dataloader.sampler.update(batch=batch)
-
- # this is experimental and may be changed/removed in the future
- # in favor of a general-purpose data buffer pool
- if hasattr(self.train_dataset, "on_batch_end"):
- # The dataset may be changed after each training batch
- self.train_dataset.on_batch_end(batch=batch)
diff --git a/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py b/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py
index 0830ed2abd3..8aad1199146 100644
--- a/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py
+++ b/verl/experimental/fully_async_policy/sglang_rollout/sglang_async_server.py
@@ -171,11 +171,6 @@ async def resume(self):
async with self.lock:
self.paused = False
- async def reset_prefix_cache(self):
- async with self.lock:
- print("Reset prefix cache ...")
- await self.tokenizer_manager.flush_cache()
-
class FullyAsyncSGLangReplica(SGLangReplica):
def __init__(
@@ -196,7 +191,3 @@ async def cancel(self):
async def resume(self):
"""Resume each rollout server."""
await asyncio.gather(*[server.resume.remote() for server in self.servers])
-
- async def reset_prefix_cache(self):
- """reset kv cache in each rollout server."""
- await asyncio.gather(*[server.reset_prefix_cache.remote() for server in self.servers])
diff --git a/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh b/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh
index 09b22145e26..1b50839f896 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh
@@ -187,5 +187,4 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
async_training.require_batches=${require_batches} \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True
+ async_training.partial_rollout="${partial_rollout}"
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh
index b11705d8eca..7af3748dec9 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_async_retool.sh
@@ -137,5 +137,4 @@ python3 -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold=$staleness_threshold \
async_training.trigger_parameter_sync_step=$trigger_parameter_sync_step \
async_training.require_batches=$require_batches \
- async_training.partial_rollout=$partial_rollout \
- async_training.use_rollout_log_probs=True
\ No newline at end of file
+ async_training.partial_rollout=$partial_rollout
\ No newline at end of file
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh
index 59c83b166b6..9b37bea9dfc 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh
@@ -158,5 +158,4 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True
\ No newline at end of file
+ async_training.partial_rollout="${partial_rollout}"
\ No newline at end of file
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh
index 7203652da41..87c6cc4ceb7 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh
@@ -158,5 +158,4 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True
+ async_training.partial_rollout="${partial_rollout}"
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh
index 300cc4551db..afd9bada3c2 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh
@@ -160,5 +160,4 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True
\ No newline at end of file
+ async_training.partial_rollout="${partial_rollout}"
\ No newline at end of file
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh
index 2dd0adc0ef7..2f8a7fc315f 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh
@@ -160,5 +160,4 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True
\ No newline at end of file
+ async_training.partial_rollout="${partial_rollout}"
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh
index 6c8341691a8..0891627a901 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh
@@ -158,5 +158,4 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True
+ async_training.partial_rollout="${partial_rollout}"
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh
index 70237d8725a..be74b657964 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh
@@ -165,8 +165,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True \
- async_training.compute_prox_log_prob=True \
+ algorithm.rollout_correction.bypass_mode=False \
algorithm.rollout_correction.rollout_is=${rollout_is} \
algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \
algorithm.rollout_correction.rollout_rs=${rollout_rs} \
diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh
index ec107948395..b584e9dba0d 100644
--- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh
+++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh
@@ -158,5 +158,4 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True
\ No newline at end of file
+ async_training.partial_rollout="${partial_rollout}"
\ No newline at end of file
diff --git a/verl/experimental/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh b/verl/experimental/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh
index 251c0ae840a..8b32c6e0078 100644
--- a/verl/experimental/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh
+++ b/verl/experimental/fully_async_policy/shell/geo3k_qwen25vl_7b_megatron_4_4.sh
@@ -107,5 +107,4 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True
\ No newline at end of file
+ async_training.partial_rollout="${partial_rollout}"
\ No newline at end of file
diff --git a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh
index bb25144481e..4d96718a352 100644
--- a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh
+++ b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh
@@ -225,6 +225,5 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True \
+ async_training.partial_rollout="${partial_rollout}"
diff --git a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh
index ed0716e8c24..29285faa71f 100644
--- a/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh
+++ b/verl/experimental/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh
@@ -235,5 +235,4 @@ python -m verl.experimental.fully_async_policy.fully_async_main \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
- async_training.partial_rollout="${partial_rollout}" \
- async_training.use_rollout_log_probs=True \
+ async_training.partial_rollout="${partial_rollout}"
\ No newline at end of file
diff --git a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml
index 3aea4e4c94d..cb2f8c2054c 100644
--- a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml
+++ b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_megatron_trainer.yaml
@@ -20,6 +20,9 @@ actor_rollout_ref:
free_cache_engine: False
# Must be enabled! Otherwise, log_probs cannot be calculated.
calculate_log_probs: True
+ # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced.
+ # TODO: Can be removed in the future once parameter synchronization is ready.
+ load_format: "auto"
# Only then will the use of log probs be correct.
# And it can be used in conjunction with other rollout_correction algorithms.
diff --git a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml
index 4c4deb485e1..012745e2aa3 100644
--- a/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml
+++ b/verl/experimental/one_step_off_policy/config/one_step_off_ppo_trainer.yaml
@@ -20,6 +20,9 @@ actor_rollout_ref:
free_cache_engine: False
# Must be enabled! Otherwise, log_probs cannot be calculated.
calculate_log_probs: True
+ # Set to auto mode to prevent incorrect rollout outputs when parameters are not synced.
+ # TODO: Can be removed in the future once parameter synchronization is ready.
+ load_format: "auto"
# Only then will the use of log probs be correct.
# And it can be used in conjunction with other rollout_correction algorithms.
diff --git a/verl/experimental/one_step_off_policy/main_ppo.py b/verl/experimental/one_step_off_policy/main_ppo.py
index d19c40ffbe2..4ef6fb38f22 100644
--- a/verl/experimental/one_step_off_policy/main_ppo.py
+++ b/verl/experimental/one_step_off_policy/main_ppo.py
@@ -170,12 +170,24 @@ def run(self, config):
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
# Load the reward manager for training and validation.
- reward_fn = load_reward_manager(
- config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
- )
- val_reward_fn = load_reward_manager(
- config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
- )
+ use_reward_loop = config.reward_model.use_reward_loop
+ if not use_reward_loop:
+ print(
+ "WARNING: Init reward manager in single controller will be deprecated. "
+ "Please set config.reward_model.use_reward_loop to use distributed reward manager."
+ )
+ # Load the reward manager for training and validation.
+ reward_fn = load_reward_manager(
+ config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
+ )
+ val_reward_fn = load_reward_manager(
+ config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
+ )
+ else:
+ # reward_loop will use init a reward loop manager in ray_trainer
+ # and use it to compute reward score
+ reward_fn = None
+ val_reward_fn = None
resource_pool_manager = create_resource_pool_manager(config, role_worker_mapping.keys())
diff --git a/verl/experimental/one_step_off_policy/ray_trainer.py b/verl/experimental/one_step_off_policy/ray_trainer.py
index a905ca510cd..55b49705f6a 100644
--- a/verl/experimental/one_step_off_policy/ray_trainer.py
+++ b/verl/experimental/one_step_off_policy/ray_trainer.py
@@ -21,6 +21,7 @@
import asyncio
import uuid
from pprint import pprint
+from typing import Optional
import numpy as np
import ray
@@ -31,46 +32,36 @@
from tqdm import tqdm
from verl import DataProto
-from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.experimental.one_step_off_policy.utils import need_critic
+from verl.experimental.separation.ray_trainer import SeparateRayPPOTrainer
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
-from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo import core_algos
-from verl.trainer.ppo.core_algos import agg_loss
-from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics
from verl.trainer.ppo.ray_trainer import (
- RayPPOTrainer,
ResourcePoolManager,
- apply_kl_penalty,
- compute_advantage,
compute_response_mask,
)
-from verl.trainer.ppo.reward import compute_reward, compute_reward_async
+from verl.trainer.ppo.reward import compute_reward_async
from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
-from verl.utils import omega_conf_to_dataclass
-from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
from verl.utils.debug import marked_timer
-from verl.utils.metric import reduce_metrics
+from verl.utils.rollout_skip import RolloutSkip
from verl.utils.tracking import ValidationGenerationsLogger
-class OneStepOffRayTrainer(RayPPOTrainer):
- # TODO: support each role have individual ray_worker_group_cls,
- # i.e., support different backend of different role
+class OneStepOffRayTrainer(SeparateRayPPOTrainer):
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
- ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
+ ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup,
processor=None,
reward_fn=None,
val_reward_fn=None,
- train_dataset: Dataset | None = None,
- val_dataset: Dataset | None = None,
+ train_dataset: Optional[Dataset] = None,
+ val_dataset: Optional[Dataset] = None,
collate_fn=None,
- train_sampler: Sampler | None = None,
+ train_sampler: Optional[Sampler] = None,
device_name=None,
):
"""
@@ -90,7 +81,7 @@ def __init__(
val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.
collate_fn: Function to collate data samples into batches.
train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.
- device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda".
+ device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None.
"""
# Store the tokenizer for text processing
@@ -101,61 +92,64 @@ def __init__(
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
-
assert not self.hybrid_engine
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = need_reference_policy(self.config)
+
self.use_rm = need_reward_model(self.role_worker_mapping)
- self.use_critic = need_critic(config)
+ self.use_reward_loop = self.config.reward_model.use_reward_loop
+
+ self.use_critic = need_critic(self.config)
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name if device_name else self.config.trainer.device
- self.validation_generations_logger = ValidationGenerationsLogger()
+ self.validation_generations_logger = ValidationGenerationsLogger(
+ project_name=self.config.trainer.project_name,
+ experiment_name=self.config.trainer.experiment_name,
+ )
+ # if ref_in_actor is True, the reference policy will be actor without lora applied
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
if lora_rank <= 0:
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
- # if ref_in_actor is True, the reference policy will be actor without lora applied
- self.ref_in_actor = lora_rank > 0
+ self.ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None
# define in-reward KL control
# kl loss control currently not suppoorted
- if config.algorithm.use_kl_in_reward:
- self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
+ if self.config.algorithm.use_kl_in_reward:
+ self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
+
+ self.use_prefix_grouper = self.config.actor_rollout_ref.actor.get("use_prefix_grouper", False)
+ self.use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
+ # ==================== SeparateRayPPOTrainer config ====================
+
+ self.global_steps = 0
+ self.epoch = 0
+ self.max_steps_duration = 0
+ self.progress_bar = None
+ self.logger = None
+ self.is_last_step = False
+ self.prev_step_profile = False
+ self.curr_step_profile = False
+ self.next_step_profile = False
+ self.last_val_metrics = {}
+ self.metrics = {}
+ self.timing_raw = {}
+ # reward message
+ self.future_reward = None
+ self.reward_tensor = None
+ self.reward_extra_infos_dict = {}
+
def _validate(self):
self.actor_rollout_wg = self.rollout_wg
ret = super()._validate()
self.actor_rollout_wg = self.actor_wg
return ret
- def init_workers(self):
- """Initialize distributed training workers using Ray backend.
-
- Creates:
- 1. Ray resource pools from configuration
- 2. Worker groups for each role (actor, critic, etc.)
- """
- self._init_resource_pools()
- self._create_worker_classes()
- self._init_worker_groups()
- self._init_models()
- self._init_async_rollout_manager()
-
- def _init_resource_pools(self):
- self.resource_pool_manager.create_resource_pool()
-
- self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
-
- def _create_worker_classes(self):
- self._create_actor_rollout_classes()
- self._create_critic_class()
- self._create_reference_policy_class()
- self._create_reward_model_class()
-
def _create_actor_rollout_classes(self):
for role in [Role.Actor, Role.Rollout]:
resource_pool = self.resource_pool_manager.get_resource_pool(role)
@@ -166,68 +160,6 @@ def _create_actor_rollout_classes(self):
)
self.resource_pool_to_cls[resource_pool][str(role)] = role_cls
- def _create_critic_class(self):
- # create critic
- if self.use_critic:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
- critic_cfg = omega_conf_to_dataclass(self.config.critic)
- critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
- self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
-
- def _create_reference_policy_class(self):
- # create reference policy if needed
- if self.use_reference_policy:
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
- ref_policy_cls = RayClassWithInitArgs(
- self.role_worker_mapping[Role.RefPolicy],
- config=self.config.actor_rollout_ref,
- role=str(Role.RefPolicy),
- # profile_option=self.config.trainer.npu_profile.options,
- )
- self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
-
- def _create_reward_model_class(self):
- # create a reward model if reward_fn is None
- if self.use_rm:
- # we create a RM here
- resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
- rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
- self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
-
- def _init_worker_groups(self):
- # initialize WorkerGroup
- # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
- # you should not use `create_colocated_worker_cls`.
- # Instead, directly pass different resource pool to different worker groups.
- # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
- all_wg = {}
- wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
- if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
- wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
- if OmegaConf.select(self.config.global_profiler, "steps") is not None:
- wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
- # Only require nsight worker options when tool is nsys
- if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
- assert (
- OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
- is not None
- ), "worker_nsight_options must be set when using nsys with profile_steps"
- wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
- OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
- )
- wg_kwargs["device_name"] = self.device_name
-
- for resource_pool, class_dict in self.resource_pool_to_cls.items():
- worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
- wg_dict = self.ray_worker_group_cls(
- resource_pool=resource_pool,
- ray_cls_with_init=worker_dict_cls,
- **wg_kwargs,
- )
- spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
- all_wg.update(spawn_wg)
- self.all_wg = all_wg
-
def _init_models(self):
if self.use_critic:
self.critic_wg = self.all_wg[str(Role.Critic)]
@@ -251,6 +183,26 @@ def _init_models(self):
self.rollout_wg.set_actor_weights_info(weights_info)
self._create_weight_sync_group()
+ def _init_async_rollout_manager(self):
+ # infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design
+ # agent_reward_loop: streaming reward computation with actor rollout
+ # two conditions satisfied: (1) no reward model, or (2) reward model with extra resource pool
+ enable_agent_reward_loop = self.use_reward_loop and (
+ not self.use_rm or self.config.reward_model.enable_resource_pool
+ )
+ # if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager
+ # to stream reward computation with actor rollout
+ reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None
+
+ # create async rollout manager and request scheduler
+ assert self.config.actor_rollout_ref.rollout.mode == "async"
+ from verl.experimental.one_step_off_policy.agent_loop import OneStepOffAgentLoopManager
+
+ self.async_rollout_mode = True
+ self.async_rollout_manager = OneStepOffAgentLoopManager(
+ config=self.config, worker_group=self.rollout_wg, reward_loop_worker_handles=reward_loop_worker_handles
+ )
+
def _create_weight_sync_group(self):
from verl.utils.device import get_nccl_backend
@@ -284,15 +236,6 @@ def _create_weight_sync_group(self):
group_name="actor_rollout",
)
- def _init_async_rollout_manager(self):
- # create async rollout manager and request scheduler
- assert self.config.actor_rollout_ref.rollout.mode == "async"
- from verl.experimental.one_step_off_policy.agent_loop import OneStepOffAgentLoopManager
-
- self.async_rollout_mode = True
-
- self.async_rollout_manager = OneStepOffAgentLoopManager(config=self.config, worker_group=self.rollout_wg)
-
def sync_rollout_weights(self):
self.actor_wg.sync_rollout_weights()
ray.get(self.rollout_wg.sync_rollout_weights())
@@ -410,11 +353,9 @@ async def fit(self):
The light-weight advantage computation is done on the driver process.
"""
- from omegaconf import OmegaConf
-
from verl.utils.tracking import Tracking
- logger = Tracking(
+ self.logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
@@ -423,339 +364,140 @@ async def fit(self):
self.global_steps = 0
- # load checkpoint before doing anything
+ # load checkpoint and update weights before doing anything
self._load_checkpoint()
-
- # after load checkpoint sync rollout weights
- self.sync_rollout_weights()
- await self.async_rollout_manager.clear_kv_cache()
+ self._fit_update_weights()
# perform validation before training
# currently, we only support validation using the reward_function.
- if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
+ if self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
- logger.log(data=val_metrics, step=self.global_steps)
+ self.logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
+ if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
+ rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
+ rollout_skip.wrap_generate_sequences()
+
# add tqdm
- progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
+ self.progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
- last_val_metrics = None
+ self.last_val_metrics = None
self.max_steps_duration = 0
- prev_step_profile = False
- curr_step_profile = (
+ self.prev_step_profile = False
+ self.curr_step_profile = (
self.global_steps in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
+ self.next_step_profile = False
# across epoch iterator
continuous_iterator = self._create_continuous_iterator()
-
# Start the first asynchronous generation task.
batch_data_future = asyncio.create_task(self._async_gen_next_batch(continuous_iterator))
-
while batch_data_future is not None:
- do_profile = (
- self.global_steps in self.config.global_profiler.steps
- if self.config.global_profiler.steps is not None
- else False
- )
- if do_profile:
- self.actor_wg.start_profile()
- if not self.hybrid_engine:
- self.rollout_wg.start_profile()
- if self.use_reference_policy:
- self.ref_policy_wg.start_profile()
- if self.use_critic:
- self.critic_wg.start_profile()
- if self.use_rm:
- self.rm_wg.start_profile()
-
- metrics = {}
- timing_raw = {}
- is_last_step = self.global_steps >= self.total_training_steps
-
- with marked_timer("start_profile", timing_raw):
- self._start_profiling(
- not prev_step_profile and curr_step_profile
- if self.config.global_profiler.profile_continuous_steps
- else curr_step_profile
- )
+ batch_data_future = await self.fit_step(batch_data_future, continuous_iterator)
+ if self.is_last_step:
+ return
+
+ async def fit_step(self, batch_data_future, continuous_iterator):
+ """
+ Single-step training template method. Handles all logic for one training step.
- with marked_timer("step", timing_raw):
- # wait for the previous batch
- with marked_timer("gen", timing_raw, color="red"):
- _metrics, _timing_raw, epoch, batch, future_reward = await batch_data_future
- timing_raw.update(batch.meta_info["timing"])
- timing_raw.update(_timing_raw)
- metrics.update(_metrics)
- batch.meta_info.pop("timing", None)
-
- # sync weights from actor to rollout
- with marked_timer("sync_rollout_weights", timing_raw, color="purple"):
- self.sync_rollout_weights()
- await self.async_rollout_manager.clear_kv_cache()
-
- # async next generation
- if not is_last_step:
- batch_data_future = asyncio.create_task(self._async_gen_next_batch(continuous_iterator))
- await asyncio.sleep(0)
-
- with marked_timer("reward", timing_raw, color="yellow"):
- # compute reward model score
- if self.use_rm and "rm_scores" not in batch.batch.keys():
- reward_tensor = self.rm_wg.compute_rm_score(batch)
- batch = batch.union(reward_tensor)
-
- if self.config.reward_model.launch_reward_fn_async:
- future_reward = compute_reward_async.remote(
- data=batch, config=self.config, tokenizer=self.tokenizer
- )
- else:
- reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
-
- # await asyncio.sleep(0) ensures:
- # Asynchronous tasks can start executing immediately
- # The event loop can handle other pending coroutines
- # Prevents computations in a certain phase from blocking the entire asynchronous workflow
- #
- # The purpose here is to ensure that after triggering
- # `self.async_rollout_manager.generate_sequences_async(gen_batch_output)`,
- # the subsequent relevant logic can proceed in a timely manner
- await asyncio.sleep(0)
-
- # Operating Mode Selection:
- # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ)
- # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ)
- # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates
- rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
- bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False)
- if bypass_recomputing_logprobs: # Use `rollout_log_probs`
- from verl.trainer.ppo.rollout_corr_helper import apply_bypass_mode
-
- apply_bypass_mode(
- batch=batch,
- rollout_corr_config=rollout_corr_config,
- policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss,
- )
- else: # Recompute old_log_probs
- with marked_timer("old_log_prob", timing_raw, color="blue"):
- old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
- entropys = old_log_prob.batch["entropys"]
- response_masks = batch.batch["response_mask"]
- actor_config = self.config.actor_rollout_ref.actor
- entropy_agg = agg_loss(
- loss_mat=entropys,
- loss_mask=response_masks,
- loss_agg_mode=actor_config.loss_agg_mode,
- loss_scale_factor=actor_config.loss_scale_factor,
- )
- old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
- metrics.update(old_log_prob_metrics)
- old_log_prob.batch.pop("entropys")
- batch = batch.union(old_log_prob)
- if "rollout_log_probs" in batch.batch.keys():
- # TODO: we may want to add diff of probs too.
- from verl.utils.debug.metrics import calculate_debug_metrics
-
- metrics.update(calculate_debug_metrics(batch))
-
- assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}'
- await asyncio.sleep(0)
-
- if self.use_reference_policy:
- # compute reference log_prob
- with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
- if not self.ref_in_actor:
- ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
- else:
- ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
- batch = batch.union(ref_log_prob)
- await asyncio.sleep(0)
-
- # compute values
- if self.use_critic:
- with marked_timer("values", timing_raw, color="cyan"):
- values = self.critic_wg.compute_values(batch)
- batch = batch.union(values)
- await asyncio.sleep(0)
-
- with marked_timer("adv", timing_raw, color="brown"):
- # we combine with rule-based rm
- reward_extra_infos_dict: dict[str, list]
- if self.config.reward_model.launch_reward_fn_async:
- reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
- batch.batch["token_level_scores"] = reward_tensor
-
- if reward_extra_infos_dict:
- batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
-
- # compute rewards. apply_kl_penalty if available
- if self.config.algorithm.use_kl_in_reward:
- batch, kl_metrics = apply_kl_penalty(
- batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
- )
- metrics.update(kl_metrics)
- else:
- batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
-
- # Compute rollout correction: IS weights, rejection sampling, and metrics
- # Only runs in decoupled mode (computes once per batch using stable π_old)
- # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout
- if (
- rollout_corr_config is not None
- and "rollout_log_probs" in batch.batch
- and not bypass_recomputing_logprobs # Only in decoupled mode
- ):
- from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch
-
- # Compute IS weights, apply rejection sampling, compute metrics
- batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
- # IS and off-policy metrics already have rollout_corr/ prefix
- metrics.update(is_metrics)
-
- # compute advantages, executed on the driver process
- norm_adv_by_std_in_grpo = self.config.algorithm.get(
- "norm_adv_by_std_in_grpo", True
- ) # GRPO adv normalization factor
-
- batch = compute_advantage(
- batch,
- adv_estimator=self.config.algorithm.adv_estimator,
- gamma=self.config.algorithm.gamma,
- lam=self.config.algorithm.lam,
- num_repeat=self.config.actor_rollout_ref.rollout.n,
- norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
- config=self.config.algorithm,
- )
- await asyncio.sleep(0)
-
- # update critic
- if self.use_critic:
- with marked_timer("update_critic", timing_raw, color="pink"):
- critic_output = self.critic_wg.update_critic(batch)
- critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
- metrics.update(critic_output_metrics)
- await asyncio.sleep(0)
-
- # implement critic warmup
- if self.config.trainer.critic_warmup <= self.global_steps:
- # update actor
- with marked_timer("update_actor", timing_raw, color="red"):
- rollout_config = self.config.actor_rollout_ref.rollout
- batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
- # TODO: Make "temperature" single source of truth from generation.
- batch.meta_info["temperature"] = rollout_config.temperature
- actor_output = self.actor_rollout_wg.update_actor(batch)
- actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
- metrics.update(actor_output_metrics)
- await asyncio.sleep(0)
-
- # Log rollout generations if enabled
- rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
- if rollout_data_dir:
- self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
+ Flow:
+ 1. Pre-step processing -> 2. Get batch -> 3. Generate sequences ->
+ 4. Compute reward -> 5. Compute log_prob -> 6. Compute reward ->
+ 7. Compute advantage -> 8. Update critic -> 9. Update actor -> 10. Post-step processing
+ Args:
+ batch_data_future: batch future
+ """
+ self.metrics = {"training/global_step": self.global_steps, "training/epoch": self.epoch}
+ self.timing_raw = {}
+ # reward message
+ self.future_reward = None
+ self.reward_tensor = None
+ self.reward_extra_infos_dict = {}
+
+ self._fit_prepare_step()
+ self._fit_start_profile()
+
+ with marked_timer("step", self.timing_raw):
+ batch, batch_data_future = await self._fit_generate(batch_data_future, continuous_iterator)
+
+ # await asyncio.sleep(0) ensures:
+ # Asynchronous tasks can start executing immediately
+ # The event loop can handle other pending coroutines
+ # Prevents computations in a certain phase from blocking the entire asynchronous workflow
+ #
+ # The purpose here is to ensure that after triggering
+ # `self.async_rollout_manager.generate_sequences_async(gen_batch_output)`,
+ # the subsequent relevant logic can proceed in a timely manner
+ await asyncio.sleep(0)
+ batch = self._fit_compute_reward(batch)
+ await asyncio.sleep(0)
+ batch = self._fit_compute_log_prob(batch)
+ await asyncio.sleep(0)
+ batch = self._fit_compute_ref_log_prob(batch)
+ await asyncio.sleep(0)
+ batch = self._fit_compute_critic(batch)
+ await asyncio.sleep(0)
+ batch = self._fit_compute_advantage(batch)
+ await asyncio.sleep(0)
+ batch = self._fit_update_critic(batch)
+ await asyncio.sleep(0)
+ batch = self._fit_update_actor(batch)
await asyncio.sleep(0)
- # validate
- if (
- self.val_reward_fn is not None
- and self.config.trainer.test_freq > 0
- and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
- ):
- with marked_timer("testing", timing_raw, color="green"):
- val_metrics: dict = self._validate()
- if is_last_step:
- last_val_metrics = val_metrics
- metrics.update(val_metrics)
+ self._fit_update_weights()
+ await asyncio.sleep(0)
+ self._fit_dump_data(batch)
await asyncio.sleep(0)
- # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
- esi_close_to_expiration = should_save_ckpt_esi(
- max_steps_duration=self.max_steps_duration,
- redundant_time=self.config.trainer.esi_redundant_time,
- )
- # Check if the conditions for saving a checkpoint are met.
- # The conditions include a mandatory condition (1) and
- # one of the following optional conditions (2/3/4):
- # 1. The save frequency is set to a positive value.
- # 2. It's the last training step.
- # 3. The current step number is a multiple of the save frequency.
- # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
- if self.config.trainer.save_freq > 0 and (
- is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
- ):
- if esi_close_to_expiration:
- print("Force saving checkpoint: ESI instance expiration approaching.")
- with marked_timer("save_checkpoint", timing_raw, color="green"):
- self._save_checkpoint()
-
- with marked_timer("stop_profile", timing_raw):
- next_step_profile = (
- self.global_steps + 1 in self.config.global_profiler.steps
- if self.config.global_profiler.steps is not None
- else False
- )
- self._stop_profiling(
- curr_step_profile and not next_step_profile
- if self.config.global_profiler.profile_continuous_steps
- else curr_step_profile
- )
- prev_step_profile = curr_step_profile
- curr_step_profile = next_step_profile
-
- steps_duration = timing_raw["step"]
- self.max_steps_duration = max(self.max_steps_duration, steps_duration)
-
- # training metrics
- metrics.update(
- {
- "training/global_step": self.global_steps,
- "training/epoch": epoch,
- }
- )
- # collect metrics
- metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
- metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
- # TODO: implement actual tflpo and theoretical tflpo
- n_gpus = self.resource_pool_manager.get_n_gpus()
- metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
- # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation
-
- # this is experimental and may be changed/removed in the future in favor of a general-purpose one
- if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
- self.train_dataloader.sampler.update(batch=batch)
-
- # TODO: make a canonical logger that supports various backend
- logger.log(data=metrics, step=self.global_steps)
-
- progress_bar.update(1)
- self.global_steps += 1
-
- if (
- hasattr(self.config.actor_rollout_ref.actor, "profiler")
- and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
- ):
- self.actor_rollout_wg.dump_memory_snapshot(
- tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
- )
+ self._fit_validate()
+ await asyncio.sleep(0)
+ self._fit_save_checkpoint()
+ await asyncio.sleep(0)
+ self._fit_stop_profile()
+ self._fit_collect_metrics(batch)
+ self._fit_torch_memory()
+ self._fit_experimental(batch)
+ self._fit_postprocess_step()
+
+ return batch_data_future
+
+ async def _fit_generate(self, batch_data_future, continuous_iterator):
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+
+ with marked_timer("gen", timing_raw, color="red"):
+ _metrics, _timing_raw, epoch, batch, future_reward = await batch_data_future
+ timing_raw.update(batch.meta_info["timing"])
+ timing_raw.update(_timing_raw)
+ metrics.update(_metrics)
+ batch.meta_info.pop("timing", None)
+
+ # sync weights from actor to rollout
+ with marked_timer("sync_rollout_weights", timing_raw, color="purple"):
+ self._fit_update_weights()
+ await self.async_rollout_manager.clear_kv_cache()
+
+ # async next generation
+ if not self.is_last_step:
+ batch_data_future = asyncio.create_task(self._async_gen_next_batch(continuous_iterator))
+ await asyncio.sleep(0)
+ else:
+ batch_data_future = None
- if is_last_step:
- if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"):
- self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True)
- pprint(f"Final validation metrics: {last_val_metrics}")
- progress_bar.close()
- return
+ return batch, batch_data_future
- # this is experimental and may be changed/removed in the future
- # in favor of a general-purpose data buffer pool
- if hasattr(self.train_dataset, "on_batch_end"):
- # The dataset may be changed after each training batch
- self.train_dataset.on_batch_end(batch=batch)
+ def _fit_update_weights(self):
+ # TODO: use checkpoint engine to update weight
+ # self.sync_rollout_weights()
+ pass
diff --git a/verl/experimental/separation/__init__.py b/verl/experimental/separation/__init__.py
new file mode 100644
index 00000000000..9cd3ed5b8e9
--- /dev/null
+++ b/verl/experimental/separation/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Meituan Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/verl/experimental/separation/ray_trainer.py b/verl/experimental/separation/ray_trainer.py
new file mode 100644
index 00000000000..9120da9303e
--- /dev/null
+++ b/verl/experimental/separation/ray_trainer.py
@@ -0,0 +1,746 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+# Copyright 2023-2024 SGLang Team
+# Copyright 2025 ModelBest Inc. and/or its affiliates
+# Copyright 2025 Meituan Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+PPO Trainer with Ray-based single controller.
+This trainer supports model-agonistic model initialization with huggingface
+"""
+
+import uuid
+from copy import deepcopy
+from pprint import pprint
+from typing import Any, Optional
+
+import numpy as np
+import ray
+import torch
+from omegaconf import OmegaConf
+from torch.utils.data import Dataset, Sampler
+from tqdm import tqdm
+
+from verl import DataProto
+from verl.checkpoint_engine import CheckpointEngineManager
+from verl.experimental.dataset.sampler import AbstractCurriculumSampler
+from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager
+from verl.single_controller.ray.base import create_colocated_worker_cls
+from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
+from verl.trainer.ppo.metric_utils import (
+ compute_data_metrics,
+ compute_throughout_metrics,
+ compute_timing_metrics,
+ compute_variance_proxy_metrics,
+)
+from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
+from verl.trainer.ppo.reward import compute_reward_async
+from verl.trainer.ppo.utils import Role, WorkerType
+from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
+from verl.utils.config import omega_conf_to_dataclass
+from verl.utils.debug import marked_timer
+from verl.utils.metric import reduce_metrics
+from verl.utils.rollout_skip import RolloutSkip
+
+
+class SeparateRayPPOTrainer(RayPPOTrainer):
+ """
+ Support for the initialization and fit process of Ray Trainer in the resource-separated scenario:
+ - Fully async policy
+ - One-step off-policy
+ """
+
+ def __init__(
+ self,
+ config,
+ tokenizer,
+ role_worker_mapping: dict[Role, WorkerType],
+ resource_pool_manager: ResourcePoolManager,
+ ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup,
+ processor=None,
+ reward_fn=None,
+ val_reward_fn=None,
+ train_dataset: Optional[Dataset] = None,
+ val_dataset: Optional[Dataset] = None,
+ collate_fn=None,
+ train_sampler: Optional[Sampler] = None,
+ device_name=None,
+ ):
+ super().__init__(
+ config,
+ tokenizer,
+ role_worker_mapping,
+ resource_pool_manager,
+ ray_worker_group_cls,
+ processor,
+ reward_fn,
+ val_reward_fn,
+ train_dataset,
+ val_dataset,
+ collate_fn,
+ train_sampler,
+ device_name,
+ )
+ self.global_steps = 0
+ self.epoch = 0
+ self.max_steps_duration = 0
+ self.progress_bar = None
+ self.logger = None
+ self.is_last_step = False
+ self.prev_step_profile = False
+ self.curr_step_profile = False
+ self.next_step_profile = False
+ self.last_val_metrics = {}
+ self.metrics = {}
+ self.timing_raw = {}
+ # reward message
+ self.future_reward = None
+ self.reward_tensor = None
+ self.reward_extra_infos_dict = {}
+
+ def init_workers(self):
+ """Initialize distributed training workers using Ray backend.
+
+ Creates:
+ 1. Ray resource pools from configuration
+ 2. Worker groups for each role (actor, critic, etc.)
+ """
+ self._init_resource_pools()
+ self._create_worker_classes()
+ self._init_worker_groups()
+ self._init_models()
+ self._init_reward_loop()
+ self._init_async_rollout_manager()
+
+ self.checkpoint_manager = CheckpointEngineManager(
+ backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend,
+ trainer=self.actor_rollout_wg,
+ replicas=self.async_rollout_manager.rollout_replicas,
+ )
+
+ def _init_resource_pools(self):
+ self.resource_pool_manager.create_resource_pool()
+ self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
+
+ def _create_worker_classes(self):
+ self._create_actor_rollout_classes()
+ self._create_critic_class()
+ self._create_reference_policy_class()
+ self._create_reward_model_class()
+
+ def _create_actor_rollout_classes(self):
+ raise NotImplementedError
+
+ def _create_critic_class(self):
+ # create critic
+ if self.use_critic:
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
+ critic_cfg = omega_conf_to_dataclass(self.config.critic)
+ critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
+ self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
+
+ def _create_reference_policy_class(self):
+ # create reference policy if needed
+ if self.use_reference_policy:
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
+ ref_policy_cls = RayClassWithInitArgs(
+ self.role_worker_mapping[Role.RefPolicy],
+ config=self.config.actor_rollout_ref,
+ role=str(Role.RefPolicy),
+ # profile_option=self.config.trainer.npu_profile.options,
+ )
+ self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
+
+ def _create_reward_model_class(self):
+ # create a reward model if reward_fn is None
+ if self.use_rm:
+ # we create a RM here
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
+ rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
+ self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
+
+ def _init_worker_groups(self):
+ # initialize WorkerGroup
+ # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
+ # you should not use `create_colocated_worker_cls`.
+ # Instead, directly pass different resource pool to different worker groups.
+ # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
+ all_wg = {}
+ wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
+ if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
+ wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
+ if OmegaConf.select(self.config.global_profiler, "steps") is not None:
+ wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
+ # Only require nsight worker options when tool is nsys
+ if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
+ assert (
+ OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
+ is not None
+ ), "worker_nsight_options must be set when using nsys with profile_steps"
+ wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
+ OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
+ )
+ wg_kwargs["device_name"] = self.device_name
+
+ for resource_pool, class_dict in self.resource_pool_to_cls.items():
+ worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
+ wg_dict = self.ray_worker_group_cls(
+ resource_pool=resource_pool,
+ ray_cls_with_init=worker_dict_cls,
+ **wg_kwargs,
+ )
+ spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
+ all_wg.update(spawn_wg)
+ self.all_wg = all_wg
+
+ def _init_models(self):
+ if self.use_critic:
+ self.critic_wg = self.all_wg[str(Role.Critic)]
+ self.critic_wg.init_model()
+
+ if self.use_reference_policy and not self.ref_in_actor:
+ self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)]
+ self.ref_policy_wg.init_model()
+
+ if self.use_rm:
+ self.rm_wg = self.all_wg[str(Role.RewardModel)]
+ self.rm_wg.init_model()
+
+ # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
+ self.actor_rollout_wg = self.all_wg[str(Role.ActorRollout)]
+ self.actor_rollout_wg.init_model()
+
+ def _init_reward_loop(self):
+ if self.use_reward_loop:
+ # create reward loop manager
+ if self.use_reward_loop:
+ from verl.experimental.reward_loop import RewardLoopManager
+
+ # initalize reward loop manager
+ # reward model (colocate or standalone): get resource_pool
+ # no reward model: resource_pool = None
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) if self.use_rm else None
+ self.reward_loop_manager = RewardLoopManager(
+ config=self.config,
+ rm_resource_pool=resource_pool,
+ )
+
+ def _init_async_rollout_manager(self):
+ pass
+
+ def fit(self):
+ """
+ The training loop of PPO.
+ The driver process only need to call the compute functions of the worker group through RPC
+ to construct the PPO dataflow.
+ The light-weight advantage computation is done on the driver process.
+
+ !!!
+ The logic of fit is consistent with that of fit_refactor;
+ if any modifications are made, apply them to both methods simultaneously.
+ """
+ from omegaconf import OmegaConf
+
+ from verl.utils.tracking import Tracking
+
+ self.logger = Tracking(
+ project_name=self.config.trainer.project_name,
+ experiment_name=self.config.trainer.experiment_name,
+ default_backend=self.config.trainer.logger,
+ config=OmegaConf.to_container(self.config, resolve=True),
+ )
+
+ self.global_steps = 0
+
+ # load checkpoint and update weights before doing anything
+ self._load_checkpoint()
+ self.checkpoint_manager.update_weights()
+
+ current_epoch = self.global_steps // len(self.train_dataloader)
+
+ # perform validation before training
+ # currently, we only support validation using the reward_function.
+ if self.config.trainer.get("val_before_train", True):
+ val_metrics = self._validate()
+ assert val_metrics, f"{val_metrics=}"
+ pprint(f"Initial validation metrics: {val_metrics}")
+ self.logger.log(data=val_metrics, step=self.global_steps)
+ if self.config.trainer.get("val_only", False):
+ return
+
+ if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
+ rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
+ rollout_skip.wrap_generate_sequences()
+
+ # add tqdm
+ self.progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
+
+ # we start from step 1
+ self.global_steps += 1
+ self.last_val_metrics = None
+ self.max_steps_duration = 0
+
+ self.prev_step_profile = False
+ self.curr_step_profile = (
+ self.global_steps in self.config.global_profiler.steps
+ if self.config.global_profiler.steps is not None
+ else False
+ )
+ self.next_step_profile = False
+
+ for epoch in range(current_epoch, self.config.trainer.total_epochs):
+ for batch_dict in self.train_dataloader:
+ self.epoch = epoch
+ self.fit_step(batch_dict)
+ if self.is_last_step:
+ return
+
+ def fit_step(self, batch_dict: Any = None):
+ """
+ Single-step training template method. Handles all logic for one training step.
+
+ Flow:
+ 1. Pre-step processing -> 2. Get batch -> 3. Generate sequences ->
+ 4. Compute reward -> 5. Compute log_prob -> 6. Compute reward ->
+ 7. Compute advantage -> 8. Update critic -> 9. Update actor -> 10. Post-step processing
+
+ Args:
+ batch_dict: Raw data dictionary
+ """
+ self.metrics = {"training/global_step": self.global_steps, "training/epoch": self.epoch}
+ self.timing_raw = {}
+ # reward message
+ self.future_reward = None
+ self.reward_tensor = None
+ self.reward_extra_infos_dict = {}
+
+ self._fit_prepare_step()
+ self._fit_start_profile()
+
+ with marked_timer("step", self.timing_raw):
+ batch = self._fit_get_batch(batch_dict)
+ batch = self._fit_generate(batch)
+ batch = self._fit_compute_reward(batch)
+ batch = self._fit_compute_log_prob(batch)
+ batch = self._fit_compute_ref_log_prob(batch)
+ batch = self._fit_compute_critic(batch)
+ batch = self._fit_compute_advantage(batch)
+ batch = self._fit_update_critic(batch)
+ batch = self._fit_update_actor(batch)
+ self._fit_update_weights()
+ self._fit_dump_data(batch)
+
+ self._fit_validate()
+ self._fit_save_checkpoint()
+ self._fit_stop_profile()
+ self._fit_collect_metrics(batch)
+ self._fit_torch_memory()
+ self._fit_experimental(batch)
+ self._fit_postprocess_step()
+
+ def _fit_prepare_step(self):
+ if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"):
+ self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False)
+ self.is_last_step = self.global_steps >= self.total_training_steps
+
+ def _fit_start_profile(self):
+ timing_raw = self.timing_raw
+ with marked_timer("start_profile", timing_raw):
+ self._start_profiling(
+ not self.prev_step_profile and self.curr_step_profile
+ if self.config.global_profiler.profile_continuous_steps
+ else self.curr_step_profile
+ )
+
+ def _fit_get_batch(self, batch_dict: dict) -> DataProto:
+ batch = DataProto.from_single_dict(batch_dict)
+ batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
+ # add uid
+ batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
+ return batch
+
+ def _fit_generate(self, batch: DataProto = None) -> DataProto:
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+ gen_batch = self._get_gen_batch(batch)
+ # pass global_steps to trace
+ gen_batch.meta_info["global_steps"] = self.global_steps
+ gen_batch_output = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
+
+ with marked_timer("gen", timing_raw, color="red"):
+ if not self.async_rollout_mode:
+ gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
+ else:
+ if self.curr_step_profile:
+ self.async_rollout_manager.start_profile(global_step=self.global_steps)
+ gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
+ self.checkpoint_manager.sleep_replicas()
+ if self.curr_step_profile:
+ self.async_rollout_manager.stop_profile()
+
+ timing_raw.update(gen_batch_output.meta_info["timing"])
+ gen_batch_output.meta_info.pop("timing", None)
+
+ if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
+ with marked_timer("gen_max", timing_raw, color="purple"):
+ gen_baseline_batch = deepcopy(gen_batch)
+ gen_baseline_batch.meta_info["do_sample"] = False
+ if not self.async_rollout_mode:
+ gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
+ else:
+ if self.curr_step_profile:
+ self.async_rollout_manager.start_profile()
+ gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
+ self.checkpoint_manager.sleep_replicas()
+ if self.curr_step_profile:
+ self.async_rollout_manager.stop_profile()
+ batch = batch.union(gen_baseline_output)
+ # compute reward model score on batch
+ rm_scores = None
+ if self.use_rm and "rm_scores" not in batch.batch.keys():
+ batch_reward = self._compute_reward_colocate(batch)
+ batch = batch.union(batch_reward)
+
+ # Compute or extract reward for REMAX baseline
+ if not self.use_reward_loop:
+ reward_baseline_tensor = self._compute_reward_legacy(
+ batch, reward_fn=self.reward_fn, sum_reward=True
+ )
+ else:
+ reward_baseline_tensor = batch.batch["rm_scores"].sum(dim=-1)
+
+ keys_to_pop = set(gen_baseline_output.batch.keys())
+ if rm_scores is not None:
+ keys_to_pop.update(rm_scores.batch.keys())
+ batch.pop(batch_keys=list(keys_to_pop))
+
+ batch.batch["reward_baselines"] = reward_baseline_tensor
+
+ del rm_scores, gen_baseline_batch, gen_baseline_output
+ # repeat to align with repeated responses in rollout
+ batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
+ batch = batch.union(gen_batch_output)
+
+ if "response_mask" not in batch.batch.keys():
+ batch.batch["response_mask"] = compute_response_mask(batch)
+ # Balance the number of valid tokens across DP ranks.
+ # NOTE: This usually changes the order of data in the `batch`,
+ # which won't affect the advantage calculation (since it's based on uid),
+ # but might affect the loss calculation (due to the change of mini-batching).
+ if self.config.trainer.balance_batch:
+ self._balance_batch(batch, metrics=metrics)
+
+ # compute global_valid tokens
+ batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
+ # get images_seqlens
+ images_seqlens_all = []
+ for multi_modal_input in batch.non_tensor_batch["multi_modal_inputs"]:
+ if "image_grid_thw" not in multi_modal_input.keys():
+ continue
+ images_seqlens_all.extend(multi_modal_input["images_seqlens"].tolist())
+ batch.meta_info["images_seqlens"] = images_seqlens_all
+ return batch
+
+ def _fit_compute_reward(self, batch: DataProto) -> DataProto:
+ timing_raw = self.timing_raw
+ with marked_timer("reward", timing_raw, color="yellow"):
+ # compute reward model score
+ if self.use_rm and "rm_scores" not in batch.batch.keys():
+ batch_reward = self._compute_reward_colocate(batch)
+ batch = batch.union(batch_reward)
+
+ # Compute or extract reward_tensor and reward_extra_infos_dict for training
+ if not self.use_reward_loop:
+ if self.config.reward_model.launch_reward_fn_async:
+ self.future_reward = compute_reward_async.remote(
+ data=batch, config=self.config, tokenizer=self.tokenizer
+ )
+ else:
+ self.reward_tensor, self.reward_extra_infos_dict = self._compute_reward_legacy(
+ batch, reward_fn=self.reward_fn, reward_for_val=False
+ )
+ else:
+ self.reward_tensor = batch.batch["rm_scores"]
+ reward_extra_keys = batch.meta_info.get("reward_extra_keys", [])
+ self.reward_extra_infos_dict = {key: batch.non_tensor_batch[key] for key in reward_extra_keys}
+ return batch
+
+ def _fit_compute_log_prob(self, batch: DataProto) -> DataProto:
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+ # Operating Mode Selection:
+ # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ)
+ # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ)
+ # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates
+ rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
+ bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False)
+ if bypass_recomputing_logprobs: # Use `rollout_log_probs`
+ from verl.trainer.ppo.rollout_corr_helper import apply_bypass_mode
+
+ apply_bypass_mode(
+ batch=batch,
+ rollout_corr_config=rollout_corr_config,
+ policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss,
+ )
+ else: # Recompute old_log_probs
+ with marked_timer("old_log_prob", timing_raw, color="blue"):
+ old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch)
+ entropys = old_log_prob.batch["entropys"]
+ response_masks = batch.batch["response_mask"]
+ actor_config = self.config.actor_rollout_ref.actor
+ entropy_agg = agg_loss(
+ loss_mat=entropys,
+ loss_mask=response_masks,
+ loss_agg_mode=actor_config.loss_agg_mode,
+ loss_scale_factor=actor_config.loss_scale_factor,
+ )
+ old_log_prob_metrics = {
+ "actor/entropy": entropy_agg.detach().item(),
+ "perf/mfu/actor_infer": old_log_prob_mfu,
+ }
+ metrics.update(old_log_prob_metrics)
+ old_log_prob.batch.pop("entropys")
+ if "routed_experts" in batch.batch and "routed_experts" in old_log_prob.batch:
+ router_mode = getattr(self.config.actor_rollout_ref.actor.router_replay, "mode", "disabled")
+ if router_mode == "R2":
+ batch.batch.pop("routed_experts")
+ else:
+ old_log_prob.batch.pop("routed_experts")
+ batch = batch.union(old_log_prob)
+ if "rollout_log_probs" in batch.batch.keys():
+ # TODO: we may want to add diff of probs too.
+ from verl.utils.debug.metrics import calculate_debug_metrics
+
+ metrics.update(calculate_debug_metrics(batch))
+
+ assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}'
+ return batch
+
+ def _fit_compute_ref_log_prob(self, batch: DataProto) -> DataProto:
+ timing_raw = self.timing_raw
+ if self.use_reference_policy:
+ with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
+ ref_log_prob = self._compute_ref_log_prob(batch)
+ batch = batch.union(ref_log_prob)
+ return batch
+
+ def _fit_compute_critic(self, batch: DataProto) -> DataProto:
+ timing_raw = self.timing_raw
+ if self.use_critic:
+ with marked_timer("values", timing_raw, color="cyan"):
+ values = self._compute_values(batch)
+ batch = batch.union(values)
+ return batch
+
+ def _fit_compute_advantage(self, batch) -> DataProto:
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+ future_reward = self.future_reward
+ reward_tensor = self.reward_tensor
+ reward_extra_infos_dict = self.reward_extra_infos_dict
+
+ with marked_timer("adv", timing_raw, color="brown"):
+ # we combine with rule-based rm
+ reward_extra_infos_dict: dict[str, list]
+ if self.config.reward_model.launch_reward_fn_async:
+ reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
+ batch.batch["token_level_scores"] = reward_tensor
+
+ if reward_extra_infos_dict:
+ batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
+
+ # compute rewards. apply_kl_penalty if available
+ if self.config.algorithm.use_kl_in_reward:
+ batch, kl_metrics = apply_kl_penalty(
+ batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
+ )
+ metrics.update(kl_metrics)
+ else:
+ batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
+
+ # Compute rollout correction: IS weights, rejection sampling, and metrics
+ # Only runs in decoupled mode (computes once per batch using stable π_old)
+ # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout
+ rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
+ bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False)
+ if (
+ rollout_corr_config is not None
+ and "rollout_log_probs" in batch.batch
+ and not bypass_recomputing_logprobs # Only in decoupled mode
+ ):
+ from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch
+
+ # Compute IS weights, apply rejection sampling, compute metrics
+ batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
+ # IS and off-policy metrics already have rollout_corr/ prefix
+ metrics.update(is_metrics)
+
+ # compute advantages, executed on the driver process
+ norm_adv_by_std_in_grpo = self.config.algorithm.get(
+ "norm_adv_by_std_in_grpo", True
+ ) # GRPO adv normalization factor
+
+ batch = compute_advantage(
+ batch,
+ adv_estimator=self.config.algorithm.adv_estimator,
+ gamma=self.config.algorithm.gamma,
+ lam=self.config.algorithm.lam,
+ num_repeat=self.config.actor_rollout_ref.rollout.n,
+ norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
+ config=self.config.algorithm,
+ )
+ return batch
+
+ def _fit_update_critic(self, batch: DataProto) -> DataProto:
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+ if self.use_critic:
+ with marked_timer("update_critic", timing_raw, color="pink"):
+ critic_output = self._update_critic(batch)
+ critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
+ metrics.update(critic_output_metrics)
+ return batch
+
+ def _fit_update_actor(self, batch: DataProto) -> DataProto:
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+ # implement critic warmup
+ if self.config.trainer.critic_warmup <= self.global_steps:
+ # update actor
+ with marked_timer("update_actor", timing_raw, color="red"):
+ actor_output = self._update_actor(batch)
+
+ actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
+ metrics.update(actor_output_metrics)
+ return batch
+
+ def _fit_update_weights(self):
+ timing_raw = self.timing_raw
+ if self.config.trainer.critic_warmup <= self.global_steps:
+ # update weights from trainer to rollout
+ with marked_timer("update_weights", timing_raw, color="red"):
+ self.checkpoint_manager.update_weights()
+
+ def _fit_dump_data(self, batch: DataProto):
+ timing_raw = self.timing_raw
+ reward_extra_infos_dict = self.reward_extra_infos_dict
+ # Log rollout generations if enabled
+ rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
+ if rollout_data_dir:
+ self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
+
+ def _fit_validate(self):
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+ if self.config.trainer.test_freq > 0 and (
+ self.is_last_step or self.global_steps % self.config.trainer.test_freq == 0
+ ):
+ with marked_timer("testing", timing_raw, color="green"):
+ val_metrics: dict = self._validate()
+ if self.is_last_step:
+ self.last_val_metrics = val_metrics
+ metrics.update(val_metrics)
+
+ def _fit_save_checkpoint(self):
+ timing_raw = self.timing_raw
+ # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
+ esi_close_to_expiration = should_save_ckpt_esi(
+ max_steps_duration=self.max_steps_duration,
+ redundant_time=self.config.trainer.esi_redundant_time,
+ )
+ # Check if the conditions for saving a checkpoint are met.
+ # The conditions include a mandatory condition (1) and
+ # one of the following optional conditions (2/3/4):
+ # 1. The save frequency is set to a positive value.
+ # 2. It's the last training step.
+ # 3. The current step number is a multiple of the save frequency.
+ # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
+ if self.config.trainer.save_freq > 0 and (
+ self.is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
+ ):
+ if esi_close_to_expiration:
+ print("Force saving checkpoint: ESI instance expiration approaching.")
+ with marked_timer("save_checkpoint", timing_raw, color="green"):
+ # sleep replicas to avoid OOM during checkpoint saving
+ # self.checkpoint_manager.sleep_replicas()
+ self._save_checkpoint()
+ # wake replicas to avoid OOM during checkpoint saving
+ # TODO: Check separation is needed.
+ # self.checkpoint_manager.update_weights()
+
+ def _fit_stop_profile(self):
+ timing_raw = self.timing_raw
+ with marked_timer("stop_profile", timing_raw):
+ self.next_step_profile = (
+ self.global_steps + 1 in self.config.global_profiler.steps
+ if self.config.global_profiler.steps is not None
+ else False
+ )
+ self._stop_profiling(
+ self.curr_step_profile and not self.next_step_profile
+ if self.config.global_profiler.profile_continuous_steps
+ else self.curr_step_profile
+ )
+ self.prev_step_profile = self.curr_step_profile
+ self.curr_step_profile = self.next_step_profile
+
+ def _fit_collect_metrics(self, batch):
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+
+ # collect metrics
+ metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
+ metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
+ # TODO: implement actual tflpo and theoretical tflpo
+ n_gpus = self.resource_pool_manager.get_n_gpus()
+ metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
+ # compute variance proxy metrics
+ gradient_norm = metrics.get("actor/grad_norm", None)
+ metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm))
+
+ def _fit_torch_memory(self):
+ if (
+ hasattr(self.config.actor_rollout_ref.actor, "profiler")
+ and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
+ ):
+ self.actor_rollout_wg.dump_memory_snapshot(
+ tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
+ )
+
+ def _fit_experimental(self, batch):
+ # this is experimental and may be changed/removed in the future in favor of a general-purpose one
+ if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
+ self.train_dataloader.sampler.update(batch=batch)
+
+ # this is experimental and may be changed/removed in the future
+ # in favor of a general-purpose data buffer pool
+ if hasattr(self.train_dataset, "on_batch_end"):
+ # The dataset may be changed after each training batch
+ self.train_dataset.on_batch_end(batch=batch)
+
+ def _fit_postprocess_step(self):
+ metrics = self.metrics
+ timing_raw = self.timing_raw
+
+ steps_duration = timing_raw["step"]
+ self.max_steps_duration = max(self.max_steps_duration, steps_duration)
+
+ # TODO: make a canonical logger that supports various backend
+ self.logger.log(data=metrics, step=self.global_steps)
+ self.progress_bar.update(1)
+ self.global_steps += 1
+ if self.is_last_step:
+ if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"):
+ self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True)
+ pprint(f"Final validation metrics: {self.last_val_metrics}")
+ self.progress_bar.close()
diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py
index da0506b6866..fd160fa86c9 100644
--- a/verl/models/mcore/model_forward.py
+++ b/verl/models/mcore/model_forward.py
@@ -198,7 +198,7 @@ def gptmodel_forward_no_padding(
}
model_kwargs["labels"] = args["label"].contiguous()
model_kwargs["loss_mask"] = args["loss_mask"].contiguous()
- if logits_processor_args and 'loss_mask' in logits_processor_args:
+ if logits_processor_args and "loss_mask" in logits_processor_args:
logits_processor_args.pop("loss_mask")
# For VLM model, need to pass bshd format `input_ids` and `attention_mask`.
@@ -252,7 +252,7 @@ def gptmodel_forward_no_padding(
}
model_kwargs["labels"] = args["label"].contiguous()
model_kwargs["loss_mask"] = args["loss_mask"].contiguous()
- if logits_processor_args and 'loss_mask' in logits_processor_args:
+ if logits_processor_args and "loss_mask" in logits_processor_args:
logits_processor_args.pop("loss_mask")
output_orig = model(
diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py
index a801d711359..ab8ee461dea 100644
--- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py
+++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py
@@ -310,8 +310,8 @@ async def sleep(self):
logger.info("skip sleep in standalone mode")
async def clear_kv_cache(self):
- obj = ReleaseMemoryOccupationReqInput(tags=["kv_cache"])
- await self.tokenizer_manager.release_memory_occupation(obj, None)
+ if self.node_rank == 0:
+ await self.tokenizer_manager.flush_cache()
async def generate(
self,