Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/grpo_trainer/run_qwen3moe-30b_megatron_96gb.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ train_prompt_mini_bsz=128
train_ppo_micro_batch_size_per_gpu=2
infer_ppo_micro_batch_size_per_gpu=2
# Paths
MODEL_PATH=Qwen/Qwen3-30B-A3B
MODEL_PATH=Qwen/Qwen3-30B-A3B-Base

RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
TRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet
Expand Down
2 changes: 1 addition & 1 deletion recipe/dapo/dapo_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def fit(self):

rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
if rollout_corr_config is not None and "rollout_log_probs" in batch.batch:
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch)
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
# IS and off-policy metrics already have rollout_corr/ prefix
metrics.update(is_metrics)

Expand Down
99 changes: 99 additions & 0 deletions recipe/fully_async_policy/megatron_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from megatron.core.distributed import DistributedDataParallel as DDP


@torch.no_grad()
def copy_megatron_model_to_cpu(models):
"""
Copy Megatron model parameters to CPU memory (non-destructive copy).
Unlike offload_megatron_model_to_cpu which moves data, this function creates
independent copies on CPU while keeping GPU data intact.

Args:
models: List of model chunks (DDP-wrapped or unwrapped)

Returns:
dict: CPU state containing copied parameters and buffers
"""
cpu_state = {}

for model_idx, model_chunk in enumerate(models):
if isinstance(model_chunk, DDP):
# Handle DDP-wrapped models
model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
buffer_states = []

for buffers in model_chunk_all_buffers:
buffer_list = []
for buffer in buffers:
buffer_state = {}

# Copy parameter data to CPU
if buffer.param_data.storage().size() > 0:
buffer_state["param_data"] = buffer.param_data.data.cpu().clone().pin_memory()

buffer_list.append(buffer_state)
buffer_states.append(buffer_list)

cpu_state[f"model_chunk_{model_idx}"] = {"buffer_states": buffer_states, "is_ddp": True}
else:
# Handle non-DDP models (ref module)
model_state = {}
for name, param in model_chunk.named_parameters():
param_state = {"data": param.data.cpu().clone().pin_memory()}
model_state[name] = param_state

cpu_state[f"model_chunk_{model_idx}"] = {"model_state": model_state, "is_ddp": False}

return cpu_state


@torch.no_grad()
def restore_megatron_model_from_cpu(models, cpu_state):
"""
Restore Megatron model parameters from CPU memory back to GPU.

Args:
models: List of model chunks to restore to
cpu_state: CPU state dict returned from copy_megatron_model_to_cpu
"""
for model_idx, model_chunk in enumerate(models):
chunk_key = f"model_chunk_{model_idx}"
if chunk_key not in cpu_state:
continue

chunk_state = cpu_state[chunk_key]

if chunk_state["is_ddp"] and isinstance(model_chunk, DDP):
# Restore DDP buffers
model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
buffer_states = chunk_state["buffer_states"]

for buffers, buffer_list in zip(model_chunk_all_buffers, buffer_states, strict=False):
for buffer, buffer_state in zip(buffers, buffer_list, strict=False):
# Restore parameter data
if "param_data" in buffer_state:
buffer.param_data.data.copy_(buffer_state["param_data"].to(buffer.param_data.device))

elif not chunk_state["is_ddp"] and not isinstance(model_chunk, DDP):
# Restore non-DDP models
model_state = chunk_state["model_state"]
for name, param in model_chunk.named_parameters():
if name in model_state:
param_state = model_state[name]
param.data.copy_(param_state["data"].to(param.device))
17 changes: 17 additions & 0 deletions recipe/fully_async_policy/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.distributed
from omegaconf import DictConfig

from recipe.fully_async_policy.megatron_utils import copy_megatron_model_to_cpu, restore_megatron_model_from_cpu
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.device import (
get_device_name,
Expand Down Expand Up @@ -89,6 +90,22 @@ def sync_rollout_weights(self):
if self._is_rollout:
inference_model.load_weights([(key, tensor)])

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_model_to_cpu(self, n):
if not hasattr(self, "cpu_saved_models"):
self.cpu_saved_models = {}
self.cpu_saved_models[n] = copy_megatron_model_to_cpu(self.actor.actor_module)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def restore_model_from_cpu(self, n):
if n in self.cpu_saved_models:
restore_megatron_model_from_cpu(self.actor.actor_module, self.cpu_saved_models[n])

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def clear_cpu_model(self, n):
if n in self.cpu_saved_models:
del self.cpu_saved_models[n]


class DetachActorWorker(DetachNcclSync):
def _get_actor_params_generator(self):
Expand Down
2 changes: 1 addition & 1 deletion recipe/fully_async_policy/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def compute_old_log_prob(batch):

rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
if rollout_corr_config is not None and "rollout_log_probs" in batch.batch:
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch)
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
# IS and off-policy metrics already have rollout_corr/ prefix
metrics.update(is_metrics)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
#!/usr/bin/env bash
set -xeuo pipefail

