Skip to content
Merged
13 changes: 12 additions & 1 deletion docs/advance/fully_async.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ can significantly improve training efficiency.
saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for
ongoing tasks to finish during parameter synchronization.

Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop.
Currently, the supported usage mode is megatron/fsdp+vllm. vllm must use the server mode based on AgentLoop.

## Design

Expand Down Expand Up @@ -104,6 +104,7 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
| `async_training.staleness_threshold` | Freshness control |
| `async_training.partial_rollout` | Whether to perform partial_rollout |
| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout |
| `async_training.compute_prox_log_prob` | Whether to compute log_prob using the training model's parameters during the training phase. | |

**Further Explanation:**

Expand Down Expand Up @@ -161,6 +162,16 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
Here, we additionally provide require_batches for streaming distribution and control the number of samples
participating in training at once.

* `async_training.compute_prox_log_prob` (experimental)

During the training process, we observed that metrics and response lengths may become unstable in the later
stages of training. To mitigate this issue, we can use
the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html)
technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using
the training engine, which requires enabling this switch.
Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d
(async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`.

### Supported Modes

1. on policy pipeline:
Expand Down
13 changes: 12 additions & 1 deletion recipe/fully_async_policy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ can significantly improve training efficiency.
saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for
ongoing tasks to finish during parameter synchronization.

Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop.
Currently, the supported usage mode is megatron/fsdp+vllm. vllm must use the server mode based on AgentLoop.

## Design

Expand Down Expand Up @@ -104,6 +104,7 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
| `async_training.staleness_threshold` | Freshness control |
| `async_training.partial_rollout` | Whether to perform partial_rollout |
| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout |
| `async_training.compute_prox_log_prob` | Whether to compute log_prob using the training model's parameters during the training phase. | |

**Further Explanation:**

Expand Down Expand Up @@ -161,6 +162,16 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
Here, we additionally provide require_batches for streaming distribution and control the number of samples
participating in training at once.

* `async_training.compute_prox_log_prob` (experimental)

During the training process, we observed that metrics and response lengths may become unstable in the later
stages of training. To mitigate this issue, we can use
the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html)
technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using
the training engine, which requires enabling this switch.
Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d
(async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`.

### Supported Modes

1. on policy pipeline:
Expand Down
43 changes: 26 additions & 17 deletions recipe/fully_async_policy/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ rollout的训练, 通过合理设置资源分配情况、参数同步频率等
* **PartialRollout**: Rollouter推理过程支持partial rollout逻辑,通过参数同步时,添加`sleep()`和`resume()`
逻辑,保存进行中的rollout的样本,并在下一次rollout中继续使用,减少参数同步等待进行中的任务结束时间。

目前支持使用模式为 fsdp+vllm。vllm必须使用基于AgentLoop的server模式。
目前支持使用模式为 megatron/fsdp+vllm。vllm必须使用基于AgentLoop的server模式。

## 设计

Expand All @@ -65,22 +65,23 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a

### 参数说明

| super params | implication |
|-----------------------------------------------|-----------------------------------------------------------------|
| `trainer.nnodes` | Trainer的node数量 |
| `trainer.n_gpus_per_node` | Trainer每个node上gpu的数量 |
| `rollout.nnodes` | Rollouter的node数量 |
| `rollout.n_gpus_per_node` | Rollouter每个node上gpu的数量 |
| `data.train_batch_size` | 在fully async策略中,该值不生效(默认设置为0) |
| `data.gen_batch_size` | 在fully async策略中,使用流式的样本生产逻辑(默认设置为1) |
| `rollout.total_rollout_steps` | 总的rollout的sample数量 |
| `rollout.test_freq` | Rollouter每更新多少次参数,进行一次validation |
| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
| `async_training.require_batches` | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量 |
| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 |
| `async_training.staleness_threshold` | 新鲜度控制 |
| `async_training.partial_rollout` | 是否进行partial_rollout |
| `async_training.use_rollout_log_probs` | 使用rollout产生的log_probs |
| super params | implication |
|------------------------------------------------------|-----------------------------------------------------------------|
| `trainer.nnodes` | Trainer的node数量 |
| `trainer.n_gpus_per_node` | Trainer每个node上gpu的数量 |
| `rollout.nnodes` | Rollouter的node数量 |
| `rollout.n_gpus_per_node` | Rollouter每个node上gpu的数量 |
| `data.train_batch_size` | 在fully async策略中,该值不生效(默认设置为0) |
| `data.gen_batch_size` | 在fully async策略中,使用流式的样本生产逻辑(默认设置为1) |
| `rollout.total_rollout_steps` | 总的rollout的sample数量 |
| `rollout.test_freq` | Rollouter每更新多少次参数,进行一次validation |
| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
| `async_training.require_batches` | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量 |
| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 |
| `async_training.staleness_threshold` | 新鲜度控制 |
| `async_training.partial_rollout` | 是否进行partial_rollout |
| `async_training.use_rollout_log_probs` | 使用rollout产生的log_probs |
| `async_training.compute_prox_log_prob`(experimental) | 是否在train阶段,使用train模型的参数计算token的 log_prob |

**进一步的解释:**

Expand Down Expand Up @@ -131,6 +132,14 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
在实际测试中,我们发现,如果单次下发的样本较少,由于数据分发的顺序,会导致训练不稳定,response 长度变长。
在这里,我们额外提供 require_batches 进行流式分发,单次参与训练的样本数量控制。

* `async_training.compute_prox_log_prob` (experimental)

我们在训练过程中,观测到随着训练的进行,训练后期指标和response长度可能会出现不稳定的情况,
这里我们可以使用 [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html) 的技术进行
重要性采样,缓解这一问题。为了使用 `Rollout Importance Sampling` 我们需要使用训练引擎使用当前的参数版本计算old_log_prob,此开关需要打开。
此外,在 mode d (async stream pipeline with partial rollout) 的情况下开启 `compute_prox_log_prob` 以及
`Rollout Importance Sampling` 后,我们的实现已近似Areal的 `Decoupled PPO`。

### 模式支持

1. on policy pipeline:
Expand Down
3 changes: 3 additions & 0 deletions recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ async_training:
# Whether to use rollout log probs for training
use_rollout_log_probs: True

# compute_prox_log_prob
compute_prox_log_prob: False

# Rollout config
rollout:

Expand Down
125 changes: 125 additions & 0 deletions recipe/fully_async_policy/fsdp2_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
import torch.distributed as dist
from packaging import version
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec

if version.parse(torch.__version__) < version.parse("2.6"):
raise RuntimeError("PyTorch 2.6 or higher is required to use fstp_utils.")


def fsdp2_sharded_save_to_cpu(
model: torch.nn.Module,
) -> tuple[dict[str, tuple[torch.Tensor, DTensorSpec]], DTensorSpec]:
"""
Sharded Save: Each process only saves the local DTensor shard from its own GPU to CPU memory.

Args:
model: FSDP2-wrapped model whose parameters are of DTensor type.

Returns:
cpu_sharded_state: Dictionary of CPU shards for the current process.
Key = parameter name, Value = (CPU shard tensor, original DTensorSpec)
global_spec: DTensorSpec of the first parameter (used to verify global rules during loading)
"""
cpu_sharded_state = {}
global_spec = None # Record global sharding rules (all parameters follow the same spec)

for param_name, param in model.named_parameters():
# Only process sharded parameters of DTensor type (core parameters of FSDP2)
if not isinstance(param, DTensor):
# Save non-sharded parameters (e.g., running_mean of BatchNorm) as local data
cpu_tensor = param.detach().cpu()
cpu_sharded_state[param_name] = (cpu_tensor, None)
continue

# Record global sharding rules (take spec of the first DTensor to ensure consistency)
if global_spec is None:
global_spec = param._spec
assert hasattr(global_spec, "device_mesh"), "DTensorSpec must contain 'device_mesh' attribute"
assert hasattr(global_spec, "placements"), "DTensorSpec must contain 'placements' attribute"

# 1. Extract local shard data from the current GPU (_local_tensor)
local_gpu_tensor = param._local_tensor # Local shard attribute defined in your DTensor class
# 2. Move to CPU memory and detach from computation graph
local_cpu_tensor = local_gpu_tensor.detach().cpu()
# 3. Save CPU shard + original DTensorSpec (ensure sharding rules remain unchanged)
cpu_sharded_state[param_name] = (local_cpu_tensor, param._spec)

assert global_spec is not None, "No DTensor-type parameters found in the model. FSDP2 sharding may not be enabled."
return cpu_sharded_state, global_spec


def fsdp2_sharded_load_from_cpu(
model: torch.nn.Module,
cpu_sharded_state: dict[str, tuple[torch.Tensor, Optional[DTensorSpec]]],
target_spec: DTensorSpec,
) -> None:
"""
Sharded Load: Each process only loads the CPU shard it is responsible for to the GPU,
keeping sharding rules unchanged.

Args:
model: FSDP2 model to be restored (must have the same structure as when saved)
cpu_sharded_state: Shard data read from CPU memory by the current process
(from fsdp2_sharded_save_to_cpu)
target_spec: Global DTensorSpec from saving (used to verify sharding rule consistency)
"""
# Verify device_mesh consistency (core: ensure loaded shards map to original GPUs)
current_device_mesh = None
for param in model.parameters():
if isinstance(param, DTensor):
current_device_mesh = param._spec.device_mesh
break
assert current_device_mesh is not None, "DTensor parameters not initialized in the model to be loaded"
assert current_device_mesh == target_spec.device_mesh, (
f"device_mesh mismatch during loading! Original: {target_spec.device_mesh}, Current: {current_device_mesh}"
)

for param_name, param in model.named_parameters():
# Skip parameters not in the saved state (e.g., newly added parameters)
if param_name not in cpu_sharded_state:
continue

# Extract CPU shard data and original Spec
local_cpu_tensor, saved_spec = cpu_sharded_state[param_name]

# Handle different parameter types: DTensor sharded parameters vs. regular parameters
if isinstance(param, DTensor):
# 1. Verify sharding rule consistency (placements must match original Spec)
assert saved_spec is not None, f"DTensorSpec missing in saved state for parameter {param_name}"
assert saved_spec.placements == target_spec.placements, (
f"Sharding strategy mismatch for parameter {param_name} (conflicts with global rules)!"
)

# 2. Move CPU shard data to the current GPU (device of param._local_tensor)
target_device = param._local_tensor.device
local_gpu_tensor = local_cpu_tensor.to(target_device)

# 3. Restore to DTensor's local shard (directly copy to _local_tensor, keep spec unchanged)
param._local_tensor.copy_(local_gpu_tensor)

else:
# Regular parameters: load directly to original device
target_device = param.device
param.data.copy_(local_cpu_tensor.to(target_device))

# Process synchronization: ensure all processes complete loading before proceeding
dist.barrier()
18 changes: 18 additions & 0 deletions recipe/fully_async_policy/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from omegaconf import DictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from recipe.fully_async_policy.fsdp2_utils import fsdp2_sharded_load_from_cpu, fsdp2_sharded_save_to_cpu
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.device import (
get_device_name,
Expand Down Expand Up @@ -124,6 +125,23 @@ def get_actor_weights_info(self):
self._weights_info = ret
return ret

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_model_to_cpu(self, n):
if not hasattr(self, "cpu_saved_models"):
self.cpu_saved_models = {}
self.cpu_saved_models[n] = fsdp2_sharded_save_to_cpu(self.actor_module_fsdp)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def restore_model_from_cpu(self, n):
if n in self.cpu_saved_models:
cpu_sharded_state, global_spec = self.cpu_saved_models[n]
fsdp2_sharded_load_from_cpu(self.actor_module_fsdp, cpu_sharded_state, global_spec)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def clear_cpu_model(self, n):
if n in self.cpu_saved_models:
del self.cpu_saved_models[n]


class DetachAsyncRolloutWorker(DetachNcclSync):
def __init__(self, config: DictConfig, role: str):
Expand Down
6 changes: 4 additions & 2 deletions recipe/fully_async_policy/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
# required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples.
self.require_batches = config.async_training.require_batches
self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches
self.compute_prox_log_prob = self.config.async_training.compute_prox_log_prob
total_gpus = (
config.trainer.nnodes * config.trainer.n_gpus_per_node
+ config.rollout.nnodes * config.rollout.n_gpus_per_node
Expand Down Expand Up @@ -257,8 +258,9 @@ def fit(self):
if batch is None:
break
self._collect_metrics_from_samples(batch, metrics)

batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw)
batch, reward_extra_infos_dict = self._process_batch_common(
batch, metrics, timing_raw, self.local_trigger_step if self.compute_prox_log_prob else None
)
self._log_rollout(batch, reward_extra_infos_dict, timing_raw)
self._check_save_checkpoint(False, timing_raw)

Expand Down
Loading