Skip to content
2 changes: 1 addition & 1 deletion docs/advance/mtp.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek s

- **Dependency Versions**:

- mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (will be merged into the main branch in the future);
- mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (Already merged into the main branch);

- Megatron-Bridge: Apply the patches and review suggestions from PR if you want to try out mimo-7B-RL: [#2387](https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/2387) (will be merged into the main branch in the future);

Expand Down
160 changes: 160 additions & 0 deletions examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/usr/bin/env bash

set -xeuo pipefail

project_name='DAPO'
exp_name='DAPO-mimo-7b-rl-megatron'

adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

loss_agg_mode="token-mean"

# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/examples/mtp_trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/MiMo-7B-RL"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7

# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
offload=False
gen_tp=2
train_tp=2
train_pp=1
train_cp=1

train_prompt_bsz=128
n_resp_per_prompt=16
train_prompt_mini_bsz=32

mtp_params=(
actor_rollout_ref.actor.megatron.use_mbridge=True
actor_rollout_ref.model.mtp.enable=True
actor_rollout_ref.model.mtp.enable_train=True
actor_rollout_ref.model.mtp.mtp_loss_scaling_factor=0.1
actor_rollout_ref.model.mtp.detach_encoder=True
actor_rollout_ref.model.mtp.enable_rollout=True
)

fully_async=(
data.train_batch_size=0
data.gen_batch_size=1
trainer.test_freq=10
actor_rollout_ref.hybrid_engine=False
actor_rollout_ref.rollout.calculate_log_probs=True
actor_rollout_ref.actor.optim.lr_decay_steps=51200
rollout.total_rollout_steps=$(((512*100)))
trainer.nnodes=1
trainer.n_gpus_per_node=4
rollout.nnodes=1
rollout.n_gpus_per_node=4
async_training.staleness_threshold=0.5
async_training.trigger_parameter_sync_step=4
async_training.require_batches=1
async_training.partial_rollout=True
)

python -m verl.experimental.fully_async_policy.fully_async_main \
--config-path=config \
--config-name='fully_async_ppo_megatron_trainer.yaml'\
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.optim.clip_grad=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
actor_rollout_ref.rollout.disable_log_stats=False \
actor_rollout_ref.rollout.prometheus.enable=True \
actor_rollout_ref.rollout.prometheus.port=44398 \
actor_rollout_ref.model.trust_remote_code=True \
data.trust_remote_code=True \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.save_freq=-1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10 \
trainer.total_epochs=10 \
"${mtp_params[@]}" \
"${fully_async[@]}"
22 changes: 15 additions & 7 deletions verl/checkpoint_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,13 @@ def remove_replicas(self, replicas: list[RolloutReplica]):
@auto_await
async def sleep_replicas(self):
"""Sleep all rollout replicas: free weight and kv_cache device memory."""
# skip sleep replicas for disaggregated rollout
if self.backend != "naive":
return
await asyncio.gather(*[r.sleep() for r in self.replicas])

@auto_await
async def wake_up_replicas(self):
"""Resume all rollout replicas: recover kv_cache and weights device memory."""
await asyncio.gather(*[r.wake_up() for r in self.replicas])

@auto_await
async def update_weights(self, global_steps: int = None):
"""Update weights from trainer to rollout replicas.
Expand All @@ -421,17 +423,23 @@ async def update_weights(self, global_steps: int = None):
rollout = RayWorkerGroup(worker_handles=workers, ray_cls_with_init=RayClassWithInitArgs(cls=_worker_cls))
trainer = self.trainer

# 3. build process group
# 3. sleep replicas to free kv_cache before weight sync (if free_cache_engine is enabled)
await self.sleep_replicas()

# 4. build process group
self.build_process_group(rollout)

# 4. update weights of all workers
# 5. update weights of all workers
ray.get(trainer.update_weights(global_steps=global_steps) + rollout.update_weights(global_steps=global_steps))

# 5. finalize all workers
# 6. finalize all workers
ray.get(
trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size)
+ rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size)
)

# 6. resume all unfinished requests for partial rollout
# 7. resume replicas to recover kv_cache (for free_cache_engine scenarios)
await self.wake_up_replicas()

# 8. resume all unfinished requests for partial rollout
await asyncio.gather(*[r.resume_generation() for r in self.replicas])
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ actor_rollout_ref:
# Must use rollout log probs for training
use_rollout_log_probs: True

model:
# To use remove padding (thd)
use_remove_padding: True

# Only then will the use of log probs be correct.
# And it can be used in conjunction with other rollout_correction algorithms.
algorithm:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ actor_rollout_ref:
# Must use rollout log probs for training
use_rollout_log_probs: True

model:
# To use remove padding (thd)
use_remove_padding: True


# Only then will the use of log probs be correct.
# And it can be used in conjunction with other rollout_correction algorithms.
algorithm:
Expand Down
91 changes: 72 additions & 19 deletions verl/models/mcore/model_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import torch
from torch.nested._internal.nested_tensor import NestedTensor

from verl.utils.megatron_utils import unwrap_model
from verl.workers.config import MtpConfig
Expand Down Expand Up @@ -65,14 +66,19 @@ def model_forward(
model_kwargs["video_grid_thw"] = multi_modal_inputs["video_grid_thw"].to(input_ids.device)

batch_size, seq_len = attention_mask.shape[:2]
mtp_enable_train = mtp_config and mtp_config.enable_train

if data_format == "thd":
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(
input_ids, attention_mask, pre_process=pre_process or post_process, use_fp8_padding=use_fp8_padding
input_ids,
attention_mask,
pre_process=pre_process or (post_process and mtp_enable_train),
use_fp8_padding=use_fp8_padding,
)
input_ids_rmpad = input_ids_rmpad.contiguous()

# when pp > 1 and processor is not None, we need to pass the labels and loss_mask to the model
if mtp_config and mtp_config.enable_train and post_process:
if mtp_enable_train and post_process:
args = {
k: preprocess_packed_seqs(v, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding)[0]
for k, v in logits_processor_args.items()
Expand Down Expand Up @@ -169,6 +175,38 @@ def model_forward(
return model_forward


def _convert_to_nested_tensor(v, input_ids_lengths):
"""Convert regular tensor to NestedTensor, slicing according to input_ids_lengths.

Args:
v: Tensor to convert, shape [batch, seq_len]
input_ids_lengths: List of valid lengths for each sample

Returns:
Converted NestedTensor
"""
if isinstance(v, NestedTensor):
return v

batch_size = v.shape[0]
assert len(input_ids_lengths) == batch_size, (
f"len(input_ids_lengths)={len(input_ids_lengths)} != batch_size={batch_size}"
)

v_split_list = []
for i in range(batch_size):
vi = v[i]
target_len = input_ids_lengths[i]
if vi.shape[0] > target_len:
vi = vi[:target_len]
elif vi.shape[0] < target_len:
vi = torch.cat([vi, torch.ones(target_len - vi.shape[0], dtype=vi.dtype, device=vi.device)])
v_split_list.append(vi)

v = torch.nested.nested_tensor(v_split_list, layout=torch.jagged)
return v


def gptmodel_forward_no_padding(
model,
input_ids,
Expand All @@ -179,7 +217,7 @@ def gptmodel_forward_no_padding(
vision_model=False,
pad_token_id=None,
data_format: str = "thd",
enable_mtp: bool = False,
mtp_enable_train: bool = False,
):
"""Default forward pass for GPT models with optional sequence packing."""

Expand All @@ -202,20 +240,28 @@ def gptmodel_forward_no_padding(

batch_size = input_ids.shape[0]
if data_format == "thd":
input_ids_rmpad, packed_seq_params = preprocess_thd_no_padding(
input_ids, pre_process=pre_process, use_fp8_padding=use_fp8_padding
input_ids_rmpad, packed_seq_params, position_ids_rmpad = preprocess_thd_no_padding(
input_ids, pre_process=pre_process or (post_process and mtp_enable_train), use_fp8_padding=use_fp8_padding
)
input_ids_rmpad = input_ids_rmpad.contiguous()

if enable_mtp and post_process:
args = {
k: preprocess_thd_no_padding(
v, pre_process=True, need_roll=(k == "label" or k == "loss_mask"), use_fp8_padding=use_fp8_padding
args = {}
if mtp_enable_train and post_process:
# Use input_ids sequence length to ensure label and loss_mask alignment
input_ids_offsets = input_ids.offsets()
input_ids_lengths = input_ids_offsets.diff().tolist()

for k in ["label", "loss_mask"]:
v = logits_processor_args[k]
v = _convert_to_nested_tensor(v, input_ids_lengths)
logits_processor_args[k] = v
args[k] = preprocess_thd_no_padding(
v, pre_process=True, need_roll=True, use_fp8_padding=use_fp8_padding
)[0]
for k, v in logits_processor_args.items()
}

model_kwargs["labels"] = args["label"].contiguous()
model_kwargs["loss_mask"] = args["loss_mask"].contiguous()

if logits_processor_args and "loss_mask" in logits_processor_args:
logits_processor_args.pop("loss_mask")

Expand All @@ -231,7 +277,7 @@ def gptmodel_forward_no_padding(
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=attention_mask,
position_ids=None,
position_ids=position_ids_rmpad if not vision_model else None, # vision models will calculate position_ids
packed_seq_params=packed_seq_params,
**model_kwargs,
)
Expand Down Expand Up @@ -262,18 +308,25 @@ def gptmodel_forward_no_padding(
"""

input_ids_bshd, attention_mask_bshd, position_ids_bshd = preprocess_bshd_no_padding(
input_ids, pre_process=pre_process, use_fp8_padding=use_fp8_padding
input_ids, pre_process=pre_process or (post_process and mtp_enable_train), use_fp8_padding=use_fp8_padding
)

if enable_mtp and post_process:
args = {
k: preprocess_bshd_no_padding(
v, pre_process=True, need_roll=(k == "label" or k == "loss_mask"), use_fp8_padding=use_fp8_padding
if mtp_enable_train and post_process:
args = {}
# Use input_ids sequence length to ensure label and loss_mask alignment
input_ids_offsets = input_ids.offsets()
input_ids_lengths = input_ids_offsets.diff().tolist()

for k in ["label", "loss_mask"]:
v = logits_processor_args[k]
v = _convert_to_nested_tensor(v, input_ids_lengths)
logits_processor_args[k] = v
args[k] = preprocess_bshd_no_padding(
v, pre_process=True, need_roll=True, use_fp8_padding=use_fp8_padding
)[0]
for k, v in logits_processor_args.items()
}
model_kwargs["labels"] = args["label"].contiguous()
model_kwargs["loss_mask"] = args["loss_mask"].contiguous()

if logits_processor_args and "loss_mask" in logits_processor_args:
logits_processor_args.pop("loss_mask")

Expand Down
Loading
Loading