diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 6f86ff0225..d1a2b4e03d 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -367,6 +367,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - rollout_importance_sampling_threshold: 重要性采样权重的阈值,用于截断或屏蔽极端权重。默认为2.0。 - log_rollout_offpolicy_metrics: 当 `rollout_importance_sampling_mode` 未设置时,是否记录训推不一致诊断指标(KL、PPL、χ²等)。当设置了 `rollout_importance_sampling_mode` 时,指标会自动记录。默认为False。 - off_policy_sequence_mask_delta: Off-Policy Sequence Masking 阈值,来自 DeepSeek-V3.2 论文。当设置此值时,会计算每个序列的 `mean(old_policy_logps - policy_logps)`,若该值大于阈值且该序列的优势为负,则 mask 掉该序列不参与损失计算。默认为None,不启用。具体参考[文档](../Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md#off-policy-sequence-masking)。 +- router_replay_mode: 路由重放模式,可选项为`disabled`、`R2`、`R3`。默认为disabled,不启用路由重放。 内置奖励函数参数参考[文档](../Instruction/Command-line-parameters.md#奖励函数参数) diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index ff9a03ca5f..5a2de8ac4d 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -393,6 +393,7 @@ In addition to inheriting the training parameters, the following parameters are - rollout_importance_sampling_threshold: Threshold for importance sampling weights, used for truncating or masking extreme weights. Default is 2.0. - log_rollout_offpolicy_metrics: Whether to log training-inference mismatch diagnostic metrics (KL, PPL, χ², etc.) when `rollout_importance_sampling_mode` is not set. When `rollout_importance_sampling_mode` is set, metrics are always logged. Default is False. - off_policy_sequence_mask_delta: Off-Policy Sequence Masking threshold from [DeepSeek-V3.2 paper](https://arxiv.org/abs/2512.02556). When set, computes `mean(old_policy_logps - policy_logps)` for each sequence. If this value exceeds the threshold AND the sequence has negative advantage, the sequence is masked out from loss computation. For details, refer to the [documentation](../Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md#off-policy-sequence-masking). +- router_replay_mode: Router replay mode. Options are `disabled`,`R2`,`R3`. Default is disabled. Built-in reward function parameters refer to the [documentation](../Instruction/Command-line-parameters.md#reward-function-parameters). diff --git a/swift/infer_engine/grpo_vllm_engine.py b/swift/infer_engine/grpo_vllm_engine.py index 07436fec47..9037b23d93 100644 --- a/swift/infer_engine/grpo_vllm_engine.py +++ b/swift/infer_engine/grpo_vllm_engine.py @@ -122,7 +122,7 @@ def _create_chat_completion_response(self, result, inputs, request_config, reque finish_reason=output.finish_reason, logprobs=logprobs, token_ids=token_ids, - ) + routed_experts=getattr(output, 'routed_experts', None)) choices.append(choice) prompt_token_ids = None images_size = None diff --git a/swift/infer_engine/protocol.py b/swift/infer_engine/protocol.py index e98840a8f3..c4b5a1a22d 100644 --- a/swift/infer_engine/protocol.py +++ b/swift/infer_engine/protocol.py @@ -2,19 +2,45 @@ import base64 import io import json +import numpy as np import os import time import uuid from copy import deepcopy from dataclasses import asdict, dataclass, field, fields from PIL import Image -from pydantic import BaseModel, Field, field_validator -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from pydantic import AfterValidator, BaseModel, Field, PlainSerializer, field_validator +from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union from swift.template import Messages, Tool from swift.utils import remove_response +def serialize_ndarray(value): + if value is None: + return None + if isinstance(value, np.ndarray): + return { + 'data': base64.b64encode(value.tobytes()).decode('ascii'), + 'shape': value.shape, + 'dtype': str(value.dtype), + '__ndarray__': True + } + return value + + +def deserialize_ndarray(value): + if value is None: + return None + if isinstance(value, dict) and value.get('__ndarray__'): + data = base64.b64decode(value['data']) + return np.frombuffer(data, dtype=value['dtype']).reshape(value['shape']) + return value + + +NumpyArray = Annotated[Any, PlainSerializer(serialize_ndarray, return_type=Dict), AfterValidator(deserialize_ndarray)] + + @dataclass class InferRequest: """ @@ -392,6 +418,7 @@ class ChatCompletionResponseChoice: finish_reason: Literal['stop', 'length', None] logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None token_ids: Optional[List[int]] = None + routed_experts: Optional[NumpyArray] = None def to_cmpl_choice(self) -> 'CompletionResponseChoice': self = deepcopy(self) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 70f34fb2ba..d2d707f5a6 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -155,6 +155,8 @@ class RLHFMegatronArgumentsMixin: # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 top_entropy_quantile: float = 1.0 + router_replay_mode: Literal['disabled', 'R2', 'R3'] = 'disabled' + # ─────────────────────────── Not Supported Yet ─────────────────────────── # reward model diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index 600f29be35..a10319c1f0 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -583,6 +583,10 @@ def get_mcore_model_config(args, hf_config): if num_moe_experts is None: kwargs['expert_model_parallel_size'] = 1 kwargs['expert_tensor_parallel_size'] = 1 + + if args.router_replay_mode != 'disabled': + kwargs['moe_enable_routing_replay'] = True + config = MegatronModelConfig(**kwargs) config.hf_config = hf_config config.args = args diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 2b4b6e8742..236c7a4f62 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -27,9 +27,9 @@ from swift.megatron.callbacks import megatron_callbacks_map from swift.megatron.model import get_mcore_model from swift.megatron.tuners import LoraParallelLinear -from swift.megatron.utils import (copy_original_module_weight, disable_forward_pre_hook, enable_forward_pre_hook, - get_optimizer_param_scheduler, get_padding_to, init_persistent_async_worker, - initialize_tp_communicators, load_mcore_checkpoint, +from swift.megatron.utils import (apply_router_replay_patch, copy_original_module_weight, disable_forward_pre_hook, + enable_forward_pre_hook, get_optimizer_param_scheduler, get_padding_to, + init_persistent_async_worker, initialize_tp_communicators, load_mcore_checkpoint, logical_and_across_model_parallel_group, maybe_finalize_async_save, prepare_mcore_model, reduce_max_stat_across_model_parallel_group, save_mcore_checkpoint, should_disable_forward_pre_hook, warmup_jit_function, @@ -44,8 +44,11 @@ try: from megatron.core.optimizer import param_group_identifier_keys + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction except ImportError: param_group_identifier_keys = None + RouterReplay = None + RouterReplayAction = None mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0') @@ -55,6 +58,11 @@ class BaseMegatronTrainer(ABC): def __init__(self, args, template: Template): + # validate mcore version and patch routing_replay + self.enable_routing_replay = args.router_replay_mode != 'disabled' + if self.enable_routing_replay: + apply_router_replay_patch() + self.args = args self.template = template self.prepare_model() @@ -839,6 +847,10 @@ def train_step(self, train_data_iterator): self.optimizer.zero_grad() # TODO: refactor _replace_data_iterator data_iterator = self._replace_data_iterator(train_data_iterator) + + if self.enable_routing_replay: + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + metrics = forward_backward_func( forward_step_func=self.forward_step, data_iterator=data_iterator, @@ -855,6 +867,10 @@ def train_step(self, train_data_iterator): if update_successful: self.opt_param_scheduler.step(increment=args.global_batch_size) + if self.enable_routing_replay: + RouterReplay.clear_global_router_replay_action() + RouterReplay.clear_global_indices() + return metrics, grad_norm, update_successful def _aggregated_metrics(self, metrics, total_metrics): diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 81d111b599..ddbf691132 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -23,7 +23,7 @@ from swift.dataset import RowPreprocessor from swift.infer_engine.protocol import RequestConfig, RolloutInferRequest, RolloutOutput from swift.megatron.arguments import MegatronArguments, MegatronRLHFArguments -from swift.megatron.utils import forward_step_helper, get_padding_to, set_random_seed +from swift.megatron.utils import RouterReplayHelper, get_padding_to, set_random_seed, set_router_replay_data from swift.rewards import orms from swift.rlhf_trainers.grpo_trainer import DataType from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context, @@ -38,6 +38,12 @@ from .utils import gather, gather_object from .vocab_parallel_utils import compute_logps_and_entropy_from_logits +try: + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction +except ImportError: + RouterReplay = None + RouterReplayAction = None + logger = get_logger() @@ -271,6 +277,8 @@ def _batch_encode(self, infer_requests: List[Dict], template: Template, strict: return batched_inputs, error_list def _get_encoded_batch(self, encoded_list, rollout_batch, template): + original_seq_lengths = [item['length'] for item in encoded_list] + args = self.args encoded_batch = to_device(template.data_collator(encoded_list, padding_to=get_padding_to(args)), self.device) @@ -361,6 +369,43 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template): flat_lps, dtype=torch.float32, device=self.device) encoded_batch['rollout_per_token_logps'] = rollout_per_token_logps + + # Validating and processing routed_experts data in R3 mode + if self.args.router_replay_mode == 'R3': + routed_experts_list = [] + cur_seq_lengths = seq_lengths + if (seq_lengths.size(0) > batch_size): + cur_seq_lengths = seq_lengths[:batch_size].clone() + cur_seq_lengths[batch_size - 1] = seq_lengths[batch_size - 1:].sum() + for data, original_seq_len, cur_seq_len in zip(rollout_batch, original_seq_lengths, cur_seq_lengths): + routed_experts = data.get('routed_experts') + assert routed_experts is not None, ( + 'When router_replay_mode = R3, routed_experts must be in rollout data') + routed_experts = torch.tensor(routed_experts) + # The number of experts in the output can be 1 less than (prompt_length + response_token_count) + # This gap of 1 is expected + # For more details, please refer PR https://github.com/vllm-project/vllm/pull/28284 + experts_seq_len = routed_experts.shape[0] + assert (experts_seq_len == original_seq_len or experts_seq_len + 1 + == original_seq_len), (f'The seq_len of routed_experts({experts_seq_len}) in output ', + f'does not match the seq_len of data({original_seq_len}), ' + 'should be equal to or 1 less than the seq_len of data') + # Padding routed_experts(seq_len, layer_num, topk) seq_len to match the seq_len of the input_ids + padding_routed_experts = routed_experts + padding_to = cur_seq_len if template.padding_free else max_seq_len + padding_len = padding_to - experts_seq_len + if padding_len > 0: + padding_right = template.padding_side == 'right' + padding_routed_experts = nn.functional.pad(routed_experts, + (0, 0, 0, 0, 0, padding_len) if padding_right else + (0, 0, 0, 0, padding_len, 0), 'constant', 0) + routed_experts_list.append(padding_routed_experts) + if template.padding_free: + global_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0) + else: + global_routed_experts = torch.stack(routed_experts_list) + encoded_batch['routed_experts'] = global_routed_experts.to(device=self.device) + return encoded_batch def _generate_and_score_completions(self, batch): @@ -558,6 +603,9 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out if 'content' in choice.logprobs: rollout_logprobs = [item['logprob'] for item in choice.logprobs['content']] input_data['rollout_logprobs'] = [rollout_logprobs] + + # Step 6: Store rollout routed_experts for routing replay + input_data['routed_experts'] = choice.routed_experts return input_data assert len(batch) == len(outputs) @@ -938,7 +986,7 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: with self.null_ref_context() as ref_models: assert len(ref_models) == 1, 'GRPO currently does not support VPP.' ref_model = ref_models[0] - ref_per_token_logps_packed = self.compute_per_token_logps( + ref_per_token_logps_packed, _ = self.compute_per_token_logps( ref_model, iter([deepcopy(inputs)]), temperature=self.temperature) if self.template.padding_free: ref_per_token_logps, _ = pad_logps_back_to_batch( @@ -950,7 +998,13 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: ref_per_token_logps = ref_per_token_logps_packed batch['ref_per_token_logps'] = ref_per_token_logps - old_per_token_logps_packed = self.compute_per_token_logps( + if self.enable_routing_replay: + if self.args.router_replay_mode == 'R2': + RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) + if self.args.router_replay_mode == 'R3': + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + old_per_token_logps_packed, routing_topk_idx = self.compute_per_token_logps( self.unwrapped_models[0], iter([deepcopy(inputs)]), temperature=self.temperature) if self.template.padding_free: old_per_token_logps, _ = pad_logps_back_to_batch( @@ -962,6 +1016,11 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: old_per_token_logps = old_per_token_logps_packed batch['old_per_token_logps'] = old_per_token_logps + if self.enable_routing_replay: + batch['routed_experts'] = routing_topk_idx + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() + return batch def _compute_kl_from_batches(self, mini_batch_data: List[Dict[str, Any]]) -> torch.Tensor: @@ -1041,6 +1100,15 @@ def forward_step(self, data_iterator, model): 'seq_lengths': seq_lengths, }) data.pop('loss_scale', None) + + if self.enable_routing_replay and RouterReplayHelper.is_replay_backward_action(model.config): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + if self.enable_routing_replay and RouterReplayHelper.is_replay_forward_action(model.config): + layers_topk_idx = data.pop('routed_experts', None) + set_router_replay_data(layers_topk_idx, model.config) + inputs = self._prepare_model_inputs(data) labels = data['labels'] @@ -1091,6 +1159,11 @@ def forward_step(self, data_iterator, model): output_tensor = per_token_logps data['per_token_entropy'] = per_token_entropy + if self.enable_routing_replay and RouterReplayHelper.is_replay_forward_action(model.config): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) + return output_tensor, partial(self.loss_func, data=data) @profiling_decorator diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index f61b000fc6..fb3737081a 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -6,7 +6,8 @@ from transformers.utils import ContextManagers from swift.megatron.model import get_mcore_model -from swift.megatron.utils import forward_step_helper, load_mcore_checkpoint +from swift.megatron.utils import (RouterReplayHelper, forward_step_helper, get_local_topk_idx_for_current_rank, + get_router_replay_data, load_mcore_checkpoint, set_router_replay_data) from swift.rlhf_trainers.utils import identity_data_collator from swift.utils import get_logger from .base import BaseMegatronTrainer @@ -107,18 +108,30 @@ def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperatur Returns: per_token_logps tensor, or None if on a non-last PP stage + routing_topk_idx tensor, or None if disbale router replay """ data = self.get_batch(data_iterator) data.pop('loss_scale', None) labels = data.get('labels') + routing_topk_idx = None + global_topk_idx = data.pop('routed_experts', None) + if self.enable_routing_replay and RouterReplayHelper.is_replay_forward_action(model.config): + assert global_topk_idx is not None, 'When router_replay_mode = R3, routed_experts must be in data' + routing_topk_idx = get_local_topk_idx_for_current_rank(global_topk_idx, model.config, + data.get('packed_seq_params')) + set_router_replay_data(routing_topk_idx, model.config) + data_for_forward = {k: v for k, v in data.items() if k != 'labels'} context = torch.no_grad() if no_grad else nullcontext() with context: output_tensor = forward_step_helper(self.args, model, data_for_forward) + if self.enable_routing_replay and RouterReplayHelper.is_r2_record_action(model.config): + routing_topk_idx = get_router_replay_data(model.config) + if labels is None or output_tensor is None: - return None + return None, routing_topk_idx if temperature != 1.0: output_tensor.div_(temperature) @@ -133,7 +146,7 @@ def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperatur if self.args.context_parallel_size > 1: per_token_logps = self._postprocess_packed_tensor_cp(per_token_logps, packed_seq_params, num_samples) - return per_token_logps + return per_token_logps, routing_topk_idx def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples): """ diff --git a/swift/megatron/trainers/rollout_mixin.py b/swift/megatron/trainers/rollout_mixin.py index 89872fc2cb..06f7abe01e 100644 --- a/swift/megatron/trainers/rollout_mixin.py +++ b/swift/megatron/trainers/rollout_mixin.py @@ -245,6 +245,11 @@ def _prepare_vllm_engine(self): vllm_engine_kwargs = args.vllm_engine_kwargs or {} load_format = vllm_engine_kwargs.pop('load_format', 'dummy') + if self.args.router_replay_mode == 'R3': + assert check_vllm_version_ge('0.14.0'), \ + 'The enable_return_routed_experts attribute is not supported. Please upgrade vllm to 0.14.0 or higher' + vllm_engine_kwargs['enable_return_routed_experts'] = True + engine = GRPOVllmEngine( args.model_info.model_dir, torch_dtype=args.torch_dtype, diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index fe351e1f93..4a4f03da11 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -9,5 +9,6 @@ from .parallel_utils import (logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group, split_cp_inputs) from .patcher import patch_merge_fn, patch_torch_dist_shard +from .router_replay_utils import * from .utils import (copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/router_replay_utils.py b/swift/megatron/utils/router_replay_utils.py new file mode 100644 index 0000000000..6cd5d485b0 --- /dev/null +++ b/swift/megatron/utils/router_replay_utils.py @@ -0,0 +1,215 @@ +""" +Router Replay Utilities +Utilities for handling router replay functionality in Megatron models. +""" + +import torch +from megatron.core import mpu +from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region +from megatron.core.transformer.transformer_block import get_num_layers_to_build +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + +from swift.megatron.trainers.utils import split_cp_inputs +from swift.utils import get_logger +from swift.utils.torch_utils import get_current_device + +logger = get_logger() + +try: + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher + ROUTER_REPLAY_AVAILABLE = True +except ImportError: + logger.warning('RouterReplay not available in current megatron-core version') + RouterReplay = None + RouterReplayAction = None + MoEAlltoAllTokenDispatcher = None + ROUTER_REPLAY_AVAILABLE = False + +device_name = get_current_device() + + +def is_moe_layer(tf_config, layer_idx): + moe_layer_freq = getattr(tf_config, 'moe_layer_freq', None) + if isinstance(moe_layer_freq, int): + return layer_idx % moe_layer_freq == 0 + elif isinstance(moe_layer_freq, list): + return moe_layer_freq[layer_idx] == 1 + else: + raise ValueError(f'Unsupported moe_layer_freq type: {type(moe_layer_freq)}') + + +def get_moe_num_layers_to_build(tf_config, vp_stage=None, pp_rank=None): + """Count the number of MoE layers assigned to the current rank. + When ``moe_layer_freq`` is 1 or unset, every transformer layer is an MoE + layer, so the count equals the total layer count. Otherwise only layers + whose global index satisfies the frequency predicate are counted. + Args: + config: Megatron TransformerConfig providing layer layout information. + vp_stage: Virtual-pipeline stage index (None defaults to current). + pp_rank: Pipeline-parallel rank (None defaults to current). + Returns: + Number of MoE layers on the specified rank/stage. + """ + total_layers = get_num_layers_to_build(tf_config, vp_stage=vp_stage, pp_rank=pp_rank) + + layer_offset = get_transformer_layer_offset(tf_config, vp_stage=vp_stage) + local_global_indices = range(layer_offset, layer_offset + total_layers) + + num_moe_layers = sum(1 for idx in local_global_indices if is_moe_layer(tf_config, idx)) + + return num_moe_layers + + +def get_local_layer_range(tf_config, vp_rank=None, only_moe_layer=True): + vp_size = tf_config.virtual_pipeline_model_parallel_size + if vp_size is not None: + vp_rank = 0 if vp_rank is None else vp_rank + offset = 0 + for pre_vp_stage in range(vp_size): + if pre_vp_stage == vp_rank: + break + num_layers_to_build = get_moe_num_layers_to_build( + tf_config, pre_vp_stage) if only_moe_layer else get_num_layers_to_build(tf_config, pre_vp_stage) + offset += num_layers_to_build + else: + offset = 0 + count = get_moe_num_layers_to_build(tf_config, vp_rank) if only_moe_layer else get_num_layers_to_build( + tf_config, vp_rank) + return offset, count + + +def get_local_topk_idx_for_current_rank(global_topk_idx, tf_config, packed_seq_params=None): + if global_topk_idx is None: + return None + # 1. pp slice + # For the hybrid model, global_topk_idx contains data from all layers + # because vLLM reports routed_experts across all transformer layers(including dense). + # However megatron only has routers for MoE layers. + # So local_topk_idx should filter only data from the MoE layer. + layer_offset = get_transformer_layer_offset(tf_config, vp_stage=0) + offset, count = get_local_layer_range( + tf_config, tf_config.virtual_pipeline_model_parallel_size, only_moe_layer=False) + num_layers = offset + count + moe_layer_idx = torch.tensor([ + layer_idx for layer_idx in range(layer_offset, layer_offset + num_layers) if is_moe_layer(tf_config, layer_idx) + ], + dtype=torch.long, + device=global_topk_idx.device) + local_topk_idx = torch.index_select(global_topk_idx, dim=2, index=moe_layer_idx) + # 2. cp slice + cp_size = mpu.get_context_parallel_world_size() + if cp_size > 1: + local_topk_idx = split_cp_inputs(local_topk_idx, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) + # 3. sp slice + local_topk_idx = scatter_to_sequence_parallel_region(local_topk_idx.transpose(0, 1)).transpose(0, 1) + return local_topk_idx + + +def get_router_replay_data(tf_config, vp_rank=None): + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + layers_topk_idx = [] + for router in router_instances_list: + layers_topk_idx.append(router.recorded_topk_idx.to(torch.uint8)) + # layer_num, seq_len, topk -> 1, seq_len, layer_num, topk + layers_topk_idx = torch.stack(layers_topk_idx).transpose(0, 1).unsqueeze(0).to(device_name) + return layers_topk_idx + + +def set_router_replay_data(layers_topk_idx, tf_config, vp_rank=None): + # bs, seq_len, layer_num, topk -> layer_num, total_seq_len, topk + layers_topk_idx_reshape = layers_topk_idx.flatten(0, 1).transpose(0, 1).to(device_name) + offset, count = get_local_layer_range(tf_config, vp_rank) + router_instances_list = RouterReplay.global_router_replay_instances[offset:offset + count] + for i, router in enumerate(router_instances_list): + router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64)) + + +class RouterReplayHelper: + """Helper class to query router replay state and locate local RouterReplay instances.""" + + @staticmethod + def get_micro_batch_router_list(tf_config, vp_rank=None): + """ + Return the list of RouterReplay instances corresponding to the current micro-batch and local + (pp_rank, vp_stage) layer range. + + When virtual pipeline (VPP) is enabled, the local range for the PP rank is expanded to include + all VP stages by multiplying the per-VP count by vp_size. The returned slice is taken from the + global RouterReplay.global_router_replay_instances list. + + Args: + tf_config: Configuration object used to compute layer assignments. + vp_rank (Optional[int]): Explicit virtual pipeline stage to query. If None, the current VP + rank from Megatron parallel state is used when available. + Returns: + list: A contiguous sublist of RouterReplay.router_instances for the local layer range. + """ + offset, count = get_local_layer_range(tf_config, vp_rank) + router_instances_list = RouterReplay.global_router_replay_instances[offset:offset + count] + return router_instances_list + + @staticmethod + def is_r2_record_action(tf_config, vp_rank=None) -> bool: + """Return True if the current router_replay_action is RECORD (R2) for the local router instances. + + This inspects the first local RouterReplay instance's router_replay_action and compares it to + RouterReplayAction.RECORD. + """ + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return (router_instances_list and router_instances_list[0].router_replay_action == RouterReplayAction.RECORD) + + @staticmethod + def is_replay_forward_action(tf_config, vp_rank=None) -> bool: + """Return True if the current router_replay_action is REPLAY_FORWARD for the local router instances. + + This inspects the first local RouterReplay instance's router_replay_action and compares it to + RouterReplayAction.REPLAY_FORWARD. + """ + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return (router_instances_list + and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_FORWARD) + + @staticmethod + def is_replay_backward_action(tf_config, vp_rank=None) -> bool: + """Return True if the current router_replay_action is REPLAY_BACKWARD for the local router instances. + + This inspects the first local RouterReplay instance's router_replay_action and compares it to + RouterReplayAction.REPLAY_BACKWARD. + """ + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return (router_instances_list + and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_BACKWARD) + + +def apply_router_replay_patch(): + """ + Applies the monkey patch for MoE Router Replay functionality. + """ + logger.info('Applying Router Replay Patch...') + + assert ROUTER_REPLAY_AVAILABLE, \ + 'The routing replay is not supported. Please upgrade megatron-core to 0.16.0 or higher' + + # Patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay + # When router replay is enabled, duplicate indices in top_indices can cause + # routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall. + if MoEAlltoAllTokenDispatcher is not None and not hasattr(MoEAlltoAllTokenDispatcher, '_preprocess_patched'): + original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + + def patched_preprocess(self, routing_map): + """Patched preprocess that handles router replay correctly for alltoall dispatcher.""" + # Call original preprocess + result = original_preprocess(self, routing_map) + # Fix num_out_tokens when router replay is enabled + if (getattr(self.config, 'moe_enable_routing_replay', False) and not self.drop_and_pad + and self.config.moe_expert_capacity_factor is None + and not (getattr(self.config, 'moe_router_padding_for_quantization', None) + or getattr(self.config, 'moe_router_padding_for_fp8', None))): + # With router replay, duplicate indices can reduce the actual routed + # token count, so derive it from the routing map instead. + self.num_out_tokens = int(routing_map.sum().item()) + return result + + MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess + MoEAlltoAllTokenDispatcher._preprocess_patched = True