From ebfae8d6653eaca2c45f896e38d2ff8df659457e Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Sun, 15 Feb 2026 00:36:29 +0800 Subject: [PATCH 1/2] [doc] feat: Enable Megatron-Bridge for MTP Signed-off-by: Hollow Man --- docs/advance/mtp.md | 8 +++++--- verl/workers/megatron_workers.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/advance/mtp.md b/docs/advance/mtp.md index b4c5a25c631..5f1698d3ddc 100644 --- a/docs/advance/mtp.md +++ b/docs/advance/mtp.md @@ -2,19 +2,21 @@ **Author**: `https://github.com/meituan-search` -Last updated: 01/30/2026 +Last updated: 02/15/2026 # 1. Scope of Support Currently, RL training can be performed on mimo-7B-RL, Qwen-next, and Deepseek series models based on the MTP architecture. The support rules for training and inference engines are as follows: -- **Training Engine**: Only supports the `mbridge + megatron` combination; other training engines are not compatible at this time; +- **Training Engine**: Only supports the `mbridge/Megatron-Bridge + megatron` combination; other training engines are not compatible at this time; - **Inference Engine**: Compatible with all engines, but the model must be in the corresponding engine's compatibility list; - **Dependency Versions**: - - mbridge: Use the specified branch: [https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp](https://github.com/ArronHZG/mbridge/tree/feature/verl_mtp) (will be merged into the main branch in the future); + - mbridge: Apply the patches and review suggestions from PR: [#62](https://github.com/ISEEKYAN/mbridge/pull/62) (will be merged into the main branch in the future); + + - Megatron-Bridge: Apply the patches and review suggestions from PR if you want to try out mimo-7B-RL: [#2387](https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/2387) (will be merged into the main branch in the future); - megatron: Use the latest dev version (commit: [23e092f41ec8bc659020e401ddac9576c1cfed7e](https://github.com/NVIDIA/Megatron-LM/tree/23e092f41ec8bc659020e401ddac9576c1cfed7e)), which supports MTP + CP training methods. diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 2dc426cd7f0..6a195cf4216 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -155,7 +155,6 @@ def _init_hf_config_and_tf_config( if enable_mtp: assert hf_config.num_nextn_predict_layers > 0, "MTP requires at least one nextn_predict_layer" assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True" - assert megatron_config.vanilla_mbridge, "MTP requires vanilla_mbridge to be True" override_transformer_config["mtp_loss_scaling_factor"] = self.config.model.mtp.mtp_loss_scaling_factor else: if hasattr(hf_config, "num_nextn_predict_layers"): From 84652eafa5f8fb415ab827ed8db9b51393c41916 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Sun, 15 Feb 2026 03:11:03 +0800 Subject: [PATCH 2/2] Update mtp patch for latest mcore Signed-off-by: Hollow Man --- verl/models/mcore/mtp_patch.py | 167 +++++++++++++++++---------- verl/workers/actor/megatron_actor.py | 5 + 2 files changed, 112 insertions(+), 60 deletions(-) diff --git a/verl/models/mcore/mtp_patch.py b/verl/models/mcore/mtp_patch.py index 117b6e3f28c..fadf5b7bd52 100644 --- a/verl/models/mcore/mtp_patch.py +++ b/verl/models/mcore/mtp_patch.py @@ -20,11 +20,7 @@ import torch from megatron.core import parallel_state from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.transformer.multi_token_prediction import ( - MTPLossAutoScaler, - MTPLossLoggingHelper, - roll_tensor, -) +from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor try: from megatron.core.utils import unwrap_model @@ -78,19 +74,45 @@ def _megatron_gptmodel_postprocess( runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, + **kwargs, ): - """Postprocesses decoder hidden states to generate logits or compute loss. + """Compatibility patch for GPTModel._postprocess. - Applies Multi-Token Prediction if enabled, generates output logits through - the output layer, and computes language model loss when labels are provided. + For inference (`labels is None`), delegate to the upstream implementation to stay + aligned with Megatron-Core updates. + + For training (`labels is not None`), keep VERL's MTP behavior and always return + logits (instead of CE loss) so PPO paths can compute custom losses from logits. """ + # Keep inference path aligned with whatever upstream Megatron currently expects. + if labels is None: + return self._postprocess_backup( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=mtp_in_postprocess, + loss_mask=loss_mask, + decoder_input=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + **kwargs, + ) - # logits and loss + # Training path: keep logits for external loss computation. output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if mtp_in_postprocess and labels is not None: + if mtp_in_postprocess: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -109,60 +131,85 @@ def _megatron_gptmodel_postprocess( if not self.post_process: return hidden_states - # Skip when mtp_num_layers is None or 0 - if self.config.mtp_num_layers and labels is not None: - mtp_labels = labels.clone() - - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) - hidden_states = hidden_states_list[0] - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - loss_mask = torch.ones_like(mtp_labels) - for mtp_layer_number in range(self.config.mtp_num_layers): - # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor( - mtp_labels, - shifts=-1, - dims=-1, - cp_group=self.cp_group, + # Skip when mtp_num_layers is None or 0. + if self.config.mtp_num_layers: + cp_group = None + if getattr(self, "pg_collection", None) is not None: + cp_group = self.pg_collection.cp + elif hasattr(self, "cp_group"): + cp_group = self.cp_group + + # Prefer upstream helper when available (newer Megatron-LM). + try: + from megatron.core.transformer.multi_token_prediction import process_mtp_loss + + hidden_states = process_mtp_loss( + hidden_states=hidden_states, + labels=labels, + loss_mask=loss_mask, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + is_training=self.training, + compute_language_model_loss=self.compute_language_model_loss, + config=self.config, + cp_group=cp_group, packed_seq_params=packed_seq_params, ) - loss_mask, num_tokens = roll_tensor( - loss_mask, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) - - # Compute mtp loss without storing logits to save memory. - mtp_loss = self.compute_output_layer_and_language_model_loss( - hidden_states_list[mtp_layer_number + 1], - labels=mtp_labels, - weight=self.shared_embedding_or_output_weight(), - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ - "weight": output_weight, - "runtime_gather_output": runtime_gather_output, - }, - ) + except (ImportError, AttributeError, TypeError): + # Fallback for older Megatron-LM versions without process_mtp_loss API. + mtp_labels = labels.clone() + + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) + for mtp_layer_number in range(self.config.mtp_num_layers): + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( + mtp_labels, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + loss_mask, num_tokens = roll_tensor( + loss_mask, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) - mtp_loss = loss_mask * mtp_loss - if self.training: - # TODO(shifangx): remove the use of parallel_state here - # after moving loss logging to loss_func in pretrain_gpt.py - MTPLossLoggingHelper.save_loss_to_tracker( - torch.sum(mtp_loss) / num_tokens, - mtp_layer_number, - self.config.mtp_num_layers, - avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + # Compute mtp loss without storing logits to save memory. + mtp_loss = self.compute_output_layer_and_language_model_loss( + hidden_states_list[mtp_layer_number + 1], + labels=mtp_labels, + weight=self.shared_embedding_or_output_weight(), + sequence_parallel_enabled=self.output_layer.sequence_parallel, + column_parallel_linear=self.output_layer, + col_linear_kwargs={ + "weight": output_weight, + "runtime_gather_output": runtime_gather_output, + }, ) - mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers - if self.config.calculate_per_token_loss: - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) - else: - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) + + mtp_loss = loss_mask * mtp_loss + if self.training: + # TODO(shifangx): remove the use of parallel_state here + # after moving loss logging to loss_func in pretrain_gpt.py + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + mtp_layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + ) + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) + else: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) # [s b h] => [b s h] diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 7fdaa6e9811..16c163410bc 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -139,6 +139,11 @@ def __init__( assert self.mtp_config.enable, "MTP requires mtp_config.enable to be True" self.use_fused_kernels = self.config.get("use_fused_kernels", False) + if getattr(self.mtp_config, "enable", False) and self.use_fused_kernels: + self.use_fused_kernels = False + logger.warning_once( + "MTP is not compatible with fused kernels for now. Automatically disable use_fused_kernels." + ) if self.use_fused_kernels and not getattr(self.config, "overlap_moe_expert_parallel_comm", False): # do not patch if overlap_moe_expert_parallel_comm is enabled logger.warning_once(