diff --git a/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh b/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh index 6937db5fcfa..1db311e28f2 100644 --- a/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh +++ b/examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh @@ -30,7 +30,7 @@ train_prompt_mini_bsz=128 train_ppo_micro_batch_size_per_gpu=2 infer_ppo_micro_batch_size_per_gpu=2 # Paths -MODEL_PATH=Qwen/Qwen3-30B-A3B +MODEL_PATH=Qwen/Qwen3-30B-A3B-Base RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} TRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index a6a06192374..18995f67787 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -321,7 +321,7 @@ def fit(self): rollout_corr_config = self.config.algorithm.get("rollout_correction", None) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: - batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) # IS and off-policy metrics already have rollout_corr/ prefix metrics.update(is_metrics) diff --git a/recipe/fully_async_policy/megatron_utils.py b/recipe/fully_async_policy/megatron_utils.py new file mode 100644 index 00000000000..9f5380f25c5 --- /dev/null +++ b/recipe/fully_async_policy/megatron_utils.py @@ -0,0 +1,99 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.distributed import DistributedDataParallel as DDP + + +@torch.no_grad() +def copy_megatron_model_to_cpu(models): + """ + Copy Megatron model parameters to CPU memory (non-destructive copy). + Unlike offload_megatron_model_to_cpu which moves data, this function creates + independent copies on CPU while keeping GPU data intact. + + Args: + models: List of model chunks (DDP-wrapped or unwrapped) + + Returns: + dict: CPU state containing copied parameters and buffers + """ + cpu_state = {} + + for model_idx, model_chunk in enumerate(models): + if isinstance(model_chunk, DDP): + # Handle DDP-wrapped models + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + buffer_states = [] + + for buffers in model_chunk_all_buffers: + buffer_list = [] + for buffer in buffers: + buffer_state = {} + + # Copy parameter data to CPU + if buffer.param_data.storage().size() > 0: + buffer_state["param_data"] = buffer.param_data.data.cpu().clone().pin_memory() + + buffer_list.append(buffer_state) + buffer_states.append(buffer_list) + + cpu_state[f"model_chunk_{model_idx}"] = {"buffer_states": buffer_states, "is_ddp": True} + else: + # Handle non-DDP models (ref module) + model_state = {} + for name, param in model_chunk.named_parameters(): + param_state = {"data": param.data.cpu().clone().pin_memory()} + model_state[name] = param_state + + cpu_state[f"model_chunk_{model_idx}"] = {"model_state": model_state, "is_ddp": False} + + return cpu_state + + +@torch.no_grad() +def restore_megatron_model_from_cpu(models, cpu_state): + """ + Restore Megatron model parameters from CPU memory back to GPU. + + Args: + models: List of model chunks to restore to + cpu_state: CPU state dict returned from copy_megatron_model_to_cpu + """ + for model_idx, model_chunk in enumerate(models): + chunk_key = f"model_chunk_{model_idx}" + if chunk_key not in cpu_state: + continue + + chunk_state = cpu_state[chunk_key] + + if chunk_state["is_ddp"] and isinstance(model_chunk, DDP): + # Restore DDP buffers + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + buffer_states = chunk_state["buffer_states"] + + for buffers, buffer_list in zip(model_chunk_all_buffers, buffer_states, strict=False): + for buffer, buffer_state in zip(buffers, buffer_list, strict=False): + # Restore parameter data + if "param_data" in buffer_state: + buffer.param_data.data.copy_(buffer_state["param_data"].to(buffer.param_data.device)) + + elif not chunk_state["is_ddp"] and not isinstance(model_chunk, DDP): + # Restore non-DDP models + model_state = chunk_state["model_state"] + for name, param in model_chunk.named_parameters(): + if name in model_state: + param_state = model_state[name] + param.data.copy_(param_state["data"].to(param.device)) diff --git a/recipe/fully_async_policy/megatron_worker.py b/recipe/fully_async_policy/megatron_worker.py index 7097ef64c2a..24bac66be77 100644 --- a/recipe/fully_async_policy/megatron_worker.py +++ b/recipe/fully_async_policy/megatron_worker.py @@ -21,6 +21,7 @@ import torch.distributed from omegaconf import DictConfig +from recipe.fully_async_policy.megatron_utils import copy_megatron_model_to_cpu, restore_megatron_model_from_cpu from verl.single_controller.base.decorator import Dispatch, register from verl.utils.device import ( get_device_name, @@ -89,6 +90,22 @@ def sync_rollout_weights(self): if self._is_rollout: inference_model.load_weights([(key, tensor)]) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_model_to_cpu(self, n): + if not hasattr(self, "cpu_saved_models"): + self.cpu_saved_models = {} + self.cpu_saved_models[n] = copy_megatron_model_to_cpu(self.actor.actor_module) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def restore_model_from_cpu(self, n): + if n in self.cpu_saved_models: + restore_megatron_model_from_cpu(self.actor.actor_module, self.cpu_saved_models[n]) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def clear_cpu_model(self, n): + if n in self.cpu_saved_models: + del self.cpu_saved_models[n] + class DetachActorWorker(DetachNcclSync): def _get_actor_params_generator(self): diff --git a/recipe/fully_async_policy/ray_trainer.py b/recipe/fully_async_policy/ray_trainer.py index ba6c1448ae6..c27851ddf19 100644 --- a/recipe/fully_async_policy/ray_trainer.py +++ b/recipe/fully_async_policy/ray_trainer.py @@ -430,7 +430,7 @@ def compute_old_log_prob(batch): rollout_corr_config = self.config.algorithm.get("rollout_correction", None) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: - batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) # IS and off-policy metrics already have rollout_corr/ prefix metrics.update(is_metrics) diff --git a/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh b/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh new file mode 100644 index 00000000000..be9523f9e08 --- /dev/null +++ b/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32.sh @@ -0,0 +1,230 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='GRPO-Qwen3-30b-Base-MATH' +exp_name='GRPO-Qwen3-30b-Base-MATH-megatron-fully-async_96-32' + +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.001 +kl_loss_type=low_var_kl + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +offload=True +train_ppo_micro_batch_size_per_gpu=2 +infer_ppo_micro_batch_size_per_gpu=2 + +optimizer_offload_fraction=${OFFLOAD_FRACTION:-1.} + +COMMON_PP=${COMMON_PP:-1} +COMMON_VPP=${COMMON_VPP:-null} +COMMON_CP=${COMMON_CP:-2} +COMMON_TP=${COMMON_TP:-2} +COMMON_EP=${COMMON_EP:-8} +COMMON_ETP=${COMMON_ETP:-1} + +TRAIN_TP=${TRAIN_TP:-$COMMON_TP} +INFER_TP=${INFER_TP:-4} + +ACTOR_PP=${ACTOR_PP:-$COMMON_PP} +ACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP} +ACTOR_CP=${ACTOR_CP:-$COMMON_CP} +ACTOR_TP=${ACTOR_TP:-$TRAIN_TP} +ACTOR_EP=${ACTOR_EP:-$COMMON_EP} +ACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP} +ROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP} +REF_PP=${REF_PP:-$COMMON_PP} +REF_VPP=${REF_VPP:-$COMMON_VPP} +REF_CP=${REF_CP:-$COMMON_CP} +REF_TP=${REF_TP:-$TRAIN_TP} +REF_EP=${REF_EP:-$COMMON_EP} +REF_ETP=${REF_ETP:-$COMMON_ETP} +CRITIC_PP=${CRITIC_PP:-$COMMON_PP} +CRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP} +CRITIC_CP=${CRITIC_CP:-$COMMON_CP} +CRITIC_TP=${CRITIC_TP:-$TRAIN_TP} +CRITIC_EP=${CRITIC_EP:-$COMMON_EP} +CRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP} +RM_PP=${RM_PP:-$COMMON_PP} +RM_VPP=${RM_VPP:-$COMMON_VPP} +RM_CP=${RM_CP:-$COMMON_CP} +RM_TP=${RM_TP:-$TRAIN_TP} +RM_EP=${RM_EP:-$COMMON_EP} +RM_ETP=${RM_ETP:-$COMMON_ETP} + +# install mbridge +# pip3 install git+https://github.com/ISEEKYAN/mbridge +USE_MBRIDGE=True +USE_DIST_CKPT=False + +# Fully async specific parameters +NNODES_ROLLOUT=${NNODES_ROLLOUT:-12} +NNODES_TRAIN=${NNODES_TRAIN:-4} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=128 +total_rollout_steps=$(((512*400))) +test_freq=20 +staleness_threshold=0.5 +trigger_parameter_sync_step=4 +require_batches=1 +partial_rollout=True + +python -m recipe.fully_async_policy.fully_async_main \ + --config-path=config \ + --config-name='fully_async_ppo_megatron_trainer.yaml'\ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + +actor_rollout_ref.model.override_config.model_config.max_position_embeddings=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.model.use_fused_kernels=False \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.lr_decay_style='constant' \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.lr_decay_steps=${total_rollout_steps} \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \ + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ + actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${ACTOR_PP} \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${ACTOR_VPP} \ + actor_rollout_ref.actor.megatron.context_parallel_size=${ACTOR_CP} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${ACTOR_EP} \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ACTOR_ETP} \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type="flex" \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${INFER_TP} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${REF_TP} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${REF_PP} \ + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=${REF_VPP} \ + actor_rollout_ref.ref.megatron.context_parallel_size=${REF_CP} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${REF_EP} \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${REF_ETP} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True \ + diff --git a/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh b/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh new file mode 100644 index 00000000000..ba2d6e4680b --- /dev/null +++ b/recipe/fully_async_policy/shell/grpo_30b_a3b_base_math_megatron_96_32_mis.sh @@ -0,0 +1,244 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='GRPO-Qwen3-30b-Base-MATH' +exp_name='GRPO-Qwen3-30b-Base-MATH-megatron-fully-async_96-32' + +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.001 +kl_loss_type=low_var_kl + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +offload=True +train_ppo_micro_batch_size_per_gpu=2 +infer_ppo_micro_batch_size_per_gpu=2 + +optimizer_offload_fraction=${OFFLOAD_FRACTION:-1.} + +COMMON_PP=${COMMON_PP:-1} +COMMON_VPP=${COMMON_VPP:-null} +COMMON_CP=${COMMON_CP:-2} +COMMON_TP=${COMMON_TP:-2} +COMMON_EP=${COMMON_EP:-8} +COMMON_ETP=${COMMON_ETP:-1} + +TRAIN_TP=${TRAIN_TP:-$COMMON_TP} +INFER_TP=${INFER_TP:-4} + +ACTOR_PP=${ACTOR_PP:-$COMMON_PP} +ACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP} +ACTOR_CP=${ACTOR_CP:-$COMMON_CP} +ACTOR_TP=${ACTOR_TP:-$TRAIN_TP} +ACTOR_EP=${ACTOR_EP:-$COMMON_EP} +ACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP} +ROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP} +REF_PP=${REF_PP:-$COMMON_PP} +REF_VPP=${REF_VPP:-$COMMON_VPP} +REF_CP=${REF_CP:-$COMMON_CP} +REF_TP=${REF_TP:-$TRAIN_TP} +REF_EP=${REF_EP:-$COMMON_EP} +REF_ETP=${REF_ETP:-$COMMON_ETP} +CRITIC_PP=${CRITIC_PP:-$COMMON_PP} +CRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP} +CRITIC_CP=${CRITIC_CP:-$COMMON_CP} +CRITIC_TP=${CRITIC_TP:-$TRAIN_TP} +CRITIC_EP=${CRITIC_EP:-$COMMON_EP} +CRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP} +RM_PP=${RM_PP:-$COMMON_PP} +RM_VPP=${RM_VPP:-$COMMON_VPP} +RM_CP=${RM_CP:-$COMMON_CP} +RM_TP=${RM_TP:-$TRAIN_TP} +RM_EP=${RM_EP:-$COMMON_EP} +RM_ETP=${RM_ETP:-$COMMON_ETP} + +# install mbridge +# pip3 install git+https://github.com/ISEEKYAN/mbridge +USE_MBRIDGE=True +USE_DIST_CKPT=False + +# Fully async specific parameters +NNODES_ROLLOUT=${NNODES_ROLLOUT:-12} +NNODES_TRAIN=${NNODES_TRAIN:-4} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=128 +total_rollout_steps=$(((512*400))) +test_freq=20 +staleness_threshold=0.5 +trigger_parameter_sync_step=4 +require_batches=1 +partial_rollout=True + +# Rollout Importance Sampling + +rollout_is=null +rollout_rs=geometric +rollout_rs_threshold=1.001 +rollout_rs_threshold_lower=0.999 +rollout_token_veto_threshold=1e-4 + +python -m recipe.fully_async_policy.fully_async_main \ + --config-path=config \ + --config-name='fully_async_ppo_megatron_trainer.yaml'\ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + async_training.compute_prox_log_prob=True \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + algorithm.rollout_correction.rollout_rs_threshold_lower=${rollout_rs_threshold_lower} \ + algorithm.rollout_correction.rollout_token_veto_threshold=${rollout_token_veto_threshold} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + +actor_rollout_ref.model.override_config.model_config.max_position_embeddings=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.model.use_fused_kernels=False \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.lr_decay_style='constant' \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.lr_decay_steps=${total_rollout_steps} \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \ + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \ + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \ + actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${ACTOR_PP} \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${ACTOR_VPP} \ + actor_rollout_ref.actor.megatron.context_parallel_size=${ACTOR_CP} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${ACTOR_EP} \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ACTOR_ETP} \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type="flex" \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${INFER_TP} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${REF_TP} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${REF_PP} \ + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=${REF_VPP} \ + actor_rollout_ref.ref.megatron.context_parallel_size=${REF_CP} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${REF_EP} \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${REF_ETP} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True \ + diff --git a/recipe/one_step_off_policy/ray_trainer.py b/recipe/one_step_off_policy/ray_trainer.py index 3f5ec30ac17..f5c1c4612f1 100644 --- a/recipe/one_step_off_policy/ray_trainer.py +++ b/recipe/one_step_off_policy/ray_trainer.py @@ -582,7 +582,7 @@ def fit(self): rollout_corr_config = self.config.algorithm.get("rollout_correction", None) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: - batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) # IS and off-policy metrics already have rollout_corr/ prefix metrics.update(is_metrics) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index fc121e3c2e0..a6b54eb7904 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -469,14 +469,15 @@ reward_model: use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} load_weight: true algorithm: - rollout_is: null - rollout_is_threshold: 2.0 - rollout_rs: null - rollout_rs_threshold: null - rollout_rs_threshold_lower: null - rollout_token_veto_threshold: null - bypass_old_logprob_for_rollout: false - use_pure_rollout_correction: false + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + rollout_rs_threshold_lower: null + rollout_token_veto_threshold: null + bypass_old_logprob_for_rollout: false + use_pure_rollout_correction: false _target_: verl.trainer.config.AlgoConfig gamma: 1.0 lam: 1.0 diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index c126dc418ee..0815451d178 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -16,7 +16,7 @@ defaults: # Reward model config. - reward_model@reward_model: megatron_reward_model # Rollout correction config. - - algorithm@algorithm: rollout_correction + - algorithm@algorithm.rollout_correction: rollout_correction - _self_ actor_rollout_ref: diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 3ec59760a05..88cfd715f1c 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1172,7 +1172,7 @@ def fit(self): # This corrects for off-policy issues (policy mismatch, model staleness, etc.) # Also computes off-policy diagnostic metrics (KL, PPL, etc.) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: - batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) # IS and off-policy metrics already have rollout_corr/ prefix metrics.update(is_metrics)