Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions recipe/fully_async_policy/shell/dapo_30b_a3b_base_math_fsdp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#!/usr/bin/env bash
set -xeuo pipefail

project_name='DAPO-Qwen3-30B-A3B-Base-Async'
exp_name='Fsdp2-tp4sp4'

# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
# Paths
DATA_PATH=${RAY_DATA_HOME:-"${HOME}/verl"}
DATA_PATH=${DATA_PATH:-"/mnt/bn/${BYTENAS}"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${DATA_PATH}/shared/models/Qwen3-30B-A3B-Base"}
CKPTS_DIR=${CKPTS_DIR:-"${DATA_PATH}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${DATA_PATH}/shared/data/dapo-math/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${DATA_PATH}/shared/data/dapo-math/aime-2024.parquet"}


rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi

# Algorithm parameters
adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

# Response length parameters
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 20))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

# Training parameters
loss_agg_mode="token-mean"
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7


NNODES=${NNODES:-4}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}

# Fully async specific parameters
n_gpus_rollout=8
n_gpus_training=8
n_nodes_rollout=2
n_nodes_train=2 # $((NNODES - n_nodes_rollout))

train_bsz=512
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((train_bsz * 400)))
test_freq=25
staleness_threshold=0.6 # 0 0.3 1
require_batches=1
total_train_gpus=$((n_gpus_training * n_nodes_train))
total_rollout_gpus=$((n_gpus_rollout * n_nodes_rollout))
trigger_parameter_sync_step=$((train_bsz / ( train_prompt_mini_bsz * require_batches))) # 8 16 32
partial_rollout=True
enforce_eager=False
nccl_timeout=72000
enable_sleep_mode=False

# Performance Related Parameter
sp_size=4
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
ref_offload=True
actor_offload=False
gen_tp=4
fsdp_size=-1


ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
--address "${RAY_ADDRESS}" \
-- python3 -m recipe.fully_async_policy.fully_async_main \
--config-path=config \
--config-name='fully_async_dapo_trainer.yaml' \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
actor_rollout_ref.actor.strategy=fsdp \
critic.strategy=fsdp \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.nccl_timeout=${nccl_timeout} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.50 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
+actor_rollout_ref.rollout.enable_sleep_mode=${enable_sleep_mode} \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.enforce_eager=${enforce_eager} \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}-i${total_rollout_gpus}_t${total_train_gpus}_s${staleness_threshold}" \
trainer.val_before_train=True \
trainer.test_freq="${test_freq}" \
trainer.save_freq=-1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.nnodes="${n_nodes_train}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${n_nodes_rollout}" \
rollout.n_gpus_per_node="${n_gpus_rollout}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.test_freq=${test_freq} \
rollout.total_epochs=10 \
async_training.require_batches=${require_batches} \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.partial_rollout="${partial_rollout}" \
async_training.use_rollout_log_probs=True
2 changes: 2 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class RolloutConfig(BaseConfig):

enable_rollout_routing_replay: bool = False

enable_sleep_mode: bool = True

def __post_init__(self):
"""Validate the rollout config"""
if self.expert_parallel_size > 1:
Expand Down
9 changes: 8 additions & 1 deletion verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
max_new_tokens=self.config.response_length,
)
logger.info(f"override_generation_config: {override_generation_config}")

logger.info(f"enable_sleep_mode: {self.config.enable_sleep_mode}")
if not self.config.enable_sleep_mode:
from verl.utils.device import set_expandable_segments

set_expandable_segments(True)

quantization = self.config.quantization
if quantization is not None:
if quantization == "fp8":
Expand All @@ -279,7 +286,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
"enable_chunked_prefill": self.config.enable_chunked_prefill,
"max_num_batched_tokens": self.config.max_num_batched_tokens,
"enable_prefix_caching": self.config.enable_prefix_caching,
"enable_sleep_mode": True,
"enable_sleep_mode": self.config.enable_sleep_mode,
"disable_custom_all_reduce": True,
"enforce_eager": self.config.enforce_eager,
"gpu_memory_utilization": self.config.gpu_memory_utilization,
Expand Down
Loading