Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 80 additions & 48 deletions verl/models/mcore/mtp_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
roll_tensor,
)

try:
import megatron.core.transformer.multi_token_prediction as _mtp_module
except ImportError:
_PROCESS_MTP_LOSS: Callable | None = None
else:
_PROCESS_MTP_LOSS: Callable | None = getattr(_mtp_module, "process_mtp_loss", None)

try:
from megatron.core.utils import unwrap_model
except ImportError:
Expand Down Expand Up @@ -78,6 +85,7 @@ def _megatron_gptmodel_postprocess(
runtime_gather_output=None,
extra_block_kwargs=None,
inference_context=None,
**kwargs,
):
"""Postprocesses decoder hidden states to generate logits or compute loss.

Expand Down Expand Up @@ -111,58 +119,82 @@ def _megatron_gptmodel_postprocess(

# Skip when mtp_num_layers is None or 0
if self.config.mtp_num_layers and labels is not None:
mtp_labels = labels.clone()

hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(
mtp_labels,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
loss_mask, num_tokens = roll_tensor(
loss_mask,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
# Prefer upstream helper when available (newer Megatron-LM), using
# a cached reference resolved at module import time.
if _PROCESS_MTP_LOSS:
cp_group = None
if getattr(self, "pg_collection", None) is not None:
cp_group = self.pg_collection.cp
elif hasattr(self, "cp_group"):
cp_group = self.cp_group

hidden_states = _PROCESS_MTP_LOSS(
hidden_states=hidden_states,
labels=labels,
loss_mask=loss_mask,
output_layer=self.output_layer,
output_weight=output_weight,
runtime_gather_output=runtime_gather_output,
is_training=self.training,
compute_language_model_loss=self.compute_language_model_loss,
config=self.config,
cp_group=cp_group,
packed_seq_params=packed_seq_params,
)
else:
# Fallback for older Megatron-LM versions without process_mtp_loss API.
mtp_labels = labels.clone()

hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(
mtp_labels,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
loss_mask, num_tokens = roll_tensor(
loss_mask,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)

# Compute mtp loss without storing logits to save memory.
mtp_loss = self.compute_output_layer_and_language_model_loss(
hidden_states_list[mtp_layer_number + 1],
labels=mtp_labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
"weight": output_weight,
"runtime_gather_output": runtime_gather_output,
},
)

mtp_loss = loss_mask * mtp_loss
if self.training:
# TODO(shifangx): remove the use of parallel_state here
# after moving loss logging to loss_func in pretrain_gpt.py
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
mtp_layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
# Compute mtp loss without storing logits to save memory.
mtp_loss = self.compute_output_layer_and_language_model_loss(
hidden_states_list[mtp_layer_number + 1],
labels=mtp_labels,
weight=self.shared_embedding_or_output_weight(),
sequence_parallel_enabled=self.output_layer.sequence_parallel,
column_parallel_linear=self.output_layer,
col_linear_kwargs={
"weight": output_weight,
"runtime_gather_output": runtime_gather_output,
},
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)

mtp_loss = loss_mask * mtp_loss
if self.training:
# TODO(shifangx): remove the use of parallel_state here
# after moving loss logging to loss_func in pretrain_gpt.py
MTPLossLoggingHelper.save_loss_to_tracker(
torch.sum(mtp_loss) / num_tokens,
mtp_layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)

logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
# [s b h] => [b s h]
Expand Down
Loading