project_name='GRPO-Qwen3-30b-Base-MATH'
exp_name='GRPO-Qwen3-30b-Base-MATH-megatron-fully-async_96-32'

RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}

rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=True
kl_loss_coef=0.001
kl_loss_type=low_var_kl

clip_ratio_low=0.2
clip_ratio_high=0.28

# Response length parameters
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

loss_agg_mode="token-mean"

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7

# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length)))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length)))
offload=True
train_ppo_micro_batch_size_per_gpu=2
infer_ppo_micro_batch_size_per_gpu=2

optimizer_offload_fraction=${OFFLOAD_FRACTION:-1.}

COMMON_PP=${COMMON_PP:-1}
COMMON_VPP=${COMMON_VPP:-null}
COMMON_CP=${COMMON_CP:-2}
COMMON_TP=${COMMON_TP:-2}
COMMON_EP=${COMMON_EP:-8}
COMMON_ETP=${COMMON_ETP:-1}

TRAIN_TP=${TRAIN_TP:-$COMMON_TP}
INFER_TP=${INFER_TP:-4}

ACTOR_PP=${ACTOR_PP:-$COMMON_PP}
ACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP}
ACTOR_CP=${ACTOR_CP:-$COMMON_CP}
ACTOR_TP=${ACTOR_TP:-$TRAIN_TP}
ACTOR_EP=${ACTOR_EP:-$COMMON_EP}
ACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP}
ROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP}
REF_PP=${REF_PP:-$COMMON_PP}
REF_VPP=${REF_VPP:-$COMMON_VPP}
REF_CP=${REF_CP:-$COMMON_CP}
REF_TP=${REF_TP:-$TRAIN_TP}
REF_EP=${REF_EP:-$COMMON_EP}
REF_ETP=${REF_ETP:-$COMMON_ETP}
CRITIC_PP=${CRITIC_PP:-$COMMON_PP}
CRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP}
CRITIC_CP=${CRITIC_CP:-$COMMON_CP}
CRITIC_TP=${CRITIC_TP:-$TRAIN_TP}
CRITIC_EP=${CRITIC_EP:-$COMMON_EP}
CRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP}
RM_PP=${RM_PP:-$COMMON_PP}
RM_VPP=${RM_VPP:-$COMMON_VPP}
RM_CP=${RM_CP:-$COMMON_CP}
RM_TP=${RM_TP:-$TRAIN_TP}
RM_EP=${RM_EP:-$COMMON_EP}
RM_ETP=${RM_ETP:-$COMMON_ETP}

# install mbridge
# pip3 install git+https://github.com/ISEEKYAN/mbridge
USE_MBRIDGE=True
USE_DIST_CKPT=False

# Fully async specific parameters
NNODES_ROLLOUT=${NNODES_ROLLOUT:-12}
NNODES_TRAIN=${NNODES_TRAIN:-4}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}

train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=128
total_rollout_steps=$(((512*400)))
test_freq=20
staleness_threshold=0.5
trigger_parameter_sync_step=4
require_batches=1
partial_rollout=True

python -m recipe.fully_async_policy.fully_async_main \
--config-path=config \
--config-name='fully_async_ppo_megatron_trainer.yaml'\
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
+actor_rollout_ref.model.override_config.model_config.max_position_embeddings=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.model.use_fused_kernels=False \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.lr_decay_style='constant' \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.optim.lr_decay_steps=${total_rollout_steps} \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${optimizer_offload_fraction} \
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \
actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${ACTOR_PP} \
actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${ACTOR_VPP} \
actor_rollout_ref.actor.megatron.context_parallel_size=${ACTOR_CP} \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=${ACTOR_EP} \
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ACTOR_ETP} \
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type="flex" \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${INFER_TP} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.hybrid_engine=False \
actor_rollout_ref.rollout.enforce_eager=True \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${REF_TP} \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${REF_PP} \
actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=${REF_VPP} \
actor_rollout_ref.ref.megatron.context_parallel_size=${REF_CP} \
actor_rollout_ref.ref.megatron.expert_model_parallel_size=${REF_EP} \
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${REF_ETP} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.save_freq=-1 \
trainer.total_epochs=10 \
trainer.resume_mode=auto \
trainer.log_val_generations=10 \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.total_epochs=10 \
rollout.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
async_training.use_rollout_log_probs=True \

Loading
Loading