From 1bb4a43d9849d6bc7afab8ad4159741a2e8fca1e Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 15:56:09 +0800 Subject: [PATCH 01/14] revert mtp_patch --- .../shell/dapo_7b_math_fsdp2_4_4.sh | 2 +- verl/models/mcore/mtp_patch.py | 167 +++++++----------- 2 files changed, 61 insertions(+), 108 deletions(-) diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh index 8bf73be8c85..ac765a58fd1 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh @@ -17,7 +17,7 @@ TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} rollout_mode="async" -rollout_name="vllm" # sglang or vllm +rollout_name="sglang" # sglang or vllm if [ "$rollout_mode" = "async" ]; then export VLLM_USE_V1=1 return_raw_chat="True" diff --git a/verl/models/mcore/mtp_patch.py b/verl/models/mcore/mtp_patch.py index fadf5b7bd52..117b6e3f28c 100644 --- a/verl/models/mcore/mtp_patch.py +++ b/verl/models/mcore/mtp_patch.py @@ -20,7 +20,11 @@ import torch from megatron.core import parallel_state from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor +from megatron.core.transformer.multi_token_prediction import ( + MTPLossAutoScaler, + MTPLossLoggingHelper, + roll_tensor, +) try: from megatron.core.utils import unwrap_model @@ -74,45 +78,19 @@ def _megatron_gptmodel_postprocess( runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, - **kwargs, ): - """Compatibility patch for GPTModel._postprocess. + """Postprocesses decoder hidden states to generate logits or compute loss. - For inference (`labels is None`), delegate to the upstream implementation to stay - aligned with Megatron-Core updates. - - For training (`labels is not None`), keep VERL's MTP behavior and always return - logits (instead of CE loss) so PPO paths can compute custom losses from logits. + Applies Multi-Token Prediction if enabled, generates output logits through + the output layer, and computes language model loss when labels are provided. """ - # Keep inference path aligned with whatever upstream Megatron currently expects. - if labels is None: - return self._postprocess_backup( - hidden_states=hidden_states, - input_ids=input_ids, - position_ids=position_ids, - labels=labels, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - mtp_in_postprocess=mtp_in_postprocess, - loss_mask=loss_mask, - decoder_input=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, - **kwargs, - ) - # Training path: keep logits for external loss computation. + # logits and loss output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if mtp_in_postprocess: + if mtp_in_postprocess and labels is not None: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -131,85 +109,60 @@ def _megatron_gptmodel_postprocess( if not self.post_process: return hidden_states - # Skip when mtp_num_layers is None or 0. - if self.config.mtp_num_layers: - cp_group = None - if getattr(self, "pg_collection", None) is not None: - cp_group = self.pg_collection.cp - elif hasattr(self, "cp_group"): - cp_group = self.cp_group - - # Prefer upstream helper when available (newer Megatron-LM). - try: - from megatron.core.transformer.multi_token_prediction import process_mtp_loss - - hidden_states = process_mtp_loss( - hidden_states=hidden_states, - labels=labels, - loss_mask=loss_mask, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - is_training=self.training, - compute_language_model_loss=self.compute_language_model_loss, - config=self.config, - cp_group=cp_group, + # Skip when mtp_num_layers is None or 0 + if self.config.mtp_num_layers and labels is not None: + mtp_labels = labels.clone() + + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( + mtp_labels, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + loss_mask, num_tokens = roll_tensor( + loss_mask, + shifts=-1, + dims=-1, + cp_group=self.cp_group, packed_seq_params=packed_seq_params, ) - except (ImportError, AttributeError, TypeError): - # Fallback for older Megatron-LM versions without process_mtp_loss API. - mtp_labels = labels.clone() - - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) - hidden_states = hidden_states_list[0] - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - loss_mask = torch.ones_like(mtp_labels) - for mtp_layer_number in range(self.config.mtp_num_layers): - # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor( - mtp_labels, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) - loss_mask, num_tokens = roll_tensor( - loss_mask, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) - # Compute mtp loss without storing logits to save memory. - mtp_loss = self.compute_output_layer_and_language_model_loss( - hidden_states_list[mtp_layer_number + 1], - labels=mtp_labels, - weight=self.shared_embedding_or_output_weight(), - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ - "weight": output_weight, - "runtime_gather_output": runtime_gather_output, - }, - ) + # Compute mtp loss without storing logits to save memory. + mtp_loss = self.compute_output_layer_and_language_model_loss( + hidden_states_list[mtp_layer_number + 1], + labels=mtp_labels, + weight=self.shared_embedding_or_output_weight(), + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ + "weight": output_weight, + "runtime_gather_output": runtime_gather_output, + }, + ) - mtp_loss = loss_mask * mtp_loss - if self.training: - # TODO(shifangx): remove the use of parallel_state here - # after moving loss logging to loss_func in pretrain_gpt.py - MTPLossLoggingHelper.save_loss_to_tracker( - torch.sum(mtp_loss) / num_tokens, - mtp_layer_number, - self.config.mtp_num_layers, - avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), - ) - mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers - if self.config.calculate_per_token_loss: - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) - else: - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) + mtp_loss = loss_mask * mtp_loss + if self.training: + # TODO(shifangx): remove the use of parallel_state here + # after moving loss logging to loss_func in pretrain_gpt.py + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + mtp_layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + ) + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) + else: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) # [s b h] => [b s h] From bf4a6b8ddbb0e846ea893e5509898e342db51682 Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 15:57:22 +0800 Subject: [PATCH 02/14] revert mtp_patch --- docs/advance/mtp.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/advance/mtp.md b/docs/advance/mtp.md index 48489e342c1..b342b670d74 100644 --- a/docs/advance/mtp.md +++ b/docs/advance/mtp.md @@ -14,7 +14,7 @@ Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek s - **Dependency Versions**: - - mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (will be merged into the main branch in the future); + - mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (Already merged into the main branch); - Megatron-Bridge: Apply the patches and review suggestions from PR if you want to try out mimo-7B-RL: [#2387](https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/2387) (will be merged into the main branch in the future); From a39609a6105725dd4294b1daa5e6b9180aff6d83 Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 15:59:18 +0800 Subject: [PATCH 03/14] DAPO-mimo-7b-rl-megatron --- ...dapo_mimo_7b_with_mtp_math_megatron_4_4.sh | 161 ++++++++++++++++++ .../shell/dapo_7b_math_fsdp2_4_4.sh | 2 +- 2 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh diff --git a/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh new file mode 100644 index 00000000000..17d8fbca059 --- /dev/null +++ b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh @@ -0,0 +1,161 @@ +#!/usr/bin/env bash + +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-mimo-7b-rl-megatron' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 1)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 1)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/examples/mtp_trainer/runtime_env.yaml"} +NNODES=${NNODES:-16} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/MiMo-7B-RL"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=False +gen_tp=2 +train_tp=1 +train_pp=1 +train_cp=1 + +train_prompt_bsz=16 +n_resp_per_prompt=8 +train_prompt_mini_bsz=16 + +mtp_params=( + actor_rollout_ref.actor.megatron.use_mbridge=True + actor_rollout_ref.model.mtp.enable=True + actor_rollout_ref.model.mtp.enable_train=True + actor_rollout_ref.model.mtp.mtp_loss_scaling_factor=0.1 + actor_rollout_ref.model.mtp.detach_encoder=True + actor_rollout_ref.model.mtp.enable_rollout=True + ) + +fully_async=( + data.train_batch_size=0 + data.gen_batch_size=1 + trainer.test_freq=10 + actor_rollout_ref.hybrid_engine=False + actor_rollout_ref.rollout.calculate_log_probs=True + actor_rollout_ref.actor.optim.lr_decay_steps=51200 + rollout.total_rollout_steps=$(((512*100))) + trainer.nnodes=1 + trainer.n_gpus_per_node=4 + rollout.nnodes=1 + rollout.n_gpus_per_node=4 + async_training.staleness_threshold=0.5 + async_training.trigger_parameter_sync_step=4 + async_training.require_batches=1 + async_training.partial_rollout=True +) + +python -m verl.experimental.fully_async_policy.fully_async_main \ + --config-path=config \ + --config-name='fully_async_ppo_megatron_trainer.yaml'\ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.prometheus.enable=True \ + actor_rollout_ref.rollout.prometheus.port=44398 \ + actor_rollout_ref.model.trust_remote_code=True \ + data.trust_remote_code=True \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 \ + trainer.total_epochs=10 \ + "${mtp_params[@]}" \ + "${fully_async[@]}" \ No newline at end of file diff --git a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh index ac765a58fd1..8bf73be8c85 100644 --- a/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh +++ b/verl/experimental/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh @@ -17,7 +17,7 @@ TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} rollout_mode="async" -rollout_name="sglang" # sglang or vllm +rollout_name="vllm" # sglang or vllm if [ "$rollout_mode" = "async" ]; then export VLLM_USE_V1=1 return_raw_chat="True" From 08167a40b8048ace18464ce4d88c974553471807 Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 23:20:16 +0800 Subject: [PATCH 04/14] model engine mtp --- verl/models/mcore/model_forward.py | 93 +++++++++++++++---- verl/models/mcore/model_forward_fused.py | 4 +- verl/models/mcore/registry.py | 8 +- verl/models/mcore/util.py | 32 ++++++- verl/utils/megatron/router_replay_utils.py | 2 +- verl/utils/megatron_utils.py | 53 ++++++----- .../engine/megatron/transformer_impl.py | 2 +- verl/workers/engine_workers.py | 6 +- 8 files changed, 146 insertions(+), 54 deletions(-) diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index f8fbd9d24f4..a9becf93564 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -15,6 +15,7 @@ # limitations under the License. import torch +from torch.nested._internal.nested_tensor import NestedTensor from verl.utils.megatron_utils import unwrap_model from verl.workers.config import MtpConfig @@ -65,14 +66,19 @@ def model_forward( model_kwargs["video_grid_thw"] = multi_modal_inputs["video_grid_thw"].to(input_ids.device) batch_size, seq_len = attention_mask.shape[:2] + mtp_enable_train = mtp_config and mtp_config.enable_train + if data_format == "thd": input_ids_rmpad, packed_seq_params = preprocess_packed_seqs( - input_ids, attention_mask, pre_process=pre_process or post_process, use_fp8_padding=use_fp8_padding + input_ids, + attention_mask, + pre_process=pre_process or (post_process and mtp_enable_train), + use_fp8_padding=use_fp8_padding, ) input_ids_rmpad = input_ids_rmpad.contiguous() # when pp > 1 and processor is not None, we need to pass the labels and loss_mask to the model - if mtp_config and mtp_config.enable_train and post_process: + if mtp_enable_train and post_process: args = { k: preprocess_packed_seqs(v, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding)[0] for k, v in logits_processor_args.items() @@ -158,6 +164,38 @@ def model_forward( return model_forward +def _convert_to_nested_tensor(v, input_ids_lengths): + """Convert regular tensor to NestedTensor, slicing according to input_ids_lengths. + + Args: + v: Tensor to convert, shape [batch, seq_len] + input_ids_lengths: List of valid lengths for each sample + + Returns: + Converted NestedTensor + """ + if isinstance(v, NestedTensor): + return v + + batch_size = v.shape[0] + assert len(input_ids_lengths) == batch_size, ( + f"len(input_ids_lengths)={len(input_ids_lengths)} != batch_size={batch_size}" + ) + + v_split_list = [] + for i in range(batch_size): + vi = v[i] + target_len = input_ids_lengths[i] + if vi.shape[0] > target_len: + vi = vi[:target_len] + elif vi.shape[0] < target_len: + vi = torch.cat([vi, torch.ones(target_len - vi.shape[0], dtype=vi.dtype, device=vi.device)]) + v_split_list.append(vi) + + v = torch.nested.nested_tensor(v_split_list, layout=torch.jagged) + return v + + def gptmodel_forward_no_padding( model, input_ids, @@ -168,7 +206,7 @@ def gptmodel_forward_no_padding( vision_model=False, pad_token_id=None, data_format: str = "thd", - enable_mtp: bool = False, + mtp_enable_train: bool = False, ): """Default forward pass for GPT models with optional sequence packing.""" @@ -191,20 +229,30 @@ def gptmodel_forward_no_padding( batch_size = input_ids.shape[0] if data_format == "thd": - input_ids_rmpad, packed_seq_params = preprocess_thd_no_padding( - input_ids, pre_process=pre_process, use_fp8_padding=use_fp8_padding + input_ids_rmpad, packed_seq_params, position_ids_rmpad = preprocess_thd_no_padding( + input_ids, pre_process=pre_process or (post_process and mtp_enable_train), use_fp8_padding=use_fp8_padding ) input_ids_rmpad = input_ids_rmpad.contiguous() - if enable_mtp and post_process: - args = { - k: preprocess_thd_no_padding( - v, pre_process=True, need_roll=(k == "label" or k == "loss_mask"), use_fp8_padding=use_fp8_padding + args = {} + if mtp_enable_train and post_process: + # 使用 input_ids 的序列长度来确保 label 和 loss_mask 对齐 + input_ids_offsets = input_ids.offsets() + input_ids_lengths = input_ids_offsets.diff().tolist() + + print(f"hzg input_ids_lengths={input_ids_lengths}") + + for k in ["label", "loss_mask"]: + v = logits_processor_args[k] + v = _convert_to_nested_tensor(v, input_ids_lengths) + logits_processor_args[k] = v + args[k] = preprocess_thd_no_padding( + v, pre_process=True, need_roll=True, use_fp8_padding=use_fp8_padding )[0] - for k, v in logits_processor_args.items() - } + model_kwargs["labels"] = args["label"].contiguous() model_kwargs["loss_mask"] = args["loss_mask"].contiguous() + if logits_processor_args and "loss_mask" in logits_processor_args: logits_processor_args.pop("loss_mask") @@ -220,7 +268,7 @@ def gptmodel_forward_no_padding( output_orig = model( input_ids=input_ids_rmpad, attention_mask=attention_mask, - position_ids=None, + position_ids=position_ids_rmpad if not vision_model else None, # vision models will calculate position_ids packed_seq_params=packed_seq_params, **model_kwargs, ) @@ -251,18 +299,25 @@ def gptmodel_forward_no_padding( """ input_ids_bshd, attention_mask_bshd, position_ids_bshd = preprocess_bshd_no_padding( - input_ids, pre_process=pre_process, use_fp8_padding=use_fp8_padding + input_ids, pre_process=pre_process or (post_process and mtp_enable_train), use_fp8_padding=use_fp8_padding ) - if enable_mtp and post_process: - args = { - k: preprocess_bshd_no_padding( - v, pre_process=True, need_roll=(k == "label" or k == "loss_mask"), use_fp8_padding=use_fp8_padding + if mtp_enable_train and post_process: + args = {} + # 使用 input_ids 的序列长度来确保 label 和 loss_mask 对齐 + input_ids_offsets = input_ids.offsets() + input_ids_lengths = input_ids_offsets.diff().tolist() + + for k in ["label", "loss_mask"]: + v = logits_processor_args[k] + v = _convert_to_nested_tensor(v, input_ids_lengths) + logits_processor_args[k] = v + args[k] = preprocess_bshd_no_padding( + v, pre_process=True, need_roll=True, use_fp8_padding=use_fp8_padding )[0] - for k, v in logits_processor_args.items() - } model_kwargs["labels"] = args["label"].contiguous() model_kwargs["loss_mask"] = args["loss_mask"].contiguous() + if logits_processor_args and "loss_mask" in logits_processor_args: logits_processor_args.pop("loss_mask") diff --git a/verl/models/mcore/model_forward_fused.py b/verl/models/mcore/model_forward_fused.py index 935569be2f7..273ade1f69e 100644 --- a/verl/models/mcore/model_forward_fused.py +++ b/verl/models/mcore/model_forward_fused.py @@ -153,7 +153,7 @@ def fused_forward_no_padding( fp8 = unwrap_model(model).config.fp8 use_fp8_padding = fp8 in ["e4m3", "hybrid"] - input_ids_rmpad, packed_seq_params = preprocess_thd_no_padding( + input_ids_rmpad, packed_seq_params, _ = preprocess_thd_no_padding( input_ids, pre_process=pre_process, use_fp8_padding=use_fp8_padding ) input_ids_rmpad = input_ids_rmpad.contiguous() @@ -177,7 +177,7 @@ def fused_forward_no_padding( 0 ) < seqlens_in_batch.unsqueeze(1) - labels_rmpad, _ = preprocess_thd_no_padding( + labels_rmpad, _, _ = preprocess_thd_no_padding( labels, pre_process=True, need_roll=True, use_fp8_padding=use_fp8_padding ) labels_rmpad = labels_rmpad.contiguous() diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py index b1b5c03406b..42c411e8602 100644 --- a/verl/models/mcore/registry.py +++ b/verl/models/mcore/registry.py @@ -131,7 +131,7 @@ class SupportedModel(Enum): QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration" QWEN3_VL = "Qwen3VLForConditionalGeneration" GPT_OSS = "GptOssForCausalLM" - MiMO = "MiMoForCausalLM" + MIMO = "MiMoForCausalLM" # Registry for model configuration converters @@ -181,7 +181,7 @@ class SupportedModel(Enum): SupportedModel.QWEN3_TOKEN_CLASSIFICATION: model_forward_gen(), SupportedModel.LLAMA_TOKEN_CLASSIFICATION: model_forward_gen(), SupportedModel.GPT_OSS: model_forward_gen(), - SupportedModel.MiMO: model_forward_gen(), + SupportedModel.MIMO: model_forward_gen(), } # Registry for model forward functions @@ -201,7 +201,7 @@ class SupportedModel(Enum): SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding, SupportedModel.LLAMA_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding, SupportedModel.GPT_OSS: gptmodel_forward_no_padding, - SupportedModel.MiMO: gptmodel_forward_no_padding, + SupportedModel.MIMO: gptmodel_forward_no_padding, } # Registry for model forward functions @@ -219,7 +219,7 @@ class SupportedModel(Enum): SupportedModel.DEEPSEEK_V3: fused_forward_model_gen(), SupportedModel.GLM4_MOE: fused_forward_model_gen(), SupportedModel.GPT_OSS: fused_forward_model_gen(), - SupportedModel.MiMO: fused_forward_model_gen(), + SupportedModel.MIMO: fused_forward_model_gen(), } # Registry for model weight converters diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py index e51bb5359c8..b5d16c48fa8 100644 --- a/verl/models/mcore/util.py +++ b/verl/models/mcore/util.py @@ -16,6 +16,7 @@ import logging import math import os +from typing import Optional import torch from megatron.core import parallel_state as mpu @@ -289,11 +290,9 @@ def postprocess_packed_seqs_for_dict_output( ### No padding versions for model engine ### inputs are nested tensors - - def preprocess_thd_no_padding( input_ids: torch.Tensor, pre_process: bool = True, need_roll: bool = False, use_fp8_padding: bool = False -) -> tuple[torch.Tensor, PackedSeqParams]: +) -> tuple[torch.Tensor, PackedSeqParams, Optional[torch.Tensor]]: """ Preprocess packed sequences CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 @@ -341,14 +340,20 @@ def preprocess_thd_no_padding( shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size if pre_process: input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + position_ids_rmpad = torch.zeros(shape, dtype=torch.long, device=input_ids.device) if need_roll: saved_roll_dict = {} + saved_position_roll_dict = {} for i in range(batch_size): # Use Python int, so no GPU→CPU sync in the loop if cp_size <= 1: seqlen = seqlens_in_batch_cpu[i] start_idx = cu_seqlens_padded_cpu[i] input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i] + # Build position_ids: 0, 1, 2, ..., seqlen-1 for this sequence + position_ids_rmpad[start_idx : start_idx + seqlen] = torch.arange( + seqlen, dtype=torch.long, device=input_ids.device + ) continue seqlen_padded_i = seqlens_in_batch_padded_cpu[i] @@ -374,6 +379,11 @@ def preprocess_thd_no_padding( half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) ] + # Build position_ids for the first chunk + position_ids_rmpad[start_idx : start_idx + half_seqlen] = torch.arange( + half_seqlen * cp_rank, half_seqlen * (cp_rank + 1), dtype=torch.long, device=input_ids.device + ) + remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) remain_end = seqlen_padded_i - half_seqlen * cp_rank remain_end = min(remain_end, d.shape[0]) @@ -382,21 +392,33 @@ def preprocess_thd_no_padding( input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ remain_start:remain_end ] + # Build position_ids for the remaining chunk + position_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = torch.arange( + seqlen_padded_i - remain_len, seqlen_padded_i, dtype=torch.long, device=input_ids.device + ) if need_roll: # Handle roll for cp_size > 1 case saved_roll_dict[start_idx + half_seqlen - 1] = d[(cp_rank + 1) * half_seqlen] + saved_position_roll_dict[start_idx + half_seqlen - 1] = position_ids_rmpad[start_idx + half_seqlen - 1] if remain_len > 0: if remain_end == d.shape[0]: saved_roll_dict[start_idx + half_seqlen + remain_len - 1] = d[0] + saved_position_roll_dict[start_idx + half_seqlen + remain_len - 1] = 0 else: saved_roll_dict[start_idx + half_seqlen + remain_len - 1] = d[remain_end] + saved_position_roll_dict[start_idx + half_seqlen + remain_len - 1] = position_ids_rmpad[ + start_idx + half_seqlen + remain_len - 1 + ] if need_roll: input_ids_rmpad = torch.roll(input_ids_rmpad, shifts=-1, dims=0) + position_ids_rmpad = torch.roll(position_ids_rmpad, shifts=-1, dims=0) if len(saved_roll_dict) > 0: for k, v in saved_roll_dict.items(): input_ids_rmpad[k] = v + for k, v in saved_position_roll_dict.items(): + position_ids_rmpad[k] = v packed_seq_params = PackedSeqParams( qkv_format="thd", @@ -408,9 +430,9 @@ def preprocess_thd_no_padding( cu_seqlens_kv_padded=cu_seqlens_padded, ) if pre_process: - return input_ids_rmpad.unsqueeze(0), packed_seq_params + return input_ids_rmpad.unsqueeze(0), packed_seq_params, position_ids_rmpad.unsqueeze(0) else: - return input_ids, packed_seq_params + return input_ids, packed_seq_params, None def postprocess_thd_no_padding( diff --git a/verl/utils/megatron/router_replay_utils.py b/verl/utils/megatron/router_replay_utils.py index 1e08b2b44f7..5514778e413 100644 --- a/verl/utils/megatron/router_replay_utils.py +++ b/verl/utils/megatron/router_replay_utils.py @@ -271,7 +271,7 @@ def set_router_replay_data(layers_topk_idx, attention_mask, tf_config, vp_rank=N """ with torch.no_grad(): if layers_topk_idx.is_nested: - layers_topk_idx_rmpad, _ = preprocess_thd_no_padding(layers_topk_idx, pre_process=True) + layers_topk_idx_rmpad, _, _ = preprocess_thd_no_padding(layers_topk_idx, pre_process=True) else: layers_topk_idx_rmpad, _ = preprocess_packed_seqs(layers_topk_idx, attention_mask, pre_process=True) layers_topk_idx_rmpad = layers_topk_idx_rmpad.contiguous() # 1, dynamic_bs_all, layer_num, topk diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index aa93ea55087..820981c8e7d 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -1331,6 +1331,15 @@ def get_megatron_module_device(models: list[Any]) -> str: def check_mtp_config(model_config: HFModelConfig, engine_config: McoreEngineConfig): + """ + Check and configure MTP (Multi-Token Prediction) settings. + + Cases: + - mtp.enable == False and no MTP layers: return directly + - mtp.enable == False and has MTP layers: set num_nextn_predict_layers = 0 + - mtp.enable == True and has MTP layers: configure override_transformer_config + - mtp.enable == True and no MTP layers: raise ValueError + """ has_mtp = ( model_config.hf_config.num_nextn_predict_layers > 0 if hasattr(model_config.hf_config, "num_nextn_predict_layers") @@ -1338,36 +1347,38 @@ def check_mtp_config(model_config: HFModelConfig, engine_config: McoreEngineConf ) enable_mtp = model_config.mtp.enable - if "mtp_loss_scaling_factor" not in engine_config.override_transformer_config: - engine_config.override_transformer_config["mtp_loss_scaling_factor"] = model_config.mtp.mtp_loss_scaling_factor - - if enable_mtp and not model_config.mtp.enable_train: - # disable parameter update by configure the loss scale to 0 - engine_config.override_transformer_config["mtp_loss_scaling_factor"] = 0 - - # Modify the hf_config before initialization, and apply patch after innitialization - if enable_mtp and not has_mtp: - logger.error("enable mtp while model has no mtp layer, ignore model.mtp.enable") - model_config.mtp.enable = False - model_config.mtp.enable_train = False - elif has_mtp and not enable_mtp: + if not enable_mtp and not has_mtp: + return + elif not enable_mtp and has_mtp: model_config.hf_config.num_nextn_predict_layers = 0 + elif enable_mtp and not has_mtp: + raise ValueError("enable mtp while model has no mtp layer, please use a model with mtp layer") + elif enable_mtp and has_mtp: + if "mtp_loss_scaling_factor" not in engine_config.override_transformer_config: + engine_config.override_transformer_config["mtp_loss_scaling_factor"] = ( + model_config.mtp.mtp_loss_scaling_factor + ) + return def patch_engine_mtp(module, model_config): + """ + Apply MTP patches to the model module. + + Args: + module: The model module to patch. Can be a single module or a list of modules. + model_config: The model configuration containing MTP settings. + """ logger.warning("Applying mtp patch...") from verl.models.mcore.mtp_patch import patch_mtp_layer_get_embeddings, patch_postprocess print(module) - if isinstance(module, list): - for m in module: - patch_postprocess(m) - if model_config.mtp.detach_encoder: - patch_mtp_layer_get_embeddings(m) - else: - patch_postprocess(module) + + modules = module if isinstance(module, list) else [module] + for m in modules: + patch_postprocess(m) if model_config.mtp.detach_encoder: - patch_mtp_layer_get_embeddings(module) + patch_mtp_layer_get_embeddings(m) @torch.no_grad() diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 8fb18042978..892180fc5cc 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -807,7 +807,7 @@ def logits_processor(logits, label, temperature): vision_model=hasattr(self.model_config.hf_config, "vision_config"), pad_token_id=self.model_config.tokenizer.pad_token_id, data_format="thd" if self.engine_config.use_remove_padding else "bshd", - enable_mtp=self.model_config.mtp.enable_train, + mtp_enable_train=self.model_config.mtp.enable and self.model_config.mtp.enable_train, ) # Router replay: switch to backward replay mode for next backward pass diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index abca5cdb65b..27d45d161de 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -15,6 +15,7 @@ import logging import os from contextlib import nullcontext +from copy import deepcopy from functools import partial from itertools import chain @@ -481,7 +482,10 @@ def init_model(self): self.config.ref.use_dynamic_bsz = self.config.ref.pop("log_prob_use_dynamic_bsz", False) self.config.ref.ppo_max_token_len_per_gpu = self.config.ref.pop("log_prob_max_token_len_per_gpu", None) ref_config: ActorConfig = omega_conf_to_dataclass(self.config.ref) - ref_config.model_config = model_config + + # The ref model does not need to enable MTP; force it to false. + ref_config.model_config = deepcopy(model_config) + ref_config.model_config.mtp.enable = False # construct TrainingWorkerConfig ref_training_config = TrainingWorkerConfig( From 450a964b6746298705b9786fd627bb6a6b20b367 Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 23:23:03 +0800 Subject: [PATCH 05/14] thd --- .../config/fully_async_ppo_megatron_trainer.yaml | 4 ++++ .../config/fully_async_ppo_trainer.yaml | 5 +++++ .../shell/runtime_env_4_4.yaml | 16 ++++++++++++++++ 3 files changed, 25 insertions(+) create mode 100644 verl/experimental/fully_async_policy/shell/runtime_env_4_4.yaml diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml index 9acc742817e..9caad8c0fd7 100644 --- a/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml +++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml @@ -60,6 +60,10 @@ actor_rollout_ref: # Must use rollout log probs for training use_rollout_log_probs: True + model: + # To use remove padding (thd) + use_remove_padding: True + # Only then will the use of log probs be correct. # And it can be used in conjunction with other rollout_correction algorithms. algorithm: diff --git a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml index f0753d969ee..40b7f0acb53 100644 --- a/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml +++ b/verl/experimental/fully_async_policy/config/fully_async_ppo_trainer.yaml @@ -60,6 +60,11 @@ actor_rollout_ref: # Must use rollout log probs for training use_rollout_log_probs: True + model: + # To use remove padding (thd) + use_remove_padding: True + + # Only then will the use of log probs be correct. # And it can be used in conjunction with other rollout_correction algorithms. algorithm: diff --git a/verl/experimental/fully_async_policy/shell/runtime_env_4_4.yaml b/verl/experimental/fully_async_policy/shell/runtime_env_4_4.yaml new file mode 100644 index 00000000000..f28db74eef8 --- /dev/null +++ b/verl/experimental/fully_async_policy/shell/runtime_env_4_4.yaml @@ -0,0 +1,16 @@ +working_dir: ./ + +excludes: + - ".git/" + +env_vars: + VLLM_USE_V1: "1" + HYDRA_FULL_ERROR: "1" + NCCL_NVLS_ENABLE: "0" + NCCL_SOCKET_IFNAME: "eth0" + TMPDIR: "/tmp" + CUDA_HOME: "/usr/local/cuda" + CUDA_TMPDIR: "/tmp" + RAY_DATA_HOME: "/home/hadoop-djst-algoplat" + TENSORBOARD_DIR: "/home/hadoop-djst-algoplat/data/tensorboard/qwen2.5-7b-math/4-4-gsm8k-sglang" + VERL_LOGGING_LEVEL: "DEBUG" \ No newline at end of file From 4f9e74a8e9bb7d9c9ff88f17232bdbe0230e0b08 Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 23:28:27 +0800 Subject: [PATCH 06/14] train_prompt_bsz=128 --- .../test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh index 17d8fbca059..7f92725f757 100644 --- a/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh +++ b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh @@ -53,9 +53,9 @@ train_tp=1 train_pp=1 train_cp=1 -train_prompt_bsz=16 -n_resp_per_prompt=8 -train_prompt_mini_bsz=16 + +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 mtp_params=( actor_rollout_ref.actor.megatron.use_mbridge=True @@ -93,7 +93,6 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ data.truncation='left' \ data.max_prompt_length=${max_prompt_length} \ data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ algorithm.adv_estimator=${adv_estimator} \ algorithm.use_kl_in_reward=${use_kl_in_reward} \ From ff19c3ed94a0aba0d3af3e29009911c4d3301d9a Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 23:28:43 +0800 Subject: [PATCH 07/14] train_prompt_bsz=128 --- .../mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh index 7f92725f757..b6925466901 100644 --- a/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh +++ b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh @@ -53,7 +53,7 @@ train_tp=1 train_pp=1 train_cp=1 - +train_prompt_bsz=128 n_resp_per_prompt=16 train_prompt_mini_bsz=32 From 5d676c6a4bc5af94c3b2a573cb5978dc7bd4c2c8 Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 23:33:23 +0800 Subject: [PATCH 08/14] ref enable mtp false --- verl/workers/config/model.py | 1 + verl/workers/engine_workers.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/verl/workers/config/model.py b/verl/workers/config/model.py index 9205a99f038..1a11cfabc07 100644 --- a/verl/workers/config/model.py +++ b/verl/workers/config/model.py @@ -82,6 +82,7 @@ class HFModelConfig(BaseConfig): "architectures", "local_hf_config_path", "local_tokenizer_path", + "mtp", } path: str = MISSING diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index 27d45d161de..f1a09069f5d 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -43,7 +43,7 @@ from verl.utils.py_functional import append_to_dict from verl.utils.tensordict_utils import maybe_fix_3d_position_ids from verl.utils.torch_functional import allgather_dict_into_dict -from verl.workers.config import ActorConfig, HFModelConfig, RolloutConfig, TrainingWorkerConfig +from verl.workers.config import ActorConfig, HFModelConfig, MtpConfig, RolloutConfig, TrainingWorkerConfig from verl.workers.rollout.base import BaseRollout, get_rollout_class from verl.workers.utils.losses import ppo_loss @@ -485,7 +485,7 @@ def init_model(self): # The ref model does not need to enable MTP; force it to false. ref_config.model_config = deepcopy(model_config) - ref_config.model_config.mtp.enable = False + ref_config.model_config.mtp = MtpConfig(enable=False) # construct TrainingWorkerConfig ref_training_config = TrainingWorkerConfig( From 137be6d34d2cf0d7ba5b3024521e1eb39341a73b Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 23:39:52 +0800 Subject: [PATCH 09/14] rm log --- verl/models/mcore/model_forward.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index a9becf93564..cf5ef5467a1 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -240,8 +240,6 @@ def gptmodel_forward_no_padding( input_ids_offsets = input_ids.offsets() input_ids_lengths = input_ids_offsets.diff().tolist() - print(f"hzg input_ids_lengths={input_ids_lengths}") - for k in ["label", "loss_mask"]: v = logits_processor_args[k] v = _convert_to_nested_tensor(v, input_ids_lengths) From 43939ff3dfb202a9d9159b0181488f8617830620 Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 23:47:27 +0800 Subject: [PATCH 10/14] rm log --- .../shell/runtime_env_4_4.yaml | 16 ---------------- 1 file changed, 16 deletions(-) delete mode 100644 verl/experimental/fully_async_policy/shell/runtime_env_4_4.yaml diff --git a/verl/experimental/fully_async_policy/shell/runtime_env_4_4.yaml b/verl/experimental/fully_async_policy/shell/runtime_env_4_4.yaml deleted file mode 100644 index f28db74eef8..00000000000 --- a/verl/experimental/fully_async_policy/shell/runtime_env_4_4.yaml +++ /dev/null @@ -1,16 +0,0 @@ -working_dir: ./ - -excludes: - - ".git/" - -env_vars: - VLLM_USE_V1: "1" - HYDRA_FULL_ERROR: "1" - NCCL_NVLS_ENABLE: "0" - NCCL_SOCKET_IFNAME: "eth0" - TMPDIR: "/tmp" - CUDA_HOME: "/usr/local/cuda" - CUDA_TMPDIR: "/tmp" - RAY_DATA_HOME: "/home/hadoop-djst-algoplat" - TENSORBOARD_DIR: "/home/hadoop-djst-algoplat/data/tensorboard/qwen2.5-7b-math/4-4-gsm8k-sglang" - VERL_LOGGING_LEVEL: "DEBUG" \ No newline at end of file From 0ad43aef3ef6596d98d3d67d87995315533032f5 Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 23:47:59 +0800 Subject: [PATCH 11/14] rm log --- verl/models/mcore/model_forward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index cf5ef5467a1..9b975de8c9b 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -236,7 +236,7 @@ def gptmodel_forward_no_padding( args = {} if mtp_enable_train and post_process: - # 使用 input_ids 的序列长度来确保 label 和 loss_mask 对齐 + # Use input_ids sequence length to ensure label and loss_mask alignment input_ids_offsets = input_ids.offsets() input_ids_lengths = input_ids_offsets.diff().tolist() @@ -302,7 +302,7 @@ def gptmodel_forward_no_padding( if mtp_enable_train and post_process: args = {} - # 使用 input_ids 的序列长度来确保 label 和 loss_mask 对齐 + # Use input_ids sequence length to ensure label and loss_mask alignment input_ids_offsets = input_ids.offsets() input_ids_lengths = input_ids_offsets.diff().tolist() From eb6699f56803ddae803268e84c721674eecf527c Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Wed, 11 Mar 2026 23:57:09 +0800 Subject: [PATCH 12/14] update fully_async_policy mtp shell --- .../test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh index b6925466901..533643deb2d 100644 --- a/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh +++ b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh @@ -16,9 +16,9 @@ clip_ratio_low=0.2 clip_ratio_high=0.28 max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 1)) +max_response_length=$((1024 * 8)) enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 1)) +overlong_buffer_len=$((1024 * 4)) overlong_penalty_factor=1.0 loss_agg_mode="token-mean" @@ -49,7 +49,7 @@ actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) offload=False gen_tp=2 -train_tp=1 +train_tp=2 train_pp=1 train_cp=1 From 1b1c568484877a0254ab3fa13169e16bf9aa3d59 Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Thu, 12 Mar 2026 18:55:17 +0800 Subject: [PATCH 13/14] before sync params, clear kv_cache --- ...dapo_mimo_7b_with_mtp_math_megatron_4_4.sh | 2 +- verl/checkpoint_engine/base.py | 20 +++++++++++++------ .../sglang_rollout/async_sglang_server.py | 9 +++++++-- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh index 533643deb2d..4883d11e8ea 100644 --- a/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh +++ b/examples/mtp_trainer/test_dapo_mimo_7b_with_mtp_math_megatron_4_4.sh @@ -119,7 +119,7 @@ python -m verl.experimental.fully_async_policy.fully_async_main \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.actor.optim.clip_grad=1.0 \ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ actor_rollout_ref.rollout.enable_chunked_prefill=True \ actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ diff --git a/verl/checkpoint_engine/base.py b/verl/checkpoint_engine/base.py index a4d99041748..6415d274b9e 100644 --- a/verl/checkpoint_engine/base.py +++ b/verl/checkpoint_engine/base.py @@ -393,11 +393,13 @@ def remove_replicas(self, replicas: list[RolloutReplica]): @auto_await async def sleep_replicas(self): """Sleep all rollout replicas: free weight and kv_cache device memory.""" - # skip sleep replicas for disaggregated rollout - if self.backend != "naive": - return await asyncio.gather(*[r.sleep() for r in self.replicas]) + @auto_await + async def wake_up_replicas(self): + """Resume all rollout replicas: recover kv_cache and weights device memory.""" + await asyncio.gather(*[r.wake_up() for r in self.replicas]) + @auto_await async def update_weights(self, global_steps: int = None): """Update weights from trainer to rollout replicas. @@ -424,14 +426,20 @@ async def update_weights(self, global_steps: int = None): # 3. build process group self.build_process_group(rollout) - # 4. update weights of all workers + # 4. sleep replicas to free kv_cache before weight sync (if free_cache_engine is enabled) + await self.sleep_replicas() + + # 5. update weights of all workers ray.get(trainer.update_weights(global_steps=global_steps) + rollout.update_weights(global_steps=global_steps)) - # 5. finalize all workers + # 6. finalize all workers ray.get( trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size) + rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size) ) - # 6. resume all unfinished requests for partial rollout + # 7. resume replicas to recover kv_cache (for free_cache_engine scenarios) + await self.wake_up_replicas() + + # 8. resume all unfinished requests for partial rollout await asyncio.gather(*[r.resume_generation() for r in self.replicas]) diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index c03ada27aba..09d415bb468 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -306,7 +306,10 @@ async def wake_up(self): await self.tokenizer_manager.resume_memory_occupation(obj, None) await self.tokenizer_manager.flush_cache() elif self.rollout_mode == RolloutMode.STANDALONE: - logger.info("skip wake_up in standalone mode") + # In standalone mode, resume kv_cache if free_cache_engine is enabled + obj = ResumeMemoryOccupationReqInput(tags=["kv_cache"]) + await self.tokenizer_manager.resume_memory_occupation(obj, None) + await self.tokenizer_manager.flush_cache() async def sleep(self): if self.node_rank != 0 or not self.config.free_cache_engine: @@ -319,7 +322,9 @@ async def sleep(self): obj = ReleaseMemoryOccupationReqInput(tags=["kv_cache", "weights"]) await self.tokenizer_manager.release_memory_occupation(obj, None) elif self.rollout_mode == RolloutMode.STANDALONE: - logger.info("skip sleep in standalone mode") + # In standalone mode, resume kv_cache if free_cache_engine is enabled + obj = ReleaseMemoryOccupationReqInput(tags=["kv_cache"]) + await self.tokenizer_manager.release_memory_occupation(obj, None) async def clear_kv_cache(self): if self.node_rank == 0: From 1517e8e250b02cc730a36c5bdd9e708b26d0546c Mon Sep 17 00:00:00 2001 From: ArronHZG Date: Fri, 13 Mar 2026 10:52:37 +0800 Subject: [PATCH 14/14] before build_process_group, clear kv_cache --- verl/checkpoint_engine/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/verl/checkpoint_engine/base.py b/verl/checkpoint_engine/base.py index 6415d274b9e..6b3a7cd2584 100644 --- a/verl/checkpoint_engine/base.py +++ b/verl/checkpoint_engine/base.py @@ -423,12 +423,12 @@ async def update_weights(self, global_steps: int = None): rollout = RayWorkerGroup(worker_handles=workers, ray_cls_with_init=RayClassWithInitArgs(cls=_worker_cls)) trainer = self.trainer - # 3. build process group - self.build_process_group(rollout) - - # 4. sleep replicas to free kv_cache before weight sync (if free_cache_engine is enabled) + # 3. sleep replicas to free kv_cache before weight sync (if free_cache_engine is enabled) await self.sleep_replicas() + # 4. build process group + self.build_process_group(rollout) + # 5. update weights of all workers ray.get(trainer.update_weights(global_steps=global_steps) + rollout.update_weights(global_steps=global_steps))