From affc976497c7a4e31c760005abd89736cf8d7276 Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Fri, 27 Feb 2026 16:27:52 +0800 Subject: [PATCH 01/14] Support routing replay in Megatron GRPO --- swift/infer_engine/grpo_vllm_engine.py | 1 + swift/infer_engine/protocol.py | 32 +- swift/megatron/arguments/megatron_args.py | 2 + swift/megatron/trainers/base.py | 17 +- swift/megatron/trainers/grpo_trainer.py | 88 ++++- swift/megatron/trainers/rollout_mixin.py | 3 + swift/megatron/utils/__init__.py | 2 + swift/megatron/utils/router_replay_patch.py | 351 ++++++++++++++++++++ swift/megatron/utils/router_replay_utils.py | 130 ++++++++ 9 files changed, 619 insertions(+), 7 deletions(-) create mode 100644 swift/megatron/utils/router_replay_patch.py create mode 100644 swift/megatron/utils/router_replay_utils.py diff --git a/swift/infer_engine/grpo_vllm_engine.py b/swift/infer_engine/grpo_vllm_engine.py index 07436fec47..a435dd228b 100644 --- a/swift/infer_engine/grpo_vllm_engine.py +++ b/swift/infer_engine/grpo_vllm_engine.py @@ -122,6 +122,7 @@ def _create_chat_completion_response(self, result, inputs, request_config, reque finish_reason=output.finish_reason, logprobs=logprobs, token_ids=token_ids, + routed_experts=getattr(output, 'routed_experts', None) ) choices.append(choice) prompt_token_ids = None diff --git a/swift/infer_engine/protocol.py b/swift/infer_engine/protocol.py index e98840a8f3..8d8e141b5a 100644 --- a/swift/infer_engine/protocol.py +++ b/swift/infer_engine/protocol.py @@ -8,13 +8,40 @@ from copy import deepcopy from dataclasses import asdict, dataclass, field, fields from PIL import Image -from pydantic import BaseModel, Field, field_validator -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from pydantic import BaseModel, Field, field_validator, PlainSerializer, AfterValidator +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Annotated +import numpy as np from swift.template import Messages, Tool from swift.utils import remove_response +def serialize_ndarray(value): + if value is None: + return None + if isinstance(value, np.ndarray): + return { + 'data': value.tolist(), + 'shape': value.shape, + 'dtype': str(value.dtype), + '__ndarray__': True + } + return value + +def deserialize_ndarray(value): + if value is None: + return None + if isinstance(value, dict) and value.get('__ndarray__'): + return np.array(value['data'], dtype=value['dtype']).reshape(value['shape']) + return value + +NumpyArray = Annotated[ + Any, + PlainSerializer(serialize_ndarray, return_type=Dict), + AfterValidator(deserialize_ndarray) +] + + @dataclass class InferRequest: """ @@ -392,6 +419,7 @@ class ChatCompletionResponseChoice: finish_reason: Literal['stop', 'length', None] logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None token_ids: Optional[List[int]] = None + routed_experts: Optional[NumpyArray] = None def to_cmpl_choice(self) -> 'CompletionResponseChoice': self = deepcopy(self) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index 082ba37ee4..84771e6e57 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -144,6 +144,8 @@ class RLHFMegatronArgumentsMixin: log_entropy: bool = False # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 top_entropy_quantile: float = 1.0 + + router_replay_mode: Literal['disabled', 'R2', 'R3'] = 'disabled' # ─────────────────────────── Not Supported Yet ─────────────────────────── diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index f999b8493b..efb0884b8e 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -32,7 +32,8 @@ initialize_tp_communicators, load_mcore_checkpoint, logical_and_across_model_parallel_group, maybe_finalize_async_save, prepare_mcore_model, reduce_max_stat_across_model_parallel_group, - save_mcore_checkpoint, should_disable_forward_pre_hook, wrap_model) + save_mcore_checkpoint, should_disable_forward_pre_hook, wrap_model, + apply_router_replay_patch, RouterReplay, RouterReplayAction) from swift.template import Template from swift.trainers import dynamic_gradient_checkpointing from swift.trainers.utils import patch_modelscope_hub_timeout @@ -60,6 +61,12 @@ def __init__(self, args, template: Template): self.optimizer, self.opt_param_scheduler = self.get_optimizer_and_scheduler() self.data_collator = self._get_data_collator() + self.enable_routing_replay = args.router_replay_mode != "disabled" + self.args.extra_args['enable_routing_replay'] = self.enable_routing_replay + # patch routing_replay + if self.enable_routing_replay: + apply_router_replay_patch() + self.state = TrainerState(max_steps=args.train_iters) initialize_embedding = args.new_special_tokens or args.task_type == 'seq_cls' if initialize_embedding: @@ -765,6 +772,10 @@ def train_step(self, train_data_iterator): self.optimizer.zero_grad() # TODO: refactor _replace_data_iterator data_iterator = self._replace_data_iterator(train_data_iterator) + + if self.enable_routing_replay: + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + metrics = forward_backward_func( forward_step_func=self.forward_step, data_iterator=data_iterator, @@ -781,6 +792,10 @@ def train_step(self, train_data_iterator): if update_successful: self.opt_param_scheduler.step(increment=args.global_batch_size) + if self.enable_routing_replay: + RouterReplay.clear_global_router_replay_action() + RouterReplay.clear_global_indices() + return metrics, grad_norm, update_successful def _aggregated_metrics(self, metrics, total_metrics): diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index e8f0b65c70..0ab5bc9dd4 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -23,7 +23,9 @@ from swift.dataset import RowPreprocessor from swift.infer_engine.protocol import RequestConfig, RolloutInferRequest, RolloutOutput from swift.megatron.arguments import MegatronArguments, MegatronRLHFArguments -from swift.megatron.utils import forward_step_helper, get_padding_to, set_random_seed +from swift.megatron.utils import (forward_step_helper, get_padding_to, set_random_seed, + RouterReplay, RouterReplayAction, RouterReplayHelper, + get_router_replay_data, set_router_replay_data, get_local_topk_idx_for_current_rank) from swift.rewards import orms from swift.rlhf_trainers.grpo_trainer import DataType from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context, @@ -278,6 +280,8 @@ def _batch_encode(self, infer_requests: List[Dict], template: Template, strict: return batched_inputs, error_list def _get_encoded_batch(self, encoded_list, rollout_batch, template): + original_seq_lengths = [item['length'] for item in encoded_list] + args = self.args encoded_batch = to_device(template.data_collator(encoded_list, padding_to=get_padding_to(args)), self.device) @@ -367,6 +371,40 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template): encoded_batch['rollout_per_token_logps'] = rollout_per_token_logps + # Validating and processing routed_experts data in R3 mode + if self.args.router_replay_mode == 'R3': + routed_experts_list = [] + cur_seq_lengths = seq_lengths + if (seq_lengths.size(0) > batch_size): + cur_seq_lengths = seq_lengths[:batch_size].clone() + cur_seq_lengths[batch_size - 1] = seq_lengths[batch_size-1:].sum() + for data, original_seq_len, cur_seq_len in zip(rollout_batch, original_seq_lengths, cur_seq_lengths): + routed_experts = data.get('routed_experts') + assert routed_experts is not None, 'When router_replay_mode = R3, routed_experts must be in rollout data' + routed_experts = torch.tensor(routed_experts) + # The number of experts in the output can be 1 less than (prompt_length + response_token_count) + # This gap of 1 is expected + # For more details, please refer PR https://github.com/vllm-project/vllm/pull/28284 + experts_seq_len = routed_experts.shape[0] + assert (experts_seq_len == original_seq_len + or experts_seq_len + 1 == original_seq_len), \ + f'The seq_len of routed_experts({experts_seq_len}) in output does not match the seq_len of data({original_seq_len}), should be equal to or 1 less than the seq_len of data' + # Padding routed_experts(seq_len, layer_num, topk) seq_len to match the seq_len of the input_ids + padding_routed_experts = routed_experts + padding_to = cur_seq_len if template.padding_free else max_seq_len + padding_len = padding_to - experts_seq_len + if padding_len > 0: + padding_right = template.padding_side == 'right' + padding_routed_experts = nn.functional.pad(routed_experts, + (0, 0, 0, 0, 0, padding_len) if padding_right else (0, 0, 0, 0, padding_len, 0), + 'constant', 0) + routed_experts_list.append(padding_routed_experts) + if template.padding_free: + gloabl_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0) + else: + gloabl_routed_experts = torch.stack(routed_experts_list) + encoded_batch['routed_experts'] = gloabl_routed_experts.to(device=self.device) + return encoded_batch def _generate_and_score_completions(self, batch): @@ -564,6 +602,9 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out if 'content' in choice.logprobs: rollout_logprobs = [item['logprob'] for item in choice.logprobs['content']] input_data['rollout_logprobs'] = [rollout_logprobs] + + # Step 6: Store rollout routed_experts for routing replay + input_data['routed_experts'] = choice.routed_experts return input_data assert len(batch) == len(outputs) @@ -958,9 +999,16 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: # In non-padding_free mode, logps are already in batch format [batch_size, seq_len] ref_per_token_logps = ref_per_token_logps_raw batch['ref_per_token_logps'] = ref_per_token_logps - - old_per_token_logps_raw = self.model_forward( - self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] + + if self.enable_routing_replay: + if self.args.router_replay_mode == 'R2': + RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) + if self.args.router_replay_mode == 'R3': + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + output = self.model_forward(self.unwrapped_models[0], iter([deepcopy(inputs)]), + no_grad=True, per_token=True) + old_per_token_logps_raw = output['logps'] if self.template.padding_free: old_per_token_logps, _ = pad_logps_back_to_batch( logps_rmpad=old_per_token_logps_raw, @@ -971,6 +1019,11 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: old_per_token_logps = old_per_token_logps_raw batch['old_per_token_logps'] = old_per_token_logps + if self.enable_routing_replay: + batch['routed_experts'] = output['layers_topk_idx'] + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() + return batch def _compute_kl_from_batches(self, mini_batch_data: List[Dict[str, Any]]) -> torch.Tensor: @@ -1050,6 +1103,15 @@ def forward_step(self, data_iterator, model): 'seq_lengths': seq_lengths, }) data.pop('loss_scale', None) + + if RouterReplayHelper.is_replay_backward_action(model.config): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + if RouterReplayHelper.is_replay_forward_action(model.config): + layers_topk_idx = data.pop('routed_experts', None) + set_router_replay_data(layers_topk_idx, model.config) + inputs = self._prepare_model_inputs(data) labels = data['labels'] @@ -1123,6 +1185,11 @@ def forward_step(self, data_iterator, model): data['per_token_logps'] = per_token_logps data['per_token_entropy'] = None + if RouterReplayHelper.is_replay_forward_action(model.config): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) + return output_tensor, partial(self.loss_func, data=data) @profiling_decorator @@ -1437,6 +1504,13 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): labels = data.get('labels') context = torch.no_grad() if no_grad else nullcontext() + layers_topk_idx = None + global_topk_idx = data.pop('routed_experts', None) + if RouterReplayHelper.is_replay_forward_action(model.config): + assert global_topk_idx is not None, "When router_replay_mode = R3, routed_experts must be in data" + layers_topk_idx = get_local_topk_idx_for_current_rank(global_topk_idx, model.config, data.get('packed_seq_params')) + set_router_replay_data(layers_topk_idx, model.config) + with context: output_tensor = forward_step_helper(self.args, model, data) @@ -1448,6 +1522,12 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0] data['logps'] = None if labels is None else self.get_logps( output_tensor, labels, packed_seq_params, num_samples, per_token=per_token) + + if RouterReplayHelper.is_r2_record_action(model.config): + layers_topk_idx = get_router_replay_data(model.config) + if layers_topk_idx is not None: + data['layers_topk_idx'] = layers_topk_idx + return data def inputs2requests(self, inputs: Union[DataType, List[RolloutInferRequest]]) -> List[RolloutInferRequest]: diff --git a/swift/megatron/trainers/rollout_mixin.py b/swift/megatron/trainers/rollout_mixin.py index 89872fc2cb..7601a35de3 100644 --- a/swift/megatron/trainers/rollout_mixin.py +++ b/swift/megatron/trainers/rollout_mixin.py @@ -245,6 +245,9 @@ def _prepare_vllm_engine(self): vllm_engine_kwargs = args.vllm_engine_kwargs or {} load_format = vllm_engine_kwargs.pop('load_format', 'dummy') + if self.args.router_replay_mode == 'R3': + vllm_engine_kwargs['enable_return_routed_experts'] = True + engine = GRPOVllmEngine( args.model_info.model_dir, torch_dtype=args.torch_dtype, diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 84838ae06d..415bf7cc67 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -10,3 +10,5 @@ from .patcher import patch_merge_fn, patch_torch_dist_shard from .utils import (copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) +from .router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch +from .router_replay_utils import * diff --git a/swift/megatron/utils/router_replay_patch.py b/swift/megatron/utils/router_replay_patch.py new file mode 100644 index 0000000000..7760ef2777 --- /dev/null +++ b/swift/megatron/utils/router_replay_patch.py @@ -0,0 +1,351 @@ +import warnings +from enum import Enum +from typing import List, Optional, Callable + +import torch + +try: + from megatron.core.transformer.moe.moe_utils import ( + apply_router_token_dropping, + compute_routing_scores_for_aux_loss, + group_limited_topk, + ) +except ImportError: + warnings.warn("NPU not support router replay for now.", stacklevel=2) + pass +from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.training import get_args + + +class RouterReplayAction(Enum): + RECORD = "record" # Record the topk indices for replay + REPLAY_FORWARD = "replay_forward" # Replay the recorded topk indices for forward pass + REPLAY_BACKWARD = "replay_backward" # Replay topk indices for re-compute during backward pass + + +class RouterReplay: + """ + A class to manage the recording and replaying of MoE routing decisions. + It holds all router instances and provides static methods to globally + control recording and replaying. + """ + + # Static variable to hold all router instances, one per MoE layer. + global_router_replay_instances: List['RouterReplay'] = [] + + @staticmethod + def set_replay_data(all_layers_topk_indices: List[torch.Tensor]): + """ + Distributes the topk indices for all layers to their respective RouterReplay instances. + :param all_layers_topk_indices: A list of tensors, where each tensor contains the + topk indices for a specific layer. The order + must match the instantiation order of the routers. + """ + if len(all_layers_topk_indices) != len(RouterReplay.global_router_replay_instances): + raise ValueError( + f"The number of replay tensors ({len(all_layers_topk_indices)}) " + f"does not match router instances ({len(RouterReplay.global_router_replay_instances)})." + ) + for i, router_instance in enumerate(RouterReplay.global_router_replay_instances): + router_instance.set_target_indices(all_layers_topk_indices[i]) + + @staticmethod + def get_recorded_data() -> List[torch.Tensor]: + """ + Collects the recorded topk indices from all RouterReplay instances. + :return: A list of tensors, each containing the recorded topk indices for a layer. + """ + return [ + router.get_recorded_indices() for router in RouterReplay.global_router_replay_instances + ] + + @staticmethod + def clear_global_indices(): + """Clears the recorded and target topk indices in all instances.""" + for router in RouterReplay.global_router_replay_instances: + router.clear_indices() + + @staticmethod + def set_global_router_replay_action(router_replay_action: RouterReplayAction): + """Sets the router replay action for all router instances.""" + for router in RouterReplay.global_router_replay_instances: + router.set_router_replay_action(router_replay_action) + + @staticmethod + def clear_global_router_replay_action(): + """Clears the router replay action for all router instances.""" + for router in RouterReplay.global_router_replay_instances: + router.clear_router_replay_action() + + def __init__(self): + """Initializes a RouterReplay instance for a specific layer.""" + self.target_topk_idx: Optional[torch.Tensor] = None # Target topk indices for replay + self.recorded_topk_idx: Optional[torch.Tensor] = None # Recorded topk indices for replay + self.router_replay_action: Optional[RouterReplayAction] = ( + None # Router replay action for this layer + ) + self.replay_backward_list: List[torch.Tensor] = ( + [] + ) # List of tensors for backward pass replay + RouterReplay.global_router_replay_instances.append(self) + + def set_target_indices(self, topk_indices: torch.Tensor): + """Sets the target topk indices for replay.""" + self.target_topk_idx = topk_indices + self.replay_backward_list.append(topk_indices) + + def get_recorded_indices(self) -> Optional[torch.Tensor]: + """Returns the recorded topk indices.""" + return self.recorded_topk_idx + + def record_indices(self, topk_indices: torch.Tensor): + """Records the topk indices.""" + self.recorded_topk_idx = topk_indices + + def clear_indices(self): + """Clears the recorded and target topk indices.""" + self.recorded_topk_idx = None + self.target_topk_idx = None + self.replay_backward_list = [] + + def set_router_replay_action(self, router_replay_action: RouterReplayAction): + """Sets the router replay action for this layer.""" + self.router_replay_action = router_replay_action + + def clear_router_replay_action(self): + """Clears the router replay action for this layer.""" + self.router_replay_action = None + + def get_replay_topk( + self, + scores: torch.Tensor, + topk: int, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + default_compute_topk: Callable[ + [torch.Tensor, int, Optional[int], Optional[int]], torch.Tensor + ] = None, + ) -> torch.Tensor: + """Returns the target topk indices for replay.""" + if self.router_replay_action == RouterReplayAction.RECORD: + probs, top_indices = default_compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + self.record_indices(top_indices) + return probs, top_indices + elif self.router_replay_action == RouterReplayAction.REPLAY_FORWARD: + top_indices = self.target_topk_idx + # Ensure indices are on the correct device + top_indices = top_indices.to(scores.device) + # Gather the scores for the replayed indices to get the probabilities + probs = scores.gather(1, top_indices) + return probs, top_indices + elif self.router_replay_action == RouterReplayAction.REPLAY_BACKWARD: + top_indices = self.replay_backward_list.pop(0) + # Ensure indices are on the correct device + top_indices = top_indices.to(scores.device) + # Gather the scores for the replayed indices to get the probabilities + probs = scores.gather(1, top_indices) + return probs, top_indices + else: + return default_compute_topk(scores, topk, num_groups, group_topk) + + +def _patched_topk_routing_with_score_function( + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: Optional[float] = None, + score_function: str = "softmax", + expert_bias: Optional[torch.Tensor] = None, + fused: bool = False, + router_replay: Optional['RouterReplay'] = None, +): + """ + Patched version of topk_routing_with_score_function that supports router replay. + """ + assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." + num_tokens, num_experts = logits.shape + + def _compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return torch.topk(scores, k=topk, dim=1) + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + # Default behavior if no replay is active + + if router_replay is None: + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + else: + return router_replay.get_replay_topk( + scores, topk, num_groups, group_topk, _compute_topk + ) + + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) + probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) + elif score_function == "sigmoid": + scores = torch.sigmoid(logits.float()).type_as(logits) + if expert_bias is not None: + scores_for_routing = scores + expert_bias + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + else: + scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) + probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + if torch.are_deterministic_algorithms_enabled(): + # build [num_tokens, num_experts] from [num_tokens, topk] + routing_probs = torch.zeros_like(logits) + rows = torch.arange(num_tokens, device=logits.device).unsqueeze(1) + routing_probs.index_put_((rows, top_indices), probs, accumulate=False) + + routing_map = torch.zeros_like(logits, dtype=logits.dtype) + routing_map.index_put_( + (rows, top_indices), torch.ones_like(probs, dtype=routing_map.dtype), accumulate=False + ) + routing_map = routing_map.bool() + else: + # TODO Try using element-wise operations instead of scatter? + routing_probs = torch.zeros_like(logits).scatter(1, top_indices, probs) + routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() + + return routing_probs, routing_map + + +def patched_routing(self, logits: torch.Tensor): + """Top-k routing function + + Args: + logits (torch.Tensor): Logits tensor after gating. + + Returns: + probs (torch.Tensor): The probabilities of token to experts assignment. + routing_map (torch.Tensor): The mapping of token to experts assignment, + with shape [num_tokens, num_experts]. + """ + seq_length, bsz = logits.shape[:2] + logits = logits.view(-1, self.config.num_moe_experts) + + # Apply Z-Loss + logits = self.apply_z_loss(logits) + + # Calculate probs and routing_map for token dispatching + if self.routing_type == "sinkhorn": + probs, routing_map = self.sinkhorn_load_balancing(logits) + else: + probs, routing_map = _patched_topk_routing_with_score_function( + logits, + self.topk, + use_pre_softmax=self.config.moe_router_pre_softmax, + num_groups=self.config.moe_router_num_groups, + group_topk=self.config.moe_router_group_topk, + scaling_factor=self.config.moe_router_topk_scaling_factor, + score_function=self.score_function, + expert_bias=self.expert_bias, + fused=self.config.moe_router_fusion, + router_replay=self.router_replay, + ) + + # Apply token dropping to probs and routing_map. + if self.config.moe_expert_capacity_factor is not None: + probs, routing_map = apply_router_token_dropping( + probs, + routing_map, + router_topk=self.topk, + capacity_factor=self.config.moe_expert_capacity_factor, + drop_policy=self.config.moe_token_drop_policy, + pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, + ) + + # Apply each aux loss type and attach aux loss autograd function to probs + if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled(): + # Calculate scores and routing_map for aux loss + routing_map_for_aux_loss, scores_for_aux_loss = compute_routing_scores_for_aux_loss( + logits, self.topk, self.score_function, fused=self.config.moe_router_fusion + ) + probs = self._apply_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss) + probs = self._apply_seq_aux_loss( + probs, scores_for_aux_loss, routing_map_for_aux_loss, seq_length, bsz + ) + probs = self._apply_global_aux_loss( + probs, scores_for_aux_loss, routing_map_for_aux_loss + ) + + # Optionally apply expert bias + # Update expert bias and tokens_per_expert + # Prevent extra local tokens accumulation on evaluation or activation recomputation + if self.enable_expert_bias and torch.is_grad_enabled(): + with torch.no_grad(): + self.local_tokens_per_expert += routing_map.sum(dim=0) + + return probs, routing_map + + +def apply_router_replay_patch(): + """ + Applies the monkey patch for MoE Router Replay functionality. + This patch dynamically adds the 'enable_routing_replay' attribute to TransformerConfig + and modifies the TopKRouter to support recording and replaying of routing decisions. + """ + print("Applying Router Replay Patch...") + # Clear router instances to avoid state leakage between model initializations. + RouterReplay.global_router_replay_instances.clear() + # Step 1: Patch TransformerConfig to include the feature flag + if not hasattr(TransformerConfig, "enable_routing_replay"): + # Add class attribute with default value + TransformerConfig.enable_routing_replay = False + + # Store original __init__ method + original_tf_config_init = TransformerConfig.__init__ + + # Define new __init__ method that safely handles enable_routing_replay parameter + def patched_tf_config_init(self, *args, **kwargs): + # Call original constructor with remaining kwargs + original_tf_config_init(self, *args, **kwargs) + + mg_args = get_args() + # Set the instance attribute + self.enable_routing_replay = mg_args.enable_routing_replay or TransformerConfig.enable_routing_replay + + # Apply the patch + TransformerConfig.__init__ = patched_tf_config_init + + # Step 2: Patch TopKRouter only once to ensure idempotency. + if hasattr(TopKRouter, "_router_replay_patched"): + return + + original_init = TopKRouter.__init__ + + # Step 3: Define the new __init__ method + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self.router_replay = None + if self.config.enable_routing_replay: + self.router_replay = RouterReplay() + + # Step 4: Apply the patches + TopKRouter.__init__ = patched_init + TopKRouter.routing = patched_routing + TopKRouter._router_replay_patched = True diff --git a/swift/megatron/utils/router_replay_utils.py b/swift/megatron/utils/router_replay_utils.py new file mode 100644 index 0000000000..1b80428a58 --- /dev/null +++ b/swift/megatron/utils/router_replay_utils.py @@ -0,0 +1,130 @@ +""" +Router Replay Utilities +Utilities for handling router replay functionality in Megatron models. +""" + +import torch + +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from megatron.core.transformer.transformer_block import get_num_layers_to_build +from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region + +from megatron.core import mpu +from swift.utils.torch_utils import get_current_device +from swift.megatron.utils import RouterReplay, RouterReplayAction +from swift.megatron.trainers.utils import split_cp_inputs + +device_name = get_current_device() + +def get_local_layer_range(tf_config, vp_rank=None): + vp_size = tf_config.virtual_pipeline_model_parallel_size + if vp_size is not None: + vp_rank = 0 if vp_rank is None else vp_rank + offset = 0 + for pre_vp_stage in range(vp_size): + if pre_vp_stage == vp_rank: + break + num_layers_to_build = get_num_layers_to_build(tf_config, pre_vp_stage) + offset += num_layers_to_build + else: + offset = 0 + count = get_num_layers_to_build(tf_config, vp_rank) + return offset, count + +def get_local_topk_idx_for_current_rank(global_topk_idx, tf_config, packed_seq_params=None): + if global_topk_idx is None: + return None + # 1. pp slice + layer_offset = get_transformer_layer_offset(tf_config, vp_stage=0) + offset, count = get_local_layer_range(tf_config, tf_config.virtual_pipeline_model_parallel_size) + num_layers = offset + count + local_topk_idx = torch.narrow(global_topk_idx, dim=2, start=layer_offset, length=num_layers) + # 2. cp slice + cp_size = mpu.get_context_parallel_world_size() + if cp_size > 1: + local_topk_idx = split_cp_inputs(local_topk_idx, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) + # 3. sp slice + local_topk_idx = scatter_to_sequence_parallel_region(local_topk_idx.transpose(0, 1)).transpose(0, 1) + return local_topk_idx + +def get_router_replay_data(tf_config, vp_rank=None): + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + layers_topk_idx = [] + for router in router_instances_list: + layers_topk_idx.append(router.recorded_topk_idx.to(torch.uint8)) + # layer_num, seq_len, topk -> 1, seq_len, layer_num, topk + layers_topk_idx = torch.stack(layers_topk_idx).transpose(0, 1).unsqueeze(0).to(device_name) + return layers_topk_idx + +def set_router_replay_data(layers_topk_idx, tf_config, vp_rank=None): + # bs, seq_len, layer_num, topk -> layer_num, total_seq_len, topk + layers_topk_idx_reshape = layers_topk_idx.flatten(0, 1).transpose(0, 1).to(device_name) + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + offset, _ = get_local_layer_range(tf_config, vp_rank) + for i, router in enumerate(router_instances_list): + router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64)) + + +class RouterReplayHelper: + """Helper class to query router replay state and locate local RouterReplay instances.""" + + @staticmethod + def get_micro_batch_router_list(tf_config, vp_rank=None): + """ + Return the list of RouterReplay instances corresponding to the current micro-batch and local + (pp_rank, vp_stage) layer range. + + When virtual pipeline (VPP) is enabled, the local range for the PP rank is expanded to include + all VP stages by multiplying the per-VP count by vp_size. The returned slice is taken from the + global RouterReplay.global_router_replay_instances list. + + Args: + tf_config: Configuration object used to compute layer assignments. + vp_rank (Optional[int]): Explicit virtual pipeline stage to query. If None, the current VP + rank from Megatron parallel state is used when available. + Returns: + list: A contiguous sublist of RouterReplay.router_instances for the local layer range. + """ + offset, count = get_local_layer_range(tf_config, vp_rank) + router_instances_list = RouterReplay.global_router_replay_instances[offset : offset + count] + return router_instances_list + + @staticmethod + def is_r2_record_action(tf_config, vp_rank=None) -> bool: + """Return True if the current router_replay_action is RECORD (R2) for the local router instances. + + This inspects the first local RouterReplay instance's router_replay_action and compares it to + RouterReplayAction.RECORD. + """ + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return ( + router_instances_list + and router_instances_list[0].router_replay_action == RouterReplayAction.RECORD + ) + + @staticmethod + def is_replay_forward_action(tf_config, vp_rank=None) -> bool: + """Return True if the current router_replay_action is REPLAY_FORWARD for the local router instances. + + This inspects the first local RouterReplay instance's router_replay_action and compares it to + RouterReplayAction.REPLAY_FORWARD. + """ + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return ( + router_instances_list + and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_FORWARD + ) + + @staticmethod + def is_replay_backward_action(tf_config, vp_rank=None) -> bool: + """Return True if the current router_replay_action is REPLAY_BACKWARD for the local router instances. + + This inspects the first local RouterReplay instance's router_replay_action and compares it to + RouterReplayAction.REPLAY_BACKWARD. + """ + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return ( + router_instances_list + and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_BACKWARD + ) From 03f718c41e4e0a613d5f38f3a2f73d3ddb283dab Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Mon, 2 Mar 2026 11:35:40 +0800 Subject: [PATCH 02/14] fix merged code --- swift/megatron/model/model_config.py | 6 ++++++ swift/megatron/trainers/base.py | 11 +++++------ swift/megatron/utils/router_replay_patch.py | 6 ++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index bc12c02390..cbedcbfcdf 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -157,6 +157,8 @@ class MegatronModelConfig(TransformerConfig): 'none'] = 'aux_loss' use_shared_expert_gate: bool = False + enable_routing_replay: bool = False + # mla multi_latent_attention: bool = False q_lora_rank: Optional[int] = None @@ -469,6 +471,10 @@ def get_mcore_model_config(args, hf_config): if num_moe_experts is None: kwargs['expert_model_parallel_size'] = 1 kwargs['expert_tensor_parallel_size'] = 1 + + if args.router_replay_mode != "disabled": + kwargs['enable_routing_replay'] = True + config = MegatronModelConfig(**kwargs) config.hf_config = hf_config config.args = args diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index efb0884b8e..a491f3a91e 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -53,6 +53,11 @@ class BaseMegatronTrainer(ABC): def __init__(self, args, template: Template): + # patch routing_replay + self.enable_routing_replay = args.router_replay_mode != "disabled" + if self.enable_routing_replay: + apply_router_replay_patch() + self.args = args self.template = template self.bridge = args.megatron_model_meta.bridge_cls(args) @@ -61,12 +66,6 @@ def __init__(self, args, template: Template): self.optimizer, self.opt_param_scheduler = self.get_optimizer_and_scheduler() self.data_collator = self._get_data_collator() - self.enable_routing_replay = args.router_replay_mode != "disabled" - self.args.extra_args['enable_routing_replay'] = self.enable_routing_replay - # patch routing_replay - if self.enable_routing_replay: - apply_router_replay_patch() - self.state = TrainerState(max_steps=args.train_iters) initialize_embedding = args.new_special_tokens or args.task_type == 'seq_cls' if initialize_embedding: diff --git a/swift/megatron/utils/router_replay_patch.py b/swift/megatron/utils/router_replay_patch.py index 7760ef2777..c630f6dd81 100644 --- a/swift/megatron/utils/router_replay_patch.py +++ b/swift/megatron/utils/router_replay_patch.py @@ -15,7 +15,6 @@ pass from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training import get_args class RouterReplayAction(Enum): @@ -322,12 +321,11 @@ def apply_router_replay_patch(): # Define new __init__ method that safely handles enable_routing_replay parameter def patched_tf_config_init(self, *args, **kwargs): + enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) # Call original constructor with remaining kwargs original_tf_config_init(self, *args, **kwargs) - - mg_args = get_args() # Set the instance attribute - self.enable_routing_replay = mg_args.enable_routing_replay or TransformerConfig.enable_routing_replay + self.enable_routing_replay = enable_routing_replay # Apply the patch TransformerConfig.__init__ = patched_tf_config_init From d8f03cfc3dbdff81a311adce9a05b592fc044092 Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Wed, 4 Mar 2026 15:30:40 +0800 Subject: [PATCH 03/14] add docs --- docs/source/Megatron-SWIFT/Command-line-parameters.md | 1 + docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 3ad5f7fefa..38171db3cc 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -349,6 +349,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - rollout_importance_sampling_threshold: 重要性采样权重的阈值,用于截断或屏蔽极端权重。默认为2.0。 - log_rollout_offpolicy_metrics: 当 `rollout_importance_sampling_mode` 未设置时,是否记录训推不一致诊断指标(KL、PPL、χ²等)。当设置了 `rollout_importance_sampling_mode` 时,指标会自动记录。默认为False。 - off_policy_sequence_mask_delta: Off-Policy Sequence Masking 阈值,来自 DeepSeek-V3.2 论文。当设置此值时,会计算每个序列的 `mean(old_policy_logps - policy_logps)`,若该值大于阈值且该序列的优势为负,则 mask 掉该序列不参与损失计算。默认为None,不启用。具体参考[文档](../Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md#off-policy-sequence-masking)。 +- router_replay_mode: 路由重放模式,可选项为`disabled`、`R2`、`R3`。默认为disabled,不启用路由重放。 内置奖励函数参数参考[文档](../Instruction/Command-line-parameters.md#奖励函数参数) diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 923a8e48ab..ee257038ef 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -375,6 +375,7 @@ In addition to inheriting the training parameters, the following parameters are - rollout_importance_sampling_threshold: Threshold for importance sampling weights, used for truncating or masking extreme weights. Default is 2.0. - log_rollout_offpolicy_metrics: Whether to log training-inference mismatch diagnostic metrics (KL, PPL, χ², etc.) when `rollout_importance_sampling_mode` is not set. When `rollout_importance_sampling_mode` is set, metrics are always logged. Default is False. - off_policy_sequence_mask_delta: Off-Policy Sequence Masking threshold from [DeepSeek-V3.2 paper](https://arxiv.org/abs/2512.02556). When set, computes `mean(old_policy_logps - policy_logps)` for each sequence. If this value exceeds the threshold AND the sequence has negative advantage, the sequence is masked out from loss computation. For details, refer to the [documentation](../Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md#off-policy-sequence-masking). +- router_replay_mode: Router replay mode. Options are `disabled`,`R2`,`R3`. Default is disabled. Built-in reward function parameters refer to the [documentation](../Instruction/Command-line-parameters.md#reward-function-parameters). From 1e207324a14f2bf400f81ad92df78840396cb838 Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Thu, 5 Mar 2026 16:54:16 +0800 Subject: [PATCH 04/14] fix lint error --- swift/infer_engine/protocol.py | 21 ++---- swift/megatron/model/model_config.py | 4 +- swift/megatron/trainers/base.py | 16 ++-- swift/megatron/trainers/grpo_trainer.py | 46 ++++++------ swift/megatron/utils/__init__.py | 4 +- swift/megatron/utils/router_replay_patch.py | 81 ++++++++------------- swift/megatron/utils/router_replay_utils.py | 37 ++++------ 7 files changed, 87 insertions(+), 122 deletions(-) diff --git a/swift/infer_engine/protocol.py b/swift/infer_engine/protocol.py index 8d8e141b5a..a286c9bbee 100644 --- a/swift/infer_engine/protocol.py +++ b/swift/infer_engine/protocol.py @@ -2,15 +2,15 @@ import base64 import io import json +import numpy as np import os import time import uuid from copy import deepcopy from dataclasses import asdict, dataclass, field, fields from PIL import Image -from pydantic import BaseModel, Field, field_validator, PlainSerializer, AfterValidator -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Annotated -import numpy as np +from pydantic import AfterValidator, BaseModel, Field, PlainSerializer, field_validator +from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union from swift.template import Messages, Tool from swift.utils import remove_response @@ -20,14 +20,10 @@ def serialize_ndarray(value): if value is None: return None if isinstance(value, np.ndarray): - return { - 'data': value.tolist(), - 'shape': value.shape, - 'dtype': str(value.dtype), - '__ndarray__': True - } + return {'data': value.tolist(), 'shape': value.shape, 'dtype': str(value.dtype), '__ndarray__': True} return value + def deserialize_ndarray(value): if value is None: return None @@ -35,11 +31,8 @@ def deserialize_ndarray(value): return np.array(value['data'], dtype=value['dtype']).reshape(value['shape']) return value -NumpyArray = Annotated[ - Any, - PlainSerializer(serialize_ndarray, return_type=Dict), - AfterValidator(deserialize_ndarray) -] + +NumpyArray = Annotated[Any, PlainSerializer(serialize_ndarray, return_type=Dict), AfterValidator(deserialize_ndarray)] @dataclass diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index 2b61c0f0cc..f723bf115a 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -510,8 +510,8 @@ def get_mcore_model_config(args, hf_config): if num_moe_experts is None: kwargs['expert_model_parallel_size'] = 1 kwargs['expert_tensor_parallel_size'] = 1 - - if args.router_replay_mode != "disabled": + + if args.router_replay_mode != 'disabled': kwargs['enable_routing_replay'] = True config = MegatronModelConfig(**kwargs) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 79893e6ae6..7d651b67b7 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -27,14 +27,12 @@ from swift.megatron.callbacks import megatron_callbacks_map from swift.megatron.model import get_mcore_model from swift.megatron.tuners import LoraParallelLinear -from swift.megatron.utils import (copy_original_module_weight, disable_forward_pre_hook, enable_forward_pre_hook, - get_optimizer_param_scheduler, get_padding_to, init_persistent_async_worker, - initialize_tp_communicators, load_mcore_checkpoint, - logical_and_across_model_parallel_group, maybe_finalize_async_save, - prepare_mcore_model, reduce_max_stat_across_model_parallel_group, - save_mcore_checkpoint, should_disable_forward_pre_hook, warmup_jit_function, - wrap_model, - apply_router_replay_patch, RouterReplay, RouterReplayAction) +from swift.megatron.utils import ( + RouterReplay, RouterReplayAction, apply_router_replay_patch, copy_original_module_weight, disable_forward_pre_hook, + enable_forward_pre_hook, get_optimizer_param_scheduler, get_padding_to, init_persistent_async_worker, + initialize_tp_communicators, load_mcore_checkpoint, logical_and_across_model_parallel_group, + maybe_finalize_async_save, prepare_mcore_model, reduce_max_stat_across_model_parallel_group, save_mcore_checkpoint, + should_disable_forward_pre_hook, warmup_jit_function, wrap_model) from swift.template import Template from swift.trainers import dynamic_gradient_checkpointing from swift.trainers.utils import patch_modelscope_hub_timeout @@ -55,7 +53,7 @@ class BaseMegatronTrainer(ABC): def __init__(self, args, template: Template): # patch routing_replay - self.enable_routing_replay = args.router_replay_mode != "disabled" + self.enable_routing_replay = args.router_replay_mode != 'disabled' if self.enable_routing_replay: apply_router_replay_patch() diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 0bb1ca5613..8a57937532 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -23,9 +23,9 @@ from swift.dataset import RowPreprocessor from swift.infer_engine.protocol import RequestConfig, RolloutInferRequest, RolloutOutput from swift.megatron.arguments import MegatronArguments, MegatronRLHFArguments -from swift.megatron.utils import (forward_step_helper, get_padding_to, set_random_seed, - RouterReplay, RouterReplayAction, RouterReplayHelper, - get_router_replay_data, set_router_replay_data, get_local_topk_idx_for_current_rank) +from swift.megatron.utils import (RouterReplay, RouterReplayAction, RouterReplayHelper, forward_step_helper, + get_local_topk_idx_for_current_rank, get_padding_to, get_router_replay_data, + set_random_seed, set_router_replay_data) from swift.rewards import orms from swift.rlhf_trainers.grpo_trainer import DataType from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context, @@ -370,18 +370,20 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template): cur_seq_lengths = seq_lengths if (seq_lengths.size(0) > batch_size): cur_seq_lengths = seq_lengths[:batch_size].clone() - cur_seq_lengths[batch_size - 1] = seq_lengths[batch_size-1:].sum() + cur_seq_lengths[batch_size - 1] = seq_lengths[batch_size - 1:].sum() for data, original_seq_len, cur_seq_len in zip(rollout_batch, original_seq_lengths, cur_seq_lengths): routed_experts = data.get('routed_experts') - assert routed_experts is not None, 'When router_replay_mode = R3, routed_experts must be in rollout data' + assert routed_experts is not None, ( + 'When router_replay_mode = R3, routed_experts must be in rollout data') routed_experts = torch.tensor(routed_experts) # The number of experts in the output can be 1 less than (prompt_length + response_token_count) # This gap of 1 is expected # For more details, please refer PR https://github.com/vllm-project/vllm/pull/28284 experts_seq_len = routed_experts.shape[0] - assert (experts_seq_len == original_seq_len - or experts_seq_len + 1 == original_seq_len), \ - f'The seq_len of routed_experts({experts_seq_len}) in output does not match the seq_len of data({original_seq_len}), should be equal to or 1 less than the seq_len of data' + assert (experts_seq_len == original_seq_len or experts_seq_len + 1 + == original_seq_len), (f'The seq_len of routed_experts({experts_seq_len}) in output ', + f'does not match the seq_len of data({original_seq_len}), ' + 'should be equal to or 1 less than the seq_len of data') # Padding routed_experts(seq_len, layer_num, topk) seq_len to match the seq_len of the input_ids padding_routed_experts = routed_experts padding_to = cur_seq_len if template.padding_free else max_seq_len @@ -389,14 +391,14 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template): if padding_len > 0: padding_right = template.padding_side == 'right' padding_routed_experts = nn.functional.pad(routed_experts, - (0, 0, 0, 0, 0, padding_len) if padding_right else (0, 0, 0, 0, padding_len, 0), - 'constant', 0) + (0, 0, 0, 0, 0, padding_len) if padding_right else + (0, 0, 0, 0, padding_len, 0), 'constant', 0) routed_experts_list.append(padding_routed_experts) if template.padding_free: - gloabl_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0) + global_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0) else: - gloabl_routed_experts = torch.stack(routed_experts_list) - encoded_batch['routed_experts'] = gloabl_routed_experts.to(device=self.device) + global_routed_experts = torch.stack(routed_experts_list) + encoded_batch['routed_experts'] = global_routed_experts.to(device=self.device) return encoded_batch @@ -992,15 +994,14 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: # In non-padding_free mode, logps are already in batch format [batch_size, seq_len] ref_per_token_logps = ref_per_token_logps_raw batch['ref_per_token_logps'] = ref_per_token_logps - + if self.enable_routing_replay: if self.args.router_replay_mode == 'R2': RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) if self.args.router_replay_mode == 'R3': RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - output = self.model_forward(self.unwrapped_models[0], iter([deepcopy(inputs)]), - no_grad=True, per_token=True) + output = self.model_forward(self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True) old_per_token_logps_raw = output['logps'] if self.template.padding_free: old_per_token_logps, _ = pad_logps_back_to_batch( @@ -1013,9 +1014,9 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: batch['old_per_token_logps'] = old_per_token_logps if self.enable_routing_replay: - batch['routed_experts'] = output['layers_topk_idx'] - RouterReplay.clear_global_indices() - RouterReplay.clear_global_router_replay_action() + batch['routed_experts'] = output['layers_topk_idx'] + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() return batch @@ -1500,8 +1501,9 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): layers_topk_idx = None global_topk_idx = data.pop('routed_experts', None) if RouterReplayHelper.is_replay_forward_action(model.config): - assert global_topk_idx is not None, "When router_replay_mode = R3, routed_experts must be in data" - layers_topk_idx = get_local_topk_idx_for_current_rank(global_topk_idx, model.config, data.get('packed_seq_params')) + assert global_topk_idx is not None, 'When router_replay_mode = R3, routed_experts must be in data' + layers_topk_idx = get_local_topk_idx_for_current_rank(global_topk_idx, model.config, + data.get('packed_seq_params')) set_router_replay_data(layers_topk_idx, model.config) with context: @@ -1515,7 +1517,7 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False): num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0] data['logps'] = None if labels is None else self.get_logps( output_tensor, labels, packed_seq_params, num_samples, per_token=per_token) - + if RouterReplayHelper.is_r2_record_action(model.config): layers_topk_idx = get_router_replay_data(model.config) if layers_topk_idx is not None: diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index b5609d1d8f..184f65784d 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -9,7 +9,7 @@ from .parallel_utils import (logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group, split_cp_inputs) from .patcher import patch_merge_fn, patch_torch_dist_shard -from .utils import (copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, - prepare_mcore_model, tuners_sharded_state_dict) from .router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch from .router_replay_utils import * +from .utils import (copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, + prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/router_replay_patch.py b/swift/megatron/utils/router_replay_patch.py index c630f6dd81..1798a5a635 100644 --- a/swift/megatron/utils/router_replay_patch.py +++ b/swift/megatron/utils/router_replay_patch.py @@ -1,26 +1,22 @@ +import torch import warnings from enum import Enum -from typing import List, Optional, Callable - -import torch +from typing import Callable, List, Optional try: - from megatron.core.transformer.moe.moe_utils import ( - apply_router_token_dropping, - compute_routing_scores_for_aux_loss, - group_limited_topk, - ) + from megatron.core.transformer.moe.moe_utils import (apply_router_token_dropping, + compute_routing_scores_for_aux_loss, group_limited_topk) except ImportError: - warnings.warn("NPU not support router replay for now.", stacklevel=2) + warnings.warn('NPU not support router replay for now.', stacklevel=2) pass from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.transformer_config import TransformerConfig class RouterReplayAction(Enum): - RECORD = "record" # Record the topk indices for replay - REPLAY_FORWARD = "replay_forward" # Replay the recorded topk indices for forward pass - REPLAY_BACKWARD = "replay_backward" # Replay topk indices for re-compute during backward pass + RECORD = 'record' # Record the topk indices for replay + REPLAY_FORWARD = 'replay_forward' # Replay the recorded topk indices for forward pass + REPLAY_BACKWARD = 'replay_backward' # Replay topk indices for re-compute during backward pass class RouterReplay: @@ -42,10 +38,8 @@ def set_replay_data(all_layers_topk_indices: List[torch.Tensor]): must match the instantiation order of the routers. """ if len(all_layers_topk_indices) != len(RouterReplay.global_router_replay_instances): - raise ValueError( - f"The number of replay tensors ({len(all_layers_topk_indices)}) " - f"does not match router instances ({len(RouterReplay.global_router_replay_instances)})." - ) + raise ValueError(f'The number of replay tensors ({len(all_layers_topk_indices)}) ' + f'does not match router instances ({len(RouterReplay.global_router_replay_instances)}).') for i, router_instance in enumerate(RouterReplay.global_router_replay_instances): router_instance.set_target_indices(all_layers_topk_indices[i]) @@ -55,9 +49,7 @@ def get_recorded_data() -> List[torch.Tensor]: Collects the recorded topk indices from all RouterReplay instances. :return: A list of tensors, each containing the recorded topk indices for a layer. """ - return [ - router.get_recorded_indices() for router in RouterReplay.global_router_replay_instances - ] + return [router.get_recorded_indices() for router in RouterReplay.global_router_replay_instances] @staticmethod def clear_global_indices(): @@ -84,9 +76,7 @@ def __init__(self): self.router_replay_action: Optional[RouterReplayAction] = ( None # Router replay action for this layer ) - self.replay_backward_list: List[torch.Tensor] = ( - [] - ) # List of tensors for backward pass replay + self.replay_backward_list: List[torch.Tensor] = ([]) # List of tensors for backward pass replay RouterReplay.global_router_replay_instances.append(self) def set_target_indices(self, topk_indices: torch.Tensor): @@ -122,15 +112,11 @@ def get_replay_topk( topk: int, num_groups: Optional[int] = None, group_topk: Optional[int] = None, - default_compute_topk: Callable[ - [torch.Tensor, int, Optional[int], Optional[int]], torch.Tensor - ] = None, + default_compute_topk: Callable[[torch.Tensor, int, Optional[int], Optional[int]], torch.Tensor] = None, ) -> torch.Tensor: """Returns the target topk indices for replay.""" if self.router_replay_action == RouterReplayAction.RECORD: - probs, top_indices = default_compute_topk( - scores, topk, num_groups=num_groups, group_topk=group_topk - ) + probs, top_indices = default_compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) self.record_indices(top_indices) return probs, top_indices elif self.router_replay_action == RouterReplayAction.REPLAY_FORWARD: @@ -158,7 +144,7 @@ def _patched_topk_routing_with_score_function( num_groups: Optional[int] = None, group_topk: Optional[int] = None, scaling_factor: Optional[float] = None, - score_function: str = "softmax", + score_function: str = 'softmax', expert_bias: Optional[torch.Tensor] = None, fused: bool = False, router_replay: Optional['RouterReplay'] = None, @@ -166,7 +152,7 @@ def _patched_topk_routing_with_score_function( """ Patched version of topk_routing_with_score_function that supports router replay. """ - assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." + assert logits.dim() == 2, f'Expected 2D logits [num_tokens, num_experts], got {logits.dim()}.' num_tokens, num_experts = logits.shape def _compute_topk(scores, topk, num_groups=None, group_topk=None): @@ -188,18 +174,16 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): if router_replay is None: return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) else: - return router_replay.get_replay_topk( - scores, topk, num_groups, group_topk, _compute_topk - ) + return router_replay.get_replay_topk(scores, topk, num_groups, group_topk, _compute_topk) - if score_function == "softmax": + if score_function == 'softmax': if use_pre_softmax: scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) else: scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) - elif score_function == "sigmoid": + elif score_function == 'sigmoid': scores = torch.sigmoid(logits.float()).type_as(logits) if expert_bias is not None: scores_for_routing = scores + expert_bias @@ -209,7 +193,7 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores else: - raise ValueError(f"Invalid score_function: {score_function}") + raise ValueError(f'Invalid score_function: {score_function}') if scaling_factor: probs = probs * scaling_factor @@ -221,9 +205,7 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): routing_probs.index_put_((rows, top_indices), probs, accumulate=False) routing_map = torch.zeros_like(logits, dtype=logits.dtype) - routing_map.index_put_( - (rows, top_indices), torch.ones_like(probs, dtype=routing_map.dtype), accumulate=False - ) + routing_map.index_put_((rows, top_indices), torch.ones_like(probs, dtype=routing_map.dtype), accumulate=False) routing_map = routing_map.bool() else: # TODO Try using element-wise operations instead of scatter? @@ -251,7 +233,7 @@ def patched_routing(self, logits: torch.Tensor): logits = self.apply_z_loss(logits) # Calculate probs and routing_map for token dispatching - if self.routing_type == "sinkhorn": + if self.routing_type == 'sinkhorn': probs, routing_map = self.sinkhorn_load_balancing(logits) else: probs, routing_map = _patched_topk_routing_with_score_function( @@ -282,15 +264,10 @@ def patched_routing(self, logits: torch.Tensor): if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled(): # Calculate scores and routing_map for aux loss routing_map_for_aux_loss, scores_for_aux_loss = compute_routing_scores_for_aux_loss( - logits, self.topk, self.score_function, fused=self.config.moe_router_fusion - ) + logits, self.topk, self.score_function, fused=self.config.moe_router_fusion) probs = self._apply_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss) - probs = self._apply_seq_aux_loss( - probs, scores_for_aux_loss, routing_map_for_aux_loss, seq_length, bsz - ) - probs = self._apply_global_aux_loss( - probs, scores_for_aux_loss, routing_map_for_aux_loss - ) + probs = self._apply_seq_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss, seq_length, bsz) + probs = self._apply_global_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss) # Optionally apply expert bias # Update expert bias and tokens_per_expert @@ -308,11 +285,11 @@ def apply_router_replay_patch(): This patch dynamically adds the 'enable_routing_replay' attribute to TransformerConfig and modifies the TopKRouter to support recording and replaying of routing decisions. """ - print("Applying Router Replay Patch...") + print('Applying Router Replay Patch...') # Clear router instances to avoid state leakage between model initializations. RouterReplay.global_router_replay_instances.clear() # Step 1: Patch TransformerConfig to include the feature flag - if not hasattr(TransformerConfig, "enable_routing_replay"): + if not hasattr(TransformerConfig, 'enable_routing_replay'): # Add class attribute with default value TransformerConfig.enable_routing_replay = False @@ -321,7 +298,7 @@ def apply_router_replay_patch(): # Define new __init__ method that safely handles enable_routing_replay parameter def patched_tf_config_init(self, *args, **kwargs): - enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) + enable_routing_replay = kwargs.pop('enable_routing_replay', TransformerConfig.enable_routing_replay) # Call original constructor with remaining kwargs original_tf_config_init(self, *args, **kwargs) # Set the instance attribute @@ -331,7 +308,7 @@ def patched_tf_config_init(self, *args, **kwargs): TransformerConfig.__init__ = patched_tf_config_init # Step 2: Patch TopKRouter only once to ensure idempotency. - if hasattr(TopKRouter, "_router_replay_patched"): + if hasattr(TopKRouter, '_router_replay_patched'): return original_init = TopKRouter.__init__ diff --git a/swift/megatron/utils/router_replay_utils.py b/swift/megatron/utils/router_replay_utils.py index 1b80428a58..78db76af6c 100644 --- a/swift/megatron/utils/router_replay_utils.py +++ b/swift/megatron/utils/router_replay_utils.py @@ -4,19 +4,18 @@ """ import torch - -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import get_transformer_layer_offset -from megatron.core.transformer.transformer_block import get_num_layers_to_build +from megatron.core import mpu from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region +from megatron.core.transformer.transformer_block import get_num_layers_to_build +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset -from megatron.core import mpu -from swift.utils.torch_utils import get_current_device -from swift.megatron.utils import RouterReplay, RouterReplayAction from swift.megatron.trainers.utils import split_cp_inputs +from swift.megatron.utils import RouterReplay, RouterReplayAction +from swift.utils.torch_utils import get_current_device device_name = get_current_device() + def get_local_layer_range(tf_config, vp_rank=None): vp_size = tf_config.virtual_pipeline_model_parallel_size if vp_size is not None: @@ -32,6 +31,7 @@ def get_local_layer_range(tf_config, vp_rank=None): count = get_num_layers_to_build(tf_config, vp_rank) return offset, count + def get_local_topk_idx_for_current_rank(global_topk_idx, tf_config, packed_seq_params=None): if global_topk_idx is None: return None @@ -48,6 +48,7 @@ def get_local_topk_idx_for_current_rank(global_topk_idx, tf_config, packed_seq_p local_topk_idx = scatter_to_sequence_parallel_region(local_topk_idx.transpose(0, 1)).transpose(0, 1) return local_topk_idx + def get_router_replay_data(tf_config, vp_rank=None): router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) layers_topk_idx = [] @@ -57,6 +58,7 @@ def get_router_replay_data(tf_config, vp_rank=None): layers_topk_idx = torch.stack(layers_topk_idx).transpose(0, 1).unsqueeze(0).to(device_name) return layers_topk_idx + def set_router_replay_data(layers_topk_idx, tf_config, vp_rank=None): # bs, seq_len, layer_num, topk -> layer_num, total_seq_len, topk layers_topk_idx_reshape = layers_topk_idx.flatten(0, 1).transpose(0, 1).to(device_name) @@ -64,7 +66,7 @@ def set_router_replay_data(layers_topk_idx, tf_config, vp_rank=None): offset, _ = get_local_layer_range(tf_config, vp_rank) for i, router in enumerate(router_instances_list): router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64)) - + class RouterReplayHelper: """Helper class to query router replay state and locate local RouterReplay instances.""" @@ -87,7 +89,7 @@ def get_micro_batch_router_list(tf_config, vp_rank=None): list: A contiguous sublist of RouterReplay.router_instances for the local layer range. """ offset, count = get_local_layer_range(tf_config, vp_rank) - router_instances_list = RouterReplay.global_router_replay_instances[offset : offset + count] + router_instances_list = RouterReplay.global_router_replay_instances[offset:offset + count] return router_instances_list @staticmethod @@ -98,10 +100,7 @@ def is_r2_record_action(tf_config, vp_rank=None) -> bool: RouterReplayAction.RECORD. """ router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) - return ( - router_instances_list - and router_instances_list[0].router_replay_action == RouterReplayAction.RECORD - ) + return (router_instances_list and router_instances_list[0].router_replay_action == RouterReplayAction.RECORD) @staticmethod def is_replay_forward_action(tf_config, vp_rank=None) -> bool: @@ -111,10 +110,8 @@ def is_replay_forward_action(tf_config, vp_rank=None) -> bool: RouterReplayAction.REPLAY_FORWARD. """ router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) - return ( - router_instances_list - and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_FORWARD - ) + return (router_instances_list + and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_FORWARD) @staticmethod def is_replay_backward_action(tf_config, vp_rank=None) -> bool: @@ -124,7 +121,5 @@ def is_replay_backward_action(tf_config, vp_rank=None) -> bool: RouterReplayAction.REPLAY_BACKWARD. """ router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) - return ( - router_instances_list - and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_BACKWARD - ) + return (router_instances_list + and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_BACKWARD) From 75bddf2a03a73bd4a28b73d64553f0d5c9d477cc Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Mon, 9 Mar 2026 16:43:47 +0800 Subject: [PATCH 05/14] optimization code --- swift/infer_engine/protocol.py | 10 ++++++++-- swift/megatron/model/model_config.py | 4 ++-- swift/megatron/trainers/rollout_mixin.py | 2 ++ swift/megatron/utils/router_replay_patch.py | 15 ++++++++------- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/swift/infer_engine/protocol.py b/swift/infer_engine/protocol.py index a286c9bbee..c4b5a1a22d 100644 --- a/swift/infer_engine/protocol.py +++ b/swift/infer_engine/protocol.py @@ -20,7 +20,12 @@ def serialize_ndarray(value): if value is None: return None if isinstance(value, np.ndarray): - return {'data': value.tolist(), 'shape': value.shape, 'dtype': str(value.dtype), '__ndarray__': True} + return { + 'data': base64.b64encode(value.tobytes()).decode('ascii'), + 'shape': value.shape, + 'dtype': str(value.dtype), + '__ndarray__': True + } return value @@ -28,7 +33,8 @@ def deserialize_ndarray(value): if value is None: return None if isinstance(value, dict) and value.get('__ndarray__'): - return np.array(value['data'], dtype=value['dtype']).reshape(value['shape']) + data = base64.b64decode(value['data']) + return np.frombuffer(data, dtype=value['dtype']).reshape(value['shape']) return value diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index f723bf115a..e08b4cfd81 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -157,7 +157,7 @@ class MegatronModelConfig(TransformerConfig): 'none'] = 'aux_loss' use_shared_expert_gate: bool = False - enable_routing_replay: bool = False + moe_enable_routing_replay: bool = False # mla multi_latent_attention: bool = False @@ -512,7 +512,7 @@ def get_mcore_model_config(args, hf_config): kwargs['expert_tensor_parallel_size'] = 1 if args.router_replay_mode != 'disabled': - kwargs['enable_routing_replay'] = True + kwargs['moe_enable_routing_replay'] = True config = MegatronModelConfig(**kwargs) config.hf_config = hf_config diff --git a/swift/megatron/trainers/rollout_mixin.py b/swift/megatron/trainers/rollout_mixin.py index 7601a35de3..06f7abe01e 100644 --- a/swift/megatron/trainers/rollout_mixin.py +++ b/swift/megatron/trainers/rollout_mixin.py @@ -246,6 +246,8 @@ def _prepare_vllm_engine(self): load_format = vllm_engine_kwargs.pop('load_format', 'dummy') if self.args.router_replay_mode == 'R3': + assert check_vllm_version_ge('0.14.0'), \ + 'The enable_return_routed_experts attribute is not supported. Please upgrade vllm to 0.14.0 or higher' vllm_engine_kwargs['enable_return_routed_experts'] = True engine = GRPOVllmEngine( diff --git a/swift/megatron/utils/router_replay_patch.py b/swift/megatron/utils/router_replay_patch.py index 1798a5a635..4eae0b6b5d 100644 --- a/swift/megatron/utils/router_replay_patch.py +++ b/swift/megatron/utils/router_replay_patch.py @@ -282,27 +282,28 @@ def patched_routing(self, logits: torch.Tensor): def apply_router_replay_patch(): """ Applies the monkey patch for MoE Router Replay functionality. - This patch dynamically adds the 'enable_routing_replay' attribute to TransformerConfig + This patch dynamically adds the 'moe_enable_routing_replay' attribute to TransformerConfig and modifies the TopKRouter to support recording and replaying of routing decisions. """ print('Applying Router Replay Patch...') # Clear router instances to avoid state leakage between model initializations. RouterReplay.global_router_replay_instances.clear() # Step 1: Patch TransformerConfig to include the feature flag - if not hasattr(TransformerConfig, 'enable_routing_replay'): + if not hasattr(TransformerConfig, 'moe_enable_routing_replay'): # Add class attribute with default value - TransformerConfig.enable_routing_replay = False + TransformerConfig.moe_enable_routing_replay = False # Store original __init__ method original_tf_config_init = TransformerConfig.__init__ - # Define new __init__ method that safely handles enable_routing_replay parameter + # Define new __init__ method that safely handles moe_enable_routing_replay parameter def patched_tf_config_init(self, *args, **kwargs): - enable_routing_replay = kwargs.pop('enable_routing_replay', TransformerConfig.enable_routing_replay) + moe_enable_routing_replay = kwargs.pop('moe_enable_routing_replay', + TransformerConfig.moe_enable_routing_replay) # Call original constructor with remaining kwargs original_tf_config_init(self, *args, **kwargs) # Set the instance attribute - self.enable_routing_replay = enable_routing_replay + self.moe_enable_routing_replay = moe_enable_routing_replay # Apply the patch TransformerConfig.__init__ = patched_tf_config_init @@ -317,7 +318,7 @@ def patched_tf_config_init(self, *args, **kwargs): def patched_init(self, *args, **kwargs): original_init(self, *args, **kwargs) self.router_replay = None - if self.config.enable_routing_replay: + if self.config.moe_enable_routing_replay: self.router_replay = RouterReplay() # Step 4: Apply the patches From b4113a5ef524d7974f957cc0b8e722e78c0f548f Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Wed, 11 Mar 2026 16:19:33 +0800 Subject: [PATCH 06/14] add MoeAlltoAllTokenDispatcher and hybrid model bugfix --- swift/megatron/utils/router_replay_patch.py | 33 +++++++++++-- swift/megatron/utils/router_replay_utils.py | 53 +++++++++++++++++++-- 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/swift/megatron/utils/router_replay_patch.py b/swift/megatron/utils/router_replay_patch.py index 4eae0b6b5d..ccec368724 100644 --- a/swift/megatron/utils/router_replay_patch.py +++ b/swift/megatron/utils/router_replay_patch.py @@ -6,9 +6,10 @@ try: from megatron.core.transformer.moe.moe_utils import (apply_router_token_dropping, compute_routing_scores_for_aux_loss, group_limited_topk) + from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher except ImportError: warnings.warn('NPU not support router replay for now.', stacklevel=2) - pass + MoEAlltoAllTokenDispatcher = None from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.transformer_config import TransformerConfig @@ -246,8 +247,7 @@ def patched_routing(self, logits: torch.Tensor): score_function=self.score_function, expert_bias=self.expert_bias, fused=self.config.moe_router_fusion, - router_replay=self.router_replay, - ) + router_replay=getattr(self, 'router_replay', None)) # Apply token dropping to probs and routing_map. if self.config.moe_expert_capacity_factor is not None: @@ -321,7 +321,32 @@ def patched_init(self, *args, **kwargs): if self.config.moe_enable_routing_replay: self.router_replay = RouterReplay() - # Step 4: Apply the patches + # Step 4: Patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay + # When router replay is enabled, duplicate indices in top_indices can cause + # routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall. + if MoEAlltoAllTokenDispatcher is not None and not hasattr(MoEAlltoAllTokenDispatcher, '_preprocess_patched'): + original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + + def patched_preprocess(self, routing_map): + """Patched preprocess that handles router replay correctly for alltoall dispatcher.""" + # Call original preprocess + result = original_preprocess(self, routing_map) + + # Fix num_out_tokens when router replay is enabled + if (getattr(self.config, 'moe_enable_routing_replay', False) and not self.drop_and_pad + and self.config.moe_expert_capacity_factor is None + and not (getattr(self.config, 'moe_router_padding_for_quantization', None) + or getattr(self.config, 'moe_router_padding_for_fp8', None))): + # With router replay, duplicate indices can reduce the actual routed + # token count, so derive it from the routing map instead. + self.num_out_tokens = int(routing_map.sum().item()) + + return result + + MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess + MoEAlltoAllTokenDispatcher._preprocess_patched = True + + # Step 5: Apply the patches TopKRouter.__init__ = patched_init TopKRouter.routing = patched_routing TopKRouter._router_replay_patched = True diff --git a/swift/megatron/utils/router_replay_utils.py b/swift/megatron/utils/router_replay_utils.py index 78db76af6c..011ab8fb73 100644 --- a/swift/megatron/utils/router_replay_utils.py +++ b/swift/megatron/utils/router_replay_utils.py @@ -16,7 +16,39 @@ device_name = get_current_device() -def get_local_layer_range(tf_config, vp_rank=None): +def is_moe_layer(tf_config, layer_idx): + moe_layer_freq = getattr(tf_config, 'moe_layer_freq', None) + if isinstance(moe_layer_freq, int): + return layer_idx % moe_layer_freq == 0 + elif isinstance(moe_layer_freq, list): + return moe_layer_freq[layer_idx] == 1 + else: + raise ValueError(f'Unsupported moe_layer_freq type: {type(moe_layer_freq)}') + + +def get_moe_num_layers_to_build(tf_config, vp_stage=None, pp_rank=None): + """Count the number of MoE layers assigned to the current rank. + When ``moe_layer_freq`` is 1 or unset, every transformer layer is an MoE + layer, so the count equals the total layer count. Otherwise only layers + whose global index satisfies the frequency predicate are counted. + Args: + config: Megatron TransformerConfig providing layer layout information. + vp_stage: Virtual-pipeline stage index (None defaults to current). + pp_rank: Pipeline-parallel rank (None defaults to current). + Returns: + Number of MoE layers on the specified rank/stage. + """ + total_layers = get_num_layers_to_build(tf_config, vp_stage=vp_stage, pp_rank=pp_rank) + + layer_offset = get_transformer_layer_offset(tf_config, vp_stage=vp_stage) + local_global_indices = range(layer_offset, layer_offset + total_layers) + + num_moe_layers = sum(1 for idx in local_global_indices if is_moe_layer(tf_config, idx)) + + return num_moe_layers + + +def get_local_layer_range(tf_config, vp_rank=None, only_moe_layer=True): vp_size = tf_config.virtual_pipeline_model_parallel_size if vp_size is not None: vp_rank = 0 if vp_rank is None else vp_rank @@ -24,11 +56,13 @@ def get_local_layer_range(tf_config, vp_rank=None): for pre_vp_stage in range(vp_size): if pre_vp_stage == vp_rank: break - num_layers_to_build = get_num_layers_to_build(tf_config, pre_vp_stage) + num_layers_to_build = get_moe_num_layers_to_build( + tf_config, pre_vp_stage) if only_moe_layer else get_num_layers_to_build(tf_config, pre_vp_stage) offset += num_layers_to_build else: offset = 0 - count = get_num_layers_to_build(tf_config, vp_rank) + count = get_moe_num_layers_to_build(tf_config, vp_rank) if only_moe_layer else get_num_layers_to_build( + tf_config, vp_rank) return offset, count @@ -36,10 +70,19 @@ def get_local_topk_idx_for_current_rank(global_topk_idx, tf_config, packed_seq_p if global_topk_idx is None: return None # 1. pp slice + # For the hybrid model, global_topk_idx contains data from all layers + # because vLLM reports routed_experts across all transformer layers(including dense). + # However megatron only has routers for MoE layers. + # So local_topk_idx should filter only data from the MoE layer. layer_offset = get_transformer_layer_offset(tf_config, vp_stage=0) - offset, count = get_local_layer_range(tf_config, tf_config.virtual_pipeline_model_parallel_size) + offset, count = get_local_layer_range( + tf_config, tf_config.virtual_pipeline_model_parallel_size, only_moe_layer=False) num_layers = offset + count - local_topk_idx = torch.narrow(global_topk_idx, dim=2, start=layer_offset, length=num_layers) + moe_layer_idx = torch.tensor([ + layer_idx for layer_idx in range(layer_offset, layer_offset + num_layers) if is_moe_layer(tf_config, layer_idx) + ], + dtype=torch.long) + local_topk_idx = torch.index_select(global_topk_idx, dim=2, index=moe_layer_idx) # 2. cp slice cp_size = mpu.get_context_parallel_world_size() if cp_size > 1: From e0700f66ab9f97bc8ee083895832ac051b0e6f2d Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Mon, 16 Mar 2026 19:06:18 +0800 Subject: [PATCH 07/14] bugfix --- swift/megatron/utils/router_replay_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/utils/router_replay_utils.py b/swift/megatron/utils/router_replay_utils.py index 011ab8fb73..48d40c1d6a 100644 --- a/swift/megatron/utils/router_replay_utils.py +++ b/swift/megatron/utils/router_replay_utils.py @@ -81,7 +81,7 @@ def get_local_topk_idx_for_current_rank(global_topk_idx, tf_config, packed_seq_p moe_layer_idx = torch.tensor([ layer_idx for layer_idx in range(layer_offset, layer_offset + num_layers) if is_moe_layer(tf_config, layer_idx) ], - dtype=torch.long) + dtype=torch.long, device=global_topk_idx.device) local_topk_idx = torch.index_select(global_topk_idx, dim=2, index=moe_layer_idx) # 2. cp slice cp_size = mpu.get_context_parallel_world_size() From 278f16ef722ac4b61fe9bf9897823fba3c14a46e Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Fri, 20 Mar 2026 14:08:00 +0800 Subject: [PATCH 08/14] fix linting --- swift/megatron/arguments/megatron_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/arguments/megatron_args.py b/swift/megatron/arguments/megatron_args.py index b3fe73e107..e9c4f003e8 100644 --- a/swift/megatron/arguments/megatron_args.py +++ b/swift/megatron/arguments/megatron_args.py @@ -153,7 +153,7 @@ class RLHFMegatronArgumentsMixin: log_entropy: bool = False # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 top_entropy_quantile: float = 1.0 - + router_replay_mode: Literal['disabled', 'R2', 'R3'] = 'disabled' # ─────────────────────────── Not Supported Yet ─────────────────────────── From c139541e569c29849f52e55dda22579ece6fcded Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Mon, 23 Mar 2026 19:12:44 +0800 Subject: [PATCH 09/14] native megatron support --- swift/megatron/trainers/base.py | 18 +- swift/megatron/trainers/grpo_trainer.py | 6 +- swift/megatron/utils/__init__.py | 1 - swift/megatron/utils/router_replay_patch.py | 352 -------------------- swift/megatron/utils/router_replay_utils.py | 40 ++- 5 files changed, 50 insertions(+), 367 deletions(-) delete mode 100644 swift/megatron/utils/router_replay_patch.py diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 7214226960..dd42b6b3b1 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -17,6 +17,7 @@ from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.moe_utils import track_moe_metrics +from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from modelscope import check_local_model_is_latest from packaging import version @@ -27,12 +28,13 @@ from swift.megatron.callbacks import megatron_callbacks_map from swift.megatron.model import get_mcore_model from swift.megatron.tuners import LoraParallelLinear -from swift.megatron.utils import ( - RouterReplay, RouterReplayAction, apply_router_replay_patch, copy_original_module_weight, disable_forward_pre_hook, - enable_forward_pre_hook, get_optimizer_param_scheduler, get_padding_to, init_persistent_async_worker, - initialize_tp_communicators, load_mcore_checkpoint, logical_and_across_model_parallel_group, - maybe_finalize_async_save, prepare_mcore_model, reduce_max_stat_across_model_parallel_group, save_mcore_checkpoint, - should_disable_forward_pre_hook, warmup_jit_function, wrap_model) +from swift.megatron.utils import (apply_router_replay_patch, copy_original_module_weight, disable_forward_pre_hook, + enable_forward_pre_hook, get_optimizer_param_scheduler, get_padding_to, + init_persistent_async_worker, initialize_tp_communicators, load_mcore_checkpoint, + logical_and_across_model_parallel_group, maybe_finalize_async_save, + prepare_mcore_model, reduce_max_stat_across_model_parallel_group, + save_mcore_checkpoint, should_disable_forward_pre_hook, warmup_jit_function, + wrap_model) from swift.template import Template from swift.trainers import dynamic_gradient_checkpointing from swift.trainers.utils import patch_modelscope_hub_timeout @@ -52,9 +54,11 @@ class BaseMegatronTrainer(ABC): def __init__(self, args, template: Template): - # patch routing_replay + # validate mcore version and patch routing_replay self.enable_routing_replay = args.router_replay_mode != 'disabled' if self.enable_routing_replay: + assert version.parse(megatron.core.__version__) >= version.parse('0.16.0'), \ + 'The routing replay is not supported. Please upgrade megatron-core to 0.16.0 or higher' apply_router_replay_patch() self.args = args diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 3c75774e90..47432cd08e 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -18,14 +18,14 @@ from functools import partial from megatron.core import mpu from megatron.core.rerun_state_machine import RerunDataIterator +from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from typing import Any, Dict, List, Optional, Tuple, Union from swift.dataset import RowPreprocessor from swift.infer_engine.protocol import RequestConfig, RolloutInferRequest, RolloutOutput from swift.megatron.arguments import MegatronArguments, MegatronRLHFArguments -from swift.megatron.utils import (RouterReplay, RouterReplayAction, RouterReplayHelper, forward_step_helper, - get_local_topk_idx_for_current_rank, get_padding_to, get_router_replay_data, - set_random_seed, set_router_replay_data) +from swift.megatron.utils import (RouterReplayHelper, forward_step_helper, get_local_topk_idx_for_current_rank, + get_padding_to, get_router_replay_data, set_random_seed, set_router_replay_data) from swift.rewards import orms from swift.rlhf_trainers.grpo_trainer import DataType from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context, diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 184f65784d..4a4f03da11 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -9,7 +9,6 @@ from .parallel_utils import (logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group, split_cp_inputs) from .patcher import patch_merge_fn, patch_torch_dist_shard -from .router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch from .router_replay_utils import * from .utils import (copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to, prepare_mcore_model, tuners_sharded_state_dict) diff --git a/swift/megatron/utils/router_replay_patch.py b/swift/megatron/utils/router_replay_patch.py deleted file mode 100644 index ccec368724..0000000000 --- a/swift/megatron/utils/router_replay_patch.py +++ /dev/null @@ -1,352 +0,0 @@ -import torch -import warnings -from enum import Enum -from typing import Callable, List, Optional - -try: - from megatron.core.transformer.moe.moe_utils import (apply_router_token_dropping, - compute_routing_scores_for_aux_loss, group_limited_topk) - from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher -except ImportError: - warnings.warn('NPU not support router replay for now.', stacklevel=2) - MoEAlltoAllTokenDispatcher = None -from megatron.core.transformer.moe.router import TopKRouter -from megatron.core.transformer.transformer_config import TransformerConfig - - -class RouterReplayAction(Enum): - RECORD = 'record' # Record the topk indices for replay - REPLAY_FORWARD = 'replay_forward' # Replay the recorded topk indices for forward pass - REPLAY_BACKWARD = 'replay_backward' # Replay topk indices for re-compute during backward pass - - -class RouterReplay: - """ - A class to manage the recording and replaying of MoE routing decisions. - It holds all router instances and provides static methods to globally - control recording and replaying. - """ - - # Static variable to hold all router instances, one per MoE layer. - global_router_replay_instances: List['RouterReplay'] = [] - - @staticmethod - def set_replay_data(all_layers_topk_indices: List[torch.Tensor]): - """ - Distributes the topk indices for all layers to their respective RouterReplay instances. - :param all_layers_topk_indices: A list of tensors, where each tensor contains the - topk indices for a specific layer. The order - must match the instantiation order of the routers. - """ - if len(all_layers_topk_indices) != len(RouterReplay.global_router_replay_instances): - raise ValueError(f'The number of replay tensors ({len(all_layers_topk_indices)}) ' - f'does not match router instances ({len(RouterReplay.global_router_replay_instances)}).') - for i, router_instance in enumerate(RouterReplay.global_router_replay_instances): - router_instance.set_target_indices(all_layers_topk_indices[i]) - - @staticmethod - def get_recorded_data() -> List[torch.Tensor]: - """ - Collects the recorded topk indices from all RouterReplay instances. - :return: A list of tensors, each containing the recorded topk indices for a layer. - """ - return [router.get_recorded_indices() for router in RouterReplay.global_router_replay_instances] - - @staticmethod - def clear_global_indices(): - """Clears the recorded and target topk indices in all instances.""" - for router in RouterReplay.global_router_replay_instances: - router.clear_indices() - - @staticmethod - def set_global_router_replay_action(router_replay_action: RouterReplayAction): - """Sets the router replay action for all router instances.""" - for router in RouterReplay.global_router_replay_instances: - router.set_router_replay_action(router_replay_action) - - @staticmethod - def clear_global_router_replay_action(): - """Clears the router replay action for all router instances.""" - for router in RouterReplay.global_router_replay_instances: - router.clear_router_replay_action() - - def __init__(self): - """Initializes a RouterReplay instance for a specific layer.""" - self.target_topk_idx: Optional[torch.Tensor] = None # Target topk indices for replay - self.recorded_topk_idx: Optional[torch.Tensor] = None # Recorded topk indices for replay - self.router_replay_action: Optional[RouterReplayAction] = ( - None # Router replay action for this layer - ) - self.replay_backward_list: List[torch.Tensor] = ([]) # List of tensors for backward pass replay - RouterReplay.global_router_replay_instances.append(self) - - def set_target_indices(self, topk_indices: torch.Tensor): - """Sets the target topk indices for replay.""" - self.target_topk_idx = topk_indices - self.replay_backward_list.append(topk_indices) - - def get_recorded_indices(self) -> Optional[torch.Tensor]: - """Returns the recorded topk indices.""" - return self.recorded_topk_idx - - def record_indices(self, topk_indices: torch.Tensor): - """Records the topk indices.""" - self.recorded_topk_idx = topk_indices - - def clear_indices(self): - """Clears the recorded and target topk indices.""" - self.recorded_topk_idx = None - self.target_topk_idx = None - self.replay_backward_list = [] - - def set_router_replay_action(self, router_replay_action: RouterReplayAction): - """Sets the router replay action for this layer.""" - self.router_replay_action = router_replay_action - - def clear_router_replay_action(self): - """Clears the router replay action for this layer.""" - self.router_replay_action = None - - def get_replay_topk( - self, - scores: torch.Tensor, - topk: int, - num_groups: Optional[int] = None, - group_topk: Optional[int] = None, - default_compute_topk: Callable[[torch.Tensor, int, Optional[int], Optional[int]], torch.Tensor] = None, - ) -> torch.Tensor: - """Returns the target topk indices for replay.""" - if self.router_replay_action == RouterReplayAction.RECORD: - probs, top_indices = default_compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) - self.record_indices(top_indices) - return probs, top_indices - elif self.router_replay_action == RouterReplayAction.REPLAY_FORWARD: - top_indices = self.target_topk_idx - # Ensure indices are on the correct device - top_indices = top_indices.to(scores.device) - # Gather the scores for the replayed indices to get the probabilities - probs = scores.gather(1, top_indices) - return probs, top_indices - elif self.router_replay_action == RouterReplayAction.REPLAY_BACKWARD: - top_indices = self.replay_backward_list.pop(0) - # Ensure indices are on the correct device - top_indices = top_indices.to(scores.device) - # Gather the scores for the replayed indices to get the probabilities - probs = scores.gather(1, top_indices) - return probs, top_indices - else: - return default_compute_topk(scores, topk, num_groups, group_topk) - - -def _patched_topk_routing_with_score_function( - logits: torch.Tensor, - topk: int, - use_pre_softmax: bool = False, - num_groups: Optional[int] = None, - group_topk: Optional[int] = None, - scaling_factor: Optional[float] = None, - score_function: str = 'softmax', - expert_bias: Optional[torch.Tensor] = None, - fused: bool = False, - router_replay: Optional['RouterReplay'] = None, -): - """ - Patched version of topk_routing_with_score_function that supports router replay. - """ - assert logits.dim() == 2, f'Expected 2D logits [num_tokens, num_experts], got {logits.dim()}.' - num_tokens, num_experts = logits.shape - - def _compute_topk(scores, topk, num_groups=None, group_topk=None): - if group_topk: - return group_limited_topk( - scores=scores, - topk=topk, - num_tokens=num_tokens, - num_experts=num_experts, - num_groups=num_groups, - group_topk=group_topk, - ) - else: - return torch.topk(scores, k=topk, dim=1) - - def compute_topk(scores, topk, num_groups=None, group_topk=None): - # Default behavior if no replay is active - - if router_replay is None: - return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) - else: - return router_replay.get_replay_topk(scores, topk, num_groups, group_topk, _compute_topk) - - if score_function == 'softmax': - if use_pre_softmax: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) - probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) - else: - scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) - probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) - elif score_function == 'sigmoid': - scores = torch.sigmoid(logits.float()).type_as(logits) - if expert_bias is not None: - scores_for_routing = scores + expert_bias - _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) - scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) - else: - scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) - probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores - else: - raise ValueError(f'Invalid score_function: {score_function}') - - if scaling_factor: - probs = probs * scaling_factor - - if torch.are_deterministic_algorithms_enabled(): - # build [num_tokens, num_experts] from [num_tokens, topk] - routing_probs = torch.zeros_like(logits) - rows = torch.arange(num_tokens, device=logits.device).unsqueeze(1) - routing_probs.index_put_((rows, top_indices), probs, accumulate=False) - - routing_map = torch.zeros_like(logits, dtype=logits.dtype) - routing_map.index_put_((rows, top_indices), torch.ones_like(probs, dtype=routing_map.dtype), accumulate=False) - routing_map = routing_map.bool() - else: - # TODO Try using element-wise operations instead of scatter? - routing_probs = torch.zeros_like(logits).scatter(1, top_indices, probs) - routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() - - return routing_probs, routing_map - - -def patched_routing(self, logits: torch.Tensor): - """Top-k routing function - - Args: - logits (torch.Tensor): Logits tensor after gating. - - Returns: - probs (torch.Tensor): The probabilities of token to experts assignment. - routing_map (torch.Tensor): The mapping of token to experts assignment, - with shape [num_tokens, num_experts]. - """ - seq_length, bsz = logits.shape[:2] - logits = logits.view(-1, self.config.num_moe_experts) - - # Apply Z-Loss - logits = self.apply_z_loss(logits) - - # Calculate probs and routing_map for token dispatching - if self.routing_type == 'sinkhorn': - probs, routing_map = self.sinkhorn_load_balancing(logits) - else: - probs, routing_map = _patched_topk_routing_with_score_function( - logits, - self.topk, - use_pre_softmax=self.config.moe_router_pre_softmax, - num_groups=self.config.moe_router_num_groups, - group_topk=self.config.moe_router_group_topk, - scaling_factor=self.config.moe_router_topk_scaling_factor, - score_function=self.score_function, - expert_bias=self.expert_bias, - fused=self.config.moe_router_fusion, - router_replay=getattr(self, 'router_replay', None)) - - # Apply token dropping to probs and routing_map. - if self.config.moe_expert_capacity_factor is not None: - probs, routing_map = apply_router_token_dropping( - probs, - routing_map, - router_topk=self.topk, - capacity_factor=self.config.moe_expert_capacity_factor, - drop_policy=self.config.moe_token_drop_policy, - pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, - ) - - # Apply each aux loss type and attach aux loss autograd function to probs - if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled(): - # Calculate scores and routing_map for aux loss - routing_map_for_aux_loss, scores_for_aux_loss = compute_routing_scores_for_aux_loss( - logits, self.topk, self.score_function, fused=self.config.moe_router_fusion) - probs = self._apply_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss) - probs = self._apply_seq_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss, seq_length, bsz) - probs = self._apply_global_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss) - - # Optionally apply expert bias - # Update expert bias and tokens_per_expert - # Prevent extra local tokens accumulation on evaluation or activation recomputation - if self.enable_expert_bias and torch.is_grad_enabled(): - with torch.no_grad(): - self.local_tokens_per_expert += routing_map.sum(dim=0) - - return probs, routing_map - - -def apply_router_replay_patch(): - """ - Applies the monkey patch for MoE Router Replay functionality. - This patch dynamically adds the 'moe_enable_routing_replay' attribute to TransformerConfig - and modifies the TopKRouter to support recording and replaying of routing decisions. - """ - print('Applying Router Replay Patch...') - # Clear router instances to avoid state leakage between model initializations. - RouterReplay.global_router_replay_instances.clear() - # Step 1: Patch TransformerConfig to include the feature flag - if not hasattr(TransformerConfig, 'moe_enable_routing_replay'): - # Add class attribute with default value - TransformerConfig.moe_enable_routing_replay = False - - # Store original __init__ method - original_tf_config_init = TransformerConfig.__init__ - - # Define new __init__ method that safely handles moe_enable_routing_replay parameter - def patched_tf_config_init(self, *args, **kwargs): - moe_enable_routing_replay = kwargs.pop('moe_enable_routing_replay', - TransformerConfig.moe_enable_routing_replay) - # Call original constructor with remaining kwargs - original_tf_config_init(self, *args, **kwargs) - # Set the instance attribute - self.moe_enable_routing_replay = moe_enable_routing_replay - - # Apply the patch - TransformerConfig.__init__ = patched_tf_config_init - - # Step 2: Patch TopKRouter only once to ensure idempotency. - if hasattr(TopKRouter, '_router_replay_patched'): - return - - original_init = TopKRouter.__init__ - - # Step 3: Define the new __init__ method - def patched_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - self.router_replay = None - if self.config.moe_enable_routing_replay: - self.router_replay = RouterReplay() - - # Step 4: Patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay - # When router replay is enabled, duplicate indices in top_indices can cause - # routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall. - if MoEAlltoAllTokenDispatcher is not None and not hasattr(MoEAlltoAllTokenDispatcher, '_preprocess_patched'): - original_preprocess = MoEAlltoAllTokenDispatcher.preprocess - - def patched_preprocess(self, routing_map): - """Patched preprocess that handles router replay correctly for alltoall dispatcher.""" - # Call original preprocess - result = original_preprocess(self, routing_map) - - # Fix num_out_tokens when router replay is enabled - if (getattr(self.config, 'moe_enable_routing_replay', False) and not self.drop_and_pad - and self.config.moe_expert_capacity_factor is None - and not (getattr(self.config, 'moe_router_padding_for_quantization', None) - or getattr(self.config, 'moe_router_padding_for_fp8', None))): - # With router replay, duplicate indices can reduce the actual routed - # token count, so derive it from the routing map instead. - self.num_out_tokens = int(routing_map.sum().item()) - - return result - - MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess - MoEAlltoAllTokenDispatcher._preprocess_patched = True - - # Step 5: Apply the patches - TopKRouter.__init__ = patched_init - TopKRouter.routing = patched_routing - TopKRouter._router_replay_patched = True diff --git a/swift/megatron/utils/router_replay_utils.py b/swift/megatron/utils/router_replay_utils.py index 48d40c1d6a..5a09382ca9 100644 --- a/swift/megatron/utils/router_replay_utils.py +++ b/swift/megatron/utils/router_replay_utils.py @@ -6,11 +6,12 @@ import torch from megatron.core import mpu from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region +from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction +from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher from megatron.core.transformer.transformer_block import get_num_layers_to_build from megatron.core.transformer.transformer_layer import get_transformer_layer_offset from swift.megatron.trainers.utils import split_cp_inputs -from swift.megatron.utils import RouterReplay, RouterReplayAction from swift.utils.torch_utils import get_current_device device_name = get_current_device() @@ -81,7 +82,8 @@ def get_local_topk_idx_for_current_rank(global_topk_idx, tf_config, packed_seq_p moe_layer_idx = torch.tensor([ layer_idx for layer_idx in range(layer_offset, layer_offset + num_layers) if is_moe_layer(tf_config, layer_idx) ], - dtype=torch.long, device=global_topk_idx.device) + dtype=torch.long, + device=global_topk_idx.device) local_topk_idx = torch.index_select(global_topk_idx, dim=2, index=moe_layer_idx) # 2. cp slice cp_size = mpu.get_context_parallel_world_size() @@ -105,8 +107,8 @@ def get_router_replay_data(tf_config, vp_rank=None): def set_router_replay_data(layers_topk_idx, tf_config, vp_rank=None): # bs, seq_len, layer_num, topk -> layer_num, total_seq_len, topk layers_topk_idx_reshape = layers_topk_idx.flatten(0, 1).transpose(0, 1).to(device_name) - router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) - offset, _ = get_local_layer_range(tf_config, vp_rank) + offset, count = get_local_layer_range(tf_config, vp_rank) + router_instances_list = RouterReplay.global_router_replay_instances[offset:offset + count] for i, router in enumerate(router_instances_list): router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64)) @@ -166,3 +168,33 @@ def is_replay_backward_action(tf_config, vp_rank=None) -> bool: router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) return (router_instances_list and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_BACKWARD) + + +def apply_router_replay_patch(): + """ + Applies the monkey patch for MoE Router Replay functionality. + """ + print('Applying Router Replay Patch...') + + # Patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay + # When router replay is enabled, duplicate indices in top_indices can cause + # routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall. + if MoEAlltoAllTokenDispatcher is not None and not hasattr(MoEAlltoAllTokenDispatcher, '_preprocess_patched'): + original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + + def patched_preprocess(self, routing_map): + """Patched preprocess that handles router replay correctly for alltoall dispatcher.""" + # Call original preprocess + result = original_preprocess(self, routing_map) + # Fix num_out_tokens when router replay is enabled + if (getattr(self.config, 'moe_enable_routing_replay', False) and not self.drop_and_pad + and self.config.moe_expert_capacity_factor is None + and not (getattr(self.config, 'moe_router_padding_for_quantization', None) + or getattr(self.config, 'moe_router_padding_for_fp8', None))): + # With router replay, duplicate indices can reduce the actual routed + # token count, so derive it from the routing map instead. + self.num_out_tokens = int(routing_map.sum().item()) + return result + + MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess + MoEAlltoAllTokenDispatcher._preprocess_patched = True \ No newline at end of file From 177a4d91e43879cda110717fca8a04566183543f Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Mon, 23 Mar 2026 19:37:38 +0800 Subject: [PATCH 10/14] delete unused code --- swift/megatron/model/model_config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/swift/megatron/model/model_config.py b/swift/megatron/model/model_config.py index 4311e84cdc..212b701bc0 100644 --- a/swift/megatron/model/model_config.py +++ b/swift/megatron/model/model_config.py @@ -176,8 +176,6 @@ class MegatronModelConfig(TransformerConfig): 'none'] = 'aux_loss' moe_shared_expert_gate: bool = False - moe_enable_routing_replay: bool = False - # mla multi_latent_attention: bool = False q_lora_rank: Optional[int] = None From 35c8b804b2d0438e28cd53fd8efd1cd283730f13 Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Thu, 26 Mar 2026 18:48:04 +0800 Subject: [PATCH 11/14] fix merged code --- swift/megatron/trainers/grpo_trainer.py | 21 +++++++++++++-------- swift/megatron/trainers/rlhf_mixin.py | 19 ++++++++++++++++--- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 98c983c72f..922ee2ea24 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -24,8 +24,7 @@ from swift.dataset import RowPreprocessor from swift.infer_engine.protocol import RequestConfig, RolloutInferRequest, RolloutOutput from swift.megatron.arguments import MegatronArguments, MegatronRLHFArguments -from swift.megatron.utils import (RouterReplayHelper, forward_step_helper, get_local_topk_idx_for_current_rank, - get_padding_to, get_router_replay_data, set_random_seed, set_router_replay_data) +from swift.megatron.utils import RouterReplayHelper, get_padding_to, set_random_seed, set_router_replay_data from swift.rewards import orms from swift.rlhf_trainers.grpo_trainer import DataType from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context, @@ -982,7 +981,7 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: with self.null_ref_context() as ref_models: assert len(ref_models) == 1, 'GRPO currently does not support VPP.' ref_model = ref_models[0] - ref_per_token_logps_packed = self.compute_per_token_logps( + ref_per_token_logps_packed, _ = self.compute_per_token_logps( ref_model, iter([deepcopy(inputs)]), temperature=self.temperature) if self.template.padding_free: ref_per_token_logps, _ = pad_logps_back_to_batch( @@ -994,7 +993,13 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: ref_per_token_logps = ref_per_token_logps_packed batch['ref_per_token_logps'] = ref_per_token_logps - old_per_token_logps_packed = self.compute_per_token_logps( + if self.enable_routing_replay: + if self.args.router_replay_mode == 'R2': + RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) + if self.args.router_replay_mode == 'R3': + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + old_per_token_logps_packed, routing_topk_idx = self.compute_per_token_logps( self.unwrapped_models[0], iter([deepcopy(inputs)]), temperature=self.temperature) if self.template.padding_free: old_per_token_logps, _ = pad_logps_back_to_batch( @@ -1006,10 +1011,10 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: old_per_token_logps = old_per_token_logps_packed batch['old_per_token_logps'] = old_per_token_logps - # if self.enable_routing_replay: - # batch['routed_experts'] = output['layers_topk_idx'] - # RouterReplay.clear_global_indices() - # RouterReplay.clear_global_router_replay_action() + if self.enable_routing_replay: + batch['routed_experts'] = routing_topk_idx + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() return batch diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index f61b000fc6..792087f875 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -6,7 +6,8 @@ from transformers.utils import ContextManagers from swift.megatron.model import get_mcore_model -from swift.megatron.utils import forward_step_helper, load_mcore_checkpoint +from swift.megatron.utils import (RouterReplayHelper, forward_step_helper, get_local_topk_idx_for_current_rank, + get_router_replay_data, load_mcore_checkpoint, set_router_replay_data) from swift.rlhf_trainers.utils import identity_data_collator from swift.utils import get_logger from .base import BaseMegatronTrainer @@ -107,18 +108,30 @@ def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperatur Returns: per_token_logps tensor, or None if on a non-last PP stage + routing_topk_idx tensor, or None if disbale router replay """ data = self.get_batch(data_iterator) data.pop('loss_scale', None) labels = data.get('labels') + routing_topk_idx = None + global_topk_idx = data.pop('routed_experts', None) + if RouterReplayHelper.is_replay_forward_action(model.config): + assert global_topk_idx is not None, 'When router_replay_mode = R3, routed_experts must be in data' + routing_topk_idx = get_local_topk_idx_for_current_rank(global_topk_idx, model.config, + data.get('packed_seq_params')) + set_router_replay_data(routing_topk_idx, model.config) + data_for_forward = {k: v for k, v in data.items() if k != 'labels'} context = torch.no_grad() if no_grad else nullcontext() with context: output_tensor = forward_step_helper(self.args, model, data_for_forward) + if RouterReplayHelper.is_r2_record_action(model.config): + routing_topk_idx = get_router_replay_data(model.config) + if labels is None or output_tensor is None: - return None + return None, routing_topk_idx if temperature != 1.0: output_tensor.div_(temperature) @@ -133,7 +146,7 @@ def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperatur if self.args.context_parallel_size > 1: per_token_logps = self._postprocess_packed_tensor_cp(per_token_logps, packed_seq_params, num_samples) - return per_token_logps + return per_token_logps, routing_topk_idx def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples): """ From 4ee1125fd0290d14f9c0fc6b5be572af5afe9a97 Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Thu, 26 Mar 2026 19:01:32 +0800 Subject: [PATCH 12/14] fix --- swift/megatron/trainers/grpo_trainer.py | 28 ++----------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 922ee2ea24..106bf9e1ec 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -1151,32 +1151,8 @@ def forward_step(self, data_iterator, model): per_token_logps = per_token_logps_packed per_token_entropy = per_token_entropy_packed - output_tensor = per_token_logps - data['per_token_entropy'] = per_token_entropy - else: - # Standard forward with labels, returns per-token loss (more efficient) - output_tensor = model(**inputs) - - # Convert output_tensor (per-token loss) to per_token_logps on PP last stage - if is_pp_last_stage and output_tensor is not None: - per_token_logps_raw = self.get_logps( - output_tensor, - labels, - packed_seq_params, - packed_seq_params.num_samples if args.padding_free else micro_batch_size, - per_token=True) - - if args.padding_free: - per_token_logps, _ = pad_logps_back_to_batch( - logps_rmpad=per_token_logps_raw, - logits_to_keep=max_seq_len, - batch_size=micro_batch_size, - seq_lengths=seq_lengths) - else: - per_token_logps = per_token_logps_raw - - data['per_token_logps'] = per_token_logps - data['per_token_entropy'] = None + output_tensor = per_token_logps + data['per_token_entropy'] = per_token_entropy if RouterReplayHelper.is_replay_forward_action(model.config): router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config) From 44f76e82e34ed514448c79aa8acd04aac42ff413 Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Fri, 27 Mar 2026 10:21:02 +0800 Subject: [PATCH 13/14] version compatibility --- swift/megatron/trainers/base.py | 6 +++--- swift/megatron/trainers/grpo_trainer.py | 13 ++++++++---- swift/megatron/trainers/rlhf_mixin.py | 4 ++-- swift/megatron/utils/router_replay_utils.py | 23 +++++++++++++++++---- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 24ced94b22..236c7a4f62 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -17,7 +17,6 @@ from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.moe_utils import track_moe_metrics -from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from modelscope import check_local_model_is_latest from packaging import version @@ -45,8 +44,11 @@ try: from megatron.core.optimizer import param_group_identifier_keys + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction except ImportError: param_group_identifier_keys = None + RouterReplay = None + RouterReplayAction = None mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0') @@ -59,8 +61,6 @@ def __init__(self, args, template: Template): # validate mcore version and patch routing_replay self.enable_routing_replay = args.router_replay_mode != 'disabled' if self.enable_routing_replay: - assert version.parse(megatron.core.__version__) >= version.parse('0.16.0'), \ - 'The routing replay is not supported. Please upgrade megatron-core to 0.16.0 or higher' apply_router_replay_patch() self.args = args diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 106bf9e1ec..ddbf691132 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -18,7 +18,6 @@ from functools import partial from megatron.core import mpu from megatron.core.rerun_state_machine import RerunDataIterator -from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from typing import Any, Dict, List, Optional, Tuple, Union from swift.dataset import RowPreprocessor @@ -39,6 +38,12 @@ from .utils import gather, gather_object from .vocab_parallel_utils import compute_logps_and_entropy_from_logits +try: + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction +except ImportError: + RouterReplay = None + RouterReplayAction = None + logger = get_logger() @@ -1096,11 +1101,11 @@ def forward_step(self, data_iterator, model): }) data.pop('loss_scale', None) - if RouterReplayHelper.is_replay_backward_action(model.config): + if self.enable_routing_replay and RouterReplayHelper.is_replay_backward_action(model.config): router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config) for router in router_instance_list: router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - if RouterReplayHelper.is_replay_forward_action(model.config): + if self.enable_routing_replay and RouterReplayHelper.is_replay_forward_action(model.config): layers_topk_idx = data.pop('routed_experts', None) set_router_replay_data(layers_topk_idx, model.config) @@ -1154,7 +1159,7 @@ def forward_step(self, data_iterator, model): output_tensor = per_token_logps data['per_token_entropy'] = per_token_entropy - if RouterReplayHelper.is_replay_forward_action(model.config): + if self.enable_routing_replay and RouterReplayHelper.is_replay_forward_action(model.config): router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config) for router in router_instance_list: router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index 792087f875..fb3737081a 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -116,7 +116,7 @@ def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperatur routing_topk_idx = None global_topk_idx = data.pop('routed_experts', None) - if RouterReplayHelper.is_replay_forward_action(model.config): + if self.enable_routing_replay and RouterReplayHelper.is_replay_forward_action(model.config): assert global_topk_idx is not None, 'When router_replay_mode = R3, routed_experts must be in data' routing_topk_idx = get_local_topk_idx_for_current_rank(global_topk_idx, model.config, data.get('packed_seq_params')) @@ -127,7 +127,7 @@ def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperatur with context: output_tensor = forward_step_helper(self.args, model, data_for_forward) - if RouterReplayHelper.is_r2_record_action(model.config): + if self.enable_routing_replay and RouterReplayHelper.is_r2_record_action(model.config): routing_topk_idx = get_router_replay_data(model.config) if labels is None or output_tensor is None: diff --git a/swift/megatron/utils/router_replay_utils.py b/swift/megatron/utils/router_replay_utils.py index 5a09382ca9..6cd5d485b0 100644 --- a/swift/megatron/utils/router_replay_utils.py +++ b/swift/megatron/utils/router_replay_utils.py @@ -6,14 +6,26 @@ import torch from megatron.core import mpu from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region -from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction -from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher from megatron.core.transformer.transformer_block import get_num_layers_to_build from megatron.core.transformer.transformer_layer import get_transformer_layer_offset from swift.megatron.trainers.utils import split_cp_inputs +from swift.utils import get_logger from swift.utils.torch_utils import get_current_device +logger = get_logger() + +try: + from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher + ROUTER_REPLAY_AVAILABLE = True +except ImportError: + logger.warning('RouterReplay not available in current megatron-core version') + RouterReplay = None + RouterReplayAction = None + MoEAlltoAllTokenDispatcher = None + ROUTER_REPLAY_AVAILABLE = False + device_name = get_current_device() @@ -174,7 +186,10 @@ def apply_router_replay_patch(): """ Applies the monkey patch for MoE Router Replay functionality. """ - print('Applying Router Replay Patch...') + logger.info('Applying Router Replay Patch...') + + assert ROUTER_REPLAY_AVAILABLE, \ + 'The routing replay is not supported. Please upgrade megatron-core to 0.16.0 or higher' # Patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay # When router replay is enabled, duplicate indices in top_indices can cause @@ -197,4 +212,4 @@ def patched_preprocess(self, routing_map): return result MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess - MoEAlltoAllTokenDispatcher._preprocess_patched = True \ No newline at end of file + MoEAlltoAllTokenDispatcher._preprocess_patched = True From a1edcc28dcb26937d44aed2e3ae22c06d35726c2 Mon Sep 17 00:00:00 2001 From: XianlongLi <2286061024@qq.com> Date: Tue, 31 Mar 2026 14:08:29 +0800 Subject: [PATCH 14/14] fix lint --- swift/infer_engine/grpo_vllm_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/swift/infer_engine/grpo_vllm_engine.py b/swift/infer_engine/grpo_vllm_engine.py index a435dd228b..9037b23d93 100644 --- a/swift/infer_engine/grpo_vllm_engine.py +++ b/swift/infer_engine/grpo_vllm_engine.py @@ -122,8 +122,7 @@ def _create_chat_completion_response(self, result, inputs, request_config, reque finish_reason=output.finish_reason, logprobs=logprobs, token_ids=token_ids, - routed_experts=getattr(output, 'routed_experts', None) - ) + routed_experts=getattr(output, 'routed_experts', None)) choices.append(choice) prompt_token_ids = None images_size = None