diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index c618608801f..328b93f99dd 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -22,9 +22,10 @@ compute_routing_scores_for_aux_loss, group_limited_topk, ) + from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher except ImportError: warnings.warn("NPU not support router replay for now.", stacklevel=2) - pass + MoEAlltoAllTokenDispatcher = None from megatron.core.transformer.moe.router import TopKRouter from megatron.core.transformer.transformer_config import TransformerConfig @@ -343,7 +344,34 @@ def patched_init(self, *args, **kwargs): if self.config.enable_routing_replay: self.router_replay = RouterReplay() - # Step 4: Apply the patches + # Step 4: Patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay + # When router replay is enabled, duplicate indices in top_indices can cause + # routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall. + if MoEAlltoAllTokenDispatcher is not None and not hasattr(MoEAlltoAllTokenDispatcher, "_preprocess_patched"): + original_preprocess = MoEAlltoAllTokenDispatcher.preprocess + + def patched_preprocess(self, routing_map): + """Patched preprocess that handles router replay correctly for alltoall dispatcher.""" + # Call original preprocess + result = original_preprocess(self, routing_map) + + # Fix num_out_tokens when router replay is enabled + if ( + getattr(self.config, "enable_routing_replay", False) + and not self.drop_and_pad + and self.config.moe_expert_capacity_factor is None + and not self.config.moe_router_padding_for_quantization + ): + # With router replay, duplicate indices can reduce the actual routed + # token count, so derive it from the routing map instead. + self.num_out_tokens = int(routing_map.sum().item()) + + return result + + MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess + MoEAlltoAllTokenDispatcher._preprocess_patched = True + + # Step 5: Apply the patches TopKRouter.__init__ = patched_init TopKRouter.routing = patched_routing TopKRouter._router_replay_patched = True