Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions verl/utils/megatron/router_replay_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
compute_routing_scores_for_aux_loss,
group_limited_topk,
)
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
except ImportError:
warnings.warn("NPU not support router replay for now.", stacklevel=2)
Comment thread
HollowMan6 marked this conversation as resolved.
pass
MoEAlltoAllTokenDispatcher = None
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.transformer_config import TransformerConfig

Expand Down Expand Up @@ -343,7 +344,34 @@ def patched_init(self, *args, **kwargs):
if self.config.enable_routing_replay:
self.router_replay = RouterReplay()

# Step 4: Apply the patches
# Step 4: Patch MoEAlltoAllTokenDispatcher.preprocess to handle router replay
# When router replay is enabled, duplicate indices in top_indices can cause
# routing_map.sum() < num_tokens * topk, leading to split size mismatch in alltoall.
if MoEAlltoAllTokenDispatcher is not None and not hasattr(MoEAlltoAllTokenDispatcher, "_preprocess_patched"):
original_preprocess = MoEAlltoAllTokenDispatcher.preprocess

def patched_preprocess(self, routing_map):
"""Patched preprocess that handles router replay correctly for alltoall dispatcher."""
# Call original preprocess
result = original_preprocess(self, routing_map)

# Fix num_out_tokens when router replay is enabled
if (
getattr(self.config, "enable_routing_replay", False)
and not self.drop_and_pad
and self.config.moe_expert_capacity_factor is None
and not self.config.moe_router_padding_for_quantization
):
# With router replay, duplicate indices can reduce the actual routed
# token count, so derive it from the routing map instead.
self.num_out_tokens = int(routing_map.sum().item())

return result
Comment thread
HollowMan6 marked this conversation as resolved.

MoEAlltoAllTokenDispatcher.preprocess = patched_preprocess
MoEAlltoAllTokenDispatcher._preprocess_patched = True

# Step 5: Apply the patches
TopKRouter.__init__ = patched_init
TopKRouter.routing = patched_routing
TopKRouter._router_replay_patched = True
Loading