diff --git a/docs/advance/fully_async.md b/docs/advance/fully_async.md index 962d4ffdb16..314c9f324fb 100644 --- a/docs/advance/fully_async.md +++ b/docs/advance/fully_async.md @@ -57,7 +57,7 @@ can significantly improve training efficiency. saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for ongoing tasks to finish during parameter synchronization. -Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop. +Currently, the supported usage mode is megatron/fsdp+vllm. vllm must use the server mode based on AgentLoop. ## Design @@ -104,6 +104,7 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a | `async_training.staleness_threshold` | Freshness control | | `async_training.partial_rollout` | Whether to perform partial_rollout | | `async_training.use_rollout_log_probs` | Use log_probs generated by rollout | +| `async_training.compute_prox_log_prob` | Whether to compute log_prob using the training model's parameters during the training phase. | | **Further Explanation:** @@ -161,6 +162,16 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a Here, we additionally provide require_batches for streaming distribution and control the number of samples participating in training at once. +* `async_training.compute_prox_log_prob` (experimental) + + During the training process, we observed that metrics and response lengths may become unstable in the later + stages of training. To mitigate this issue, we can use + the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html) + technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using + the training engine, which requires enabling this switch. + Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d + (async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`. + ### Supported Modes 1. on policy pipeline: diff --git a/recipe/fully_async_policy/README.md b/recipe/fully_async_policy/README.md index 962d4ffdb16..314c9f324fb 100644 --- a/recipe/fully_async_policy/README.md +++ b/recipe/fully_async_policy/README.md @@ -57,7 +57,7 @@ can significantly improve training efficiency. saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for ongoing tasks to finish during parameter synchronization. -Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop. +Currently, the supported usage mode is megatron/fsdp+vllm. vllm must use the server mode based on AgentLoop. ## Design @@ -104,6 +104,7 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a | `async_training.staleness_threshold` | Freshness control | | `async_training.partial_rollout` | Whether to perform partial_rollout | | `async_training.use_rollout_log_probs` | Use log_probs generated by rollout | +| `async_training.compute_prox_log_prob` | Whether to compute log_prob using the training model's parameters during the training phase. | | **Further Explanation:** @@ -161,6 +162,16 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a Here, we additionally provide require_batches for streaming distribution and control the number of samples participating in training at once. +* `async_training.compute_prox_log_prob` (experimental) + + During the training process, we observed that metrics and response lengths may become unstable in the later + stages of training. To mitigate this issue, we can use + the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html) + technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using + the training engine, which requires enabling this switch. + Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d + (async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`. + ### Supported Modes 1. on policy pipeline: diff --git a/recipe/fully_async_policy/README_zh.md b/recipe/fully_async_policy/README_zh.md index 23b457f14ef..b4fa24759b7 100644 --- a/recipe/fully_async_policy/README_zh.md +++ b/recipe/fully_async_policy/README_zh.md @@ -39,7 +39,7 @@ rollout的训练, 通过合理设置资源分配情况、参数同步频率等 * **PartialRollout**: Rollouter推理过程支持partial rollout逻辑,通过参数同步时,添加`sleep()`和`resume()` 逻辑,保存进行中的rollout的样本,并在下一次rollout中继续使用,减少参数同步等待进行中的任务结束时间。 -目前支持使用模式为 fsdp+vllm。vllm必须使用基于AgentLoop的server模式。 +目前支持使用模式为 megatron/fsdp+vllm。vllm必须使用基于AgentLoop的server模式。 ## 设计 @@ -65,22 +65,23 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a ### 参数说明 -| super params | implication | -|-----------------------------------------------|-----------------------------------------------------------------| -| `trainer.nnodes` | Trainer的node数量 | -| `trainer.n_gpus_per_node` | Trainer每个node上gpu的数量 | -| `rollout.nnodes` | Rollouter的node数量 | -| `rollout.n_gpus_per_node` | Rollouter每个node上gpu的数量 | -| `data.train_batch_size` | 在fully async策略中,该值不生效(默认设置为0) | -| `data.gen_batch_size` | 在fully async策略中,使用流式的样本生产逻辑(默认设置为1) | -| `rollout.total_rollout_steps` | 总的rollout的sample数量 | -| `rollout.test_freq` | Rollouter每更新多少次参数,进行一次validation | -| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus | -| `async_training.require_batches` | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量 | -| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 | -| `async_training.staleness_threshold` | 新鲜度控制 | -| `async_training.partial_rollout` | 是否进行partial_rollout | -| `async_training.use_rollout_log_probs` | 使用rollout产生的log_probs | +| super params | implication | +|------------------------------------------------------|-----------------------------------------------------------------| +| `trainer.nnodes` | Trainer的node数量 | +| `trainer.n_gpus_per_node` | Trainer每个node上gpu的数量 | +| `rollout.nnodes` | Rollouter的node数量 | +| `rollout.n_gpus_per_node` | Rollouter每个node上gpu的数量 | +| `data.train_batch_size` | 在fully async策略中,该值不生效(默认设置为0) | +| `data.gen_batch_size` | 在fully async策略中,使用流式的样本生产逻辑(默认设置为1) | +| `rollout.total_rollout_steps` | 总的rollout的sample数量 | +| `rollout.test_freq` | Rollouter每更新多少次参数,进行一次validation | +| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus | +| `async_training.require_batches` | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量 | +| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 | +| `async_training.staleness_threshold` | 新鲜度控制 | +| `async_training.partial_rollout` | 是否进行partial_rollout | +| `async_training.use_rollout_log_probs` | 使用rollout产生的log_probs | +| `async_training.compute_prox_log_prob`(experimental) | 是否在train阶段,使用train模型的参数计算token的 log_prob | **进一步的解释:** @@ -131,6 +132,14 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a 在实际测试中,我们发现,如果单次下发的样本较少,由于数据分发的顺序,会导致训练不稳定,response 长度变长。 在这里,我们额外提供 require_batches 进行流式分发,单次参与训练的样本数量控制。 +* `async_training.compute_prox_log_prob` (experimental) + + 我们在训练过程中,观测到随着训练的进行,训练后期指标和response长度可能会出现不稳定的情况, + 这里我们可以使用 [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html) 的技术进行 + 重要性采样,缓解这一问题。为了使用 `Rollout Importance Sampling` 我们需要使用训练引擎使用当前的参数版本计算old_log_prob,此开关需要打开。 + 此外,在 mode d (async stream pipeline with partial rollout) 的情况下开启 `compute_prox_log_prob` 以及 + `Rollout Importance Sampling` 后,我们的实现已近似Areal的 `Decoupled PPO`。 + ### 模式支持 1. on policy pipeline: diff --git a/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml b/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml index 4a8b8fc32e7..8ec5a1ba8b5 100644 --- a/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml +++ b/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml @@ -24,6 +24,9 @@ async_training: # Whether to use rollout log probs for training use_rollout_log_probs: True + # compute_prox_log_prob + compute_prox_log_prob: False + # Rollout config rollout: diff --git a/recipe/fully_async_policy/fsdp2_utils.py b/recipe/fully_async_policy/fsdp2_utils.py new file mode 100644 index 00000000000..1f1856596fb --- /dev/null +++ b/recipe/fully_async_policy/fsdp2_utils.py @@ -0,0 +1,125 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.distributed as dist +from packaging import version +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec + +if version.parse(torch.__version__) < version.parse("2.6"): + raise RuntimeError("PyTorch 2.6 or higher is required to use fstp_utils.") + + +def fsdp2_sharded_save_to_cpu( + model: torch.nn.Module, +) -> tuple[dict[str, tuple[torch.Tensor, DTensorSpec]], DTensorSpec]: + """ + Sharded Save: Each process only saves the local DTensor shard from its own GPU to CPU memory. + + Args: + model: FSDP2-wrapped model whose parameters are of DTensor type. + + Returns: + cpu_sharded_state: Dictionary of CPU shards for the current process. + Key = parameter name, Value = (CPU shard tensor, original DTensorSpec) + global_spec: DTensorSpec of the first parameter (used to verify global rules during loading) + """ + cpu_sharded_state = {} + global_spec = None # Record global sharding rules (all parameters follow the same spec) + + for param_name, param in model.named_parameters(): + # Only process sharded parameters of DTensor type (core parameters of FSDP2) + if not isinstance(param, DTensor): + # Save non-sharded parameters (e.g., running_mean of BatchNorm) as local data + cpu_tensor = param.detach().cpu() + cpu_sharded_state[param_name] = (cpu_tensor, None) + continue + + # Record global sharding rules (take spec of the first DTensor to ensure consistency) + if global_spec is None: + global_spec = param._spec + assert hasattr(global_spec, "device_mesh"), "DTensorSpec must contain 'device_mesh' attribute" + assert hasattr(global_spec, "placements"), "DTensorSpec must contain 'placements' attribute" + + # 1. Extract local shard data from the current GPU (_local_tensor) + local_gpu_tensor = param._local_tensor # Local shard attribute defined in your DTensor class + # 2. Move to CPU memory and detach from computation graph + local_cpu_tensor = local_gpu_tensor.detach().cpu() + # 3. Save CPU shard + original DTensorSpec (ensure sharding rules remain unchanged) + cpu_sharded_state[param_name] = (local_cpu_tensor, param._spec) + + assert global_spec is not None, "No DTensor-type parameters found in the model. FSDP2 sharding may not be enabled." + return cpu_sharded_state, global_spec + + +def fsdp2_sharded_load_from_cpu( + model: torch.nn.Module, + cpu_sharded_state: dict[str, tuple[torch.Tensor, Optional[DTensorSpec]]], + target_spec: DTensorSpec, +) -> None: + """ + Sharded Load: Each process only loads the CPU shard it is responsible for to the GPU, + keeping sharding rules unchanged. + + Args: + model: FSDP2 model to be restored (must have the same structure as when saved) + cpu_sharded_state: Shard data read from CPU memory by the current process + (from fsdp2_sharded_save_to_cpu) + target_spec: Global DTensorSpec from saving (used to verify sharding rule consistency) + """ + # Verify device_mesh consistency (core: ensure loaded shards map to original GPUs) + current_device_mesh = None + for param in model.parameters(): + if isinstance(param, DTensor): + current_device_mesh = param._spec.device_mesh + break + assert current_device_mesh is not None, "DTensor parameters not initialized in the model to be loaded" + assert current_device_mesh == target_spec.device_mesh, ( + f"device_mesh mismatch during loading! Original: {target_spec.device_mesh}, Current: {current_device_mesh}" + ) + + for param_name, param in model.named_parameters(): + # Skip parameters not in the saved state (e.g., newly added parameters) + if param_name not in cpu_sharded_state: + continue + + # Extract CPU shard data and original Spec + local_cpu_tensor, saved_spec = cpu_sharded_state[param_name] + + # Handle different parameter types: DTensor sharded parameters vs. regular parameters + if isinstance(param, DTensor): + # 1. Verify sharding rule consistency (placements must match original Spec) + assert saved_spec is not None, f"DTensorSpec missing in saved state for parameter {param_name}" + assert saved_spec.placements == target_spec.placements, ( + f"Sharding strategy mismatch for parameter {param_name} (conflicts with global rules)!" + ) + + # 2. Move CPU shard data to the current GPU (device of param._local_tensor) + target_device = param._local_tensor.device + local_gpu_tensor = local_cpu_tensor.to(target_device) + + # 3. Restore to DTensor's local shard (directly copy to _local_tensor, keep spec unchanged) + param._local_tensor.copy_(local_gpu_tensor) + + else: + # Regular parameters: load directly to original device + target_device = param.device + param.data.copy_(local_cpu_tensor.to(target_device)) + + # Process synchronization: ensure all processes complete loading before proceeding + dist.barrier() diff --git a/recipe/fully_async_policy/fsdp_workers.py b/recipe/fully_async_policy/fsdp_workers.py index ad6b0db8b51..8e665352683 100644 --- a/recipe/fully_async_policy/fsdp_workers.py +++ b/recipe/fully_async_policy/fsdp_workers.py @@ -21,6 +21,7 @@ from omegaconf import DictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from recipe.fully_async_policy.fsdp2_utils import fsdp2_sharded_load_from_cpu, fsdp2_sharded_save_to_cpu from verl.single_controller.base.decorator import Dispatch, register from verl.utils.device import ( get_device_name, @@ -124,6 +125,23 @@ def get_actor_weights_info(self): self._weights_info = ret return ret + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_model_to_cpu(self, n): + if not hasattr(self, "cpu_saved_models"): + self.cpu_saved_models = {} + self.cpu_saved_models[n] = fsdp2_sharded_save_to_cpu(self.actor_module_fsdp) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def restore_model_from_cpu(self, n): + if n in self.cpu_saved_models: + cpu_sharded_state, global_spec = self.cpu_saved_models[n] + fsdp2_sharded_load_from_cpu(self.actor_module_fsdp, cpu_sharded_state, global_spec) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def clear_cpu_model(self, n): + if n in self.cpu_saved_models: + del self.cpu_saved_models[n] + class DetachAsyncRolloutWorker(DetachNcclSync): def __init__(self, config: DictConfig, role: str): diff --git a/recipe/fully_async_policy/fully_async_trainer.py b/recipe/fully_async_policy/fully_async_trainer.py index c62f7d924f1..554850e922b 100644 --- a/recipe/fully_async_policy/fully_async_trainer.py +++ b/recipe/fully_async_policy/fully_async_trainer.py @@ -100,6 +100,7 @@ def __init__( # required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples. self.require_batches = config.async_training.require_batches self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches + self.compute_prox_log_prob = self.config.async_training.compute_prox_log_prob total_gpus = ( config.trainer.nnodes * config.trainer.n_gpus_per_node + config.rollout.nnodes * config.rollout.n_gpus_per_node @@ -257,8 +258,9 @@ def fit(self): if batch is None: break self._collect_metrics_from_samples(batch, metrics) - - batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw) + batch, reward_extra_infos_dict = self._process_batch_common( + batch, metrics, timing_raw, self.local_trigger_step if self.compute_prox_log_prob else None + ) self._log_rollout(batch, reward_extra_infos_dict, timing_raw) self._check_save_checkpoint(False, timing_raw) diff --git a/recipe/fully_async_policy/ray_trainer.py b/recipe/fully_async_policy/ray_trainer.py index a4c80bac776..9d216e05639 100644 --- a/recipe/fully_async_policy/ray_trainer.py +++ b/recipe/fully_async_policy/ray_trainer.py @@ -336,7 +336,7 @@ def _post_generate_batch(self, batch, gen_batch_output, metrics): return batch - def _process_batch_common(self, batch, metrics, timing_raw): + def _process_batch_common(self, batch, metrics, timing_raw, local_trigger_step=None): with marked_timer("reward", timing_raw, color="yellow"): # compute reward model score if self.use_rm: @@ -348,14 +348,9 @@ def _process_batch_common(self, batch, metrics, timing_raw): else: reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) - # recompute old_log_probs with marked_timer("old_log_prob", timing_raw, color="blue"): - async_training = self.config.get("async_training", None) - if async_training and async_training.use_rollout_log_probs: - batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"] - batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature - else: + def compute_old_log_prob(batch): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] @@ -365,12 +360,34 @@ def _process_batch_common(self, batch, metrics, timing_raw): metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) - if "rollout_log_probs" in batch.batch.keys(): # TODO: we may want to add diff of probs too. from verl.utils.debug.metrics import calculate_debug_metrics metrics.update(calculate_debug_metrics(batch)) + return batch + + async_training = self.config.get("async_training", None) + if async_training and async_training.use_rollout_log_probs: + # If local_triger_step == 1, load the training engine's parameters to the CPU + # and save a copy for subsequent MIS use. + # If local_trigger_step == 2, 3, ..., restore the parameters of version 1 to calculate the old_log_prob, + # then restore the parameters of the current version. + if local_trigger_step == 1: + self.actor_rollout_wg.save_model_to_cpu(1) + batch = compute_old_log_prob(batch) + elif local_trigger_step is not None: + self.actor_rollout_wg.save_model_to_cpu(local_trigger_step) + self.actor_rollout_wg.restore_model_from_cpu(1) + batch = compute_old_log_prob(batch) + self.actor_rollout_wg.restore_model_from_cpu(local_trigger_step) + self.actor_rollout_wg.clear_cpu_model(local_trigger_step) + else: + batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"] + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + + else: + batch = compute_old_log_prob(batch) if self.use_reference_policy: # compute reference log_prob @@ -406,8 +423,14 @@ def _process_batch_common(self, batch, metrics, timing_raw): else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - # compute advantages, executed on the driver process + # Compute rollout importance sampling weights centrally (once per batch) + # This corrects for mismatch between rollout policy and training policy + # Also computes mismatch metrics (KL, PPL, etc.) + batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch) + # IS and mismatch metrics already have mismatch/ prefix + metrics.update(is_metrics) + # compute advantages, executed on the driver process norm_adv_by_std_in_grpo = self.config.algorithm.get( "norm_adv_by_std_in_grpo", True ) # GRPO adv normalization factor diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16-16.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh similarity index 99% rename from recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16-16.sh rename to recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh index 82072c3a0eb..71a11202376 100644 --- a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16-16.sh +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16_16.sh @@ -159,4 +159,4 @@ python -m recipe.fully_async_policy.fully_async_main \ async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ async_training.require_batches="${require_batches}" \ async_training.partial_rollout="${partial_rollout}" \ - async_training.use_rollout_log_probs=True + async_training.use_rollout_log_probs=True \ No newline at end of file diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh index c03e880eec8..ee2f6d332d8 100644 --- a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh @@ -71,7 +71,7 @@ n_resp_per_prompt=16 train_prompt_mini_bsz=32 total_rollout_steps=$(((512*400))) test_freq=20 -staleness_threshold=0.1 +staleness_threshold=0.5 trigger_parameter_sync_step=4 require_batches=4 partial_rollout=True diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh new file mode 100644 index 00000000000..c72ea786c12 --- /dev/null +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh @@ -0,0 +1,178 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_64-64' + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=4 +sp_size=4 +fsdp_size=8 + +# Fully async specific parameters +NNODES_ROLLOUT=${NNODES_ROLLOUT:-8} +NNODES_TRAIN=${NNODES_TRAIN:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*400))) +test_freq=20 +staleness_threshold=0.5 +trigger_parameter_sync_step=4 +require_batches=4 +partial_rollout=True + +# Rollout Importance Sampling +rollout_is_threshold=1.001 +rollout_is=True +rollout_is_threshold_lower=0.99 +rollout_is_level=geometric +rollout_is_mode=mask +rollout_is_veto_threshold=1e-4 + +python -m recipe.fully_async_policy.fully_async_main \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True \ + async_training.compute_prox_log_prob=True \ + algorithm.rollout_is=${rollout_is} \ + algorithm.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \ + algorithm.rollout_is_level=${rollout_is_level} \ + algorithm.rollout_is_mode=${rollout_is_mode} \ + algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} +