diff --git a/examples/router_replay/README.md b/examples/router_replay/README.md new file mode 100644 index 00000000000..25ccb67fab2 --- /dev/null +++ b/examples/router_replay/README.md @@ -0,0 +1,72 @@ +# Router Replay + +Router Replay is an advanced routing replay functionality within the Verl framework designed for Mixture of Experts (MoE) models. It enables deterministic training by recording and replaying routing decisions, ensuring consistent model behavior across training runs. + + +## Key Features + +### Multiple Operating Modes +- **`disabled`**: Router replay functionality is completely disabled +- **`R2`**: Standard router replay mode for recording and replaying routing decisions +- **`R3`**: Rollout-specific router replay mode optimized for reinforcement learning workflows + +### Core Capabilities +- **Seamless Integration**: Works with reinforcement learning pipelines including PPO +- **Distributed Training Support**: Compatible with multi-GPU and multi-node training environments +- **Flexible Configuration**: Easy to configure via YAML files or command-line parameters + +## Configuration + +### RouterReplayConfig Parameters + +```yaml +router_replay: + mode: "disabled" # Available options: disabled, R2, R3 + record_file: null # Path for recording routing decisions + replay_file: null # Path for replaying recorded decisions +``` + +## Quick Start Guide + +### Enabling R2 Mode + +#### Configuration File Method +Add the following to your training configuration: + +```yaml +actor: + router_replay: + mode: "R2" +``` + +#### Command Line Method +Enable R2 mode via command-line parameters: + +```bash +actor_rollout_ref.actor.router_replay.mode="R2" +actor_rollout_ref.rollout.enable_rollout_routing_replay=True +``` + +### Enabling R3 Mode + +#### Configuration File Method +Configure both actor and rollout settings: + +```yaml +# Actor configuration +router_replay: + mode: "R3" + +# Rollout configuration +enable_rollout_routing_replay: True +``` + +#### Command Line Method +Enable R3 mode via command-line parameters: + +```bash +actor_rollout_ref.actor.router_replay.mode="R3" +actor_rollout_ref.rollout.enable_rollout_routing_replay=True +``` + +R3 mode requires the rollout backend to support returning router selection results. Currently, this functionality is being tested based on the vllm implementation at https://github.com/vllm-project/vllm/pull/28284. diff --git a/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh b/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh new file mode 100644 index 00000000000..9a45bdaac73 --- /dev/null +++ b/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh @@ -0,0 +1,108 @@ + +set -x + +NODES=1 + +# R2: enable routing replay +# R3: enable rollout routing replay +# If enabling R3, please set actor_rollout_ref.rollout.enable_rollout_routing_replay=True +# R3 example is based on vllm related pr https://github.com/vllm-project/vllm/pull/5322 + +ROUTING_REPLAY_MODE="R2" + +DIST_CKPT_PATH="" +HF_MODEL_PATH="" +TRAIN_DATA_PATH="" +TEST_DATA_PATH="" + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping +PP=1 +VPP=None +TP=2 +EP=8 +ETP=1 +VLLM_INFER_TP=2 +offload=True +gpu_memory_utilization=0.65 +bs=8 +micro_bs=3 +use_dynamic_bsz=True +max_prompt_length=1024 +max_response_length=1024 +ppo_mini_batch_size=8 +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) + + +exper_name=Node${NODES}_bs${bs}_${PP}${TP}${EP}${ETP}_${VLLM_INFER_TP}_minbs${ppo_mini_batch_size}_micro_bs${micro_bs} + +python3 -m verl.trainer.main_ppo --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + algorithm.adv_estimator=grpo \ + data.train_files=$TRAIN_DATA_PATH \ + data.val_files=$TEST_DATA_PATH \ + data.train_batch_size=$bs \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.router_replay.mode=${ROUTING_REPLAY_MODE} \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_bs \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_bs \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_INFER_TP \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=$gpu_memory_utilization \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$micro_bs \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k_math' \ + trainer.experiment_name="$exper_name" \ + trainer.nnodes=$NODES \ + trainer.n_gpus_per_node=8 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_training_steps=50000 \ + trainer.balance_batch=False \ + trainer.val_before_train=False 2>&1 diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 883534bdc1f..1fbb8c6c9ce 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -134,6 +134,8 @@ class AgentLoopOutput(BaseModel): """Response mask, 1 for LLM generated token, 0 for tool response token.""" response_logprobs: Optional[list[float]] = None """Log probabilities for the response tokens.""" + routed_experts: Optional[Any] = None + """Routed experts for the total tokens.""" multi_modal_data: Optional[dict[str, Any]] = None """Multi-modal data for multi-modal tools.""" reward_score: Optional[float] = None @@ -165,6 +167,8 @@ class _InternalAgentLoopOutput(AgentLoopOutput): """Padded attention mask.""" response_logprobs: Optional[torch.Tensor] = None """Padded log probabilities for the response tokens.""" + routed_experts: Optional[torch.Tensor] = None + """Padded routed experts for the total tokens.""" multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None """Multi-modal inputs for processors (e.g., pixel_values, image_grid_thw).""" extra_fields: dict[str, Any] = {} @@ -487,6 +491,25 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1) input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1) + routed_experts = None + if output.routed_experts is not None: + total_length = input_ids.shape[1] + length, layer_num, topk_num = output.routed_experts.shape + experts_tensor = torch.from_numpy(output.routed_experts) + routed_experts = torch.zeros(1, total_length, layer_num, topk_num, dtype=experts_tensor.dtype) + + # Calculate start position: left padding means original prompt starts at the end + start_pos = prompt_output["input_ids"].shape[1] - len(output.prompt_ids) + end_pos = min(start_pos + length, total_length) + + # Add boundary checks for robustness + if start_pos < 0 or end_pos > total_length: + raise ValueError( + f"Invalid position range: start_pos={start_pos}, end_pos={end_pos}, total_length={total_length}" + ) + + routed_experts[:, start_pos:end_pos] = experts_tensor.unsqueeze(0) + # Handle multi-modal inputs and position_ids calculation # Only support Qwen2VLImageProcessor for multi-modal processing currently # TODO: support other multi-modal inputs @@ -560,6 +583,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO response_mask=response_mask, attention_mask=attention_mask, response_logprobs=response_logprobs, + routed_experts=routed_experts, multi_modal_inputs=multi_modal_inputs, multi_modal_data=output.multi_modal_data, reward_score=output.reward_score, @@ -580,6 +604,8 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto: optional_outputs = {} if inputs[0].response_logprobs is not None: optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0) + if inputs[0].routed_experts is not None: + optional_outputs["routed_experts"] = torch.cat([input.routed_experts for input in inputs], dim=0) batch = TensorDict( { diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py index 62dcd02f1cf..da3189874ba 100644 --- a/verl/experimental/agent_loop/single_turn_agent_loop.py +++ b/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -73,6 +73,11 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu response_ids=output.token_ids[: self.response_length], response_mask=response_mask[: self.response_length], response_logprobs=output.log_probs[: self.response_length] if output.log_probs else None, + routed_experts=( + output.routed_experts[: len(prompt_ids) + self.response_length] + if output.routed_experts is not None + else None + ), multi_modal_data={"image": image_data} if image_data is not None else {}, num_turns=2, metrics=metrics, diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index 80d1ee3d0c2..e107bb37b51 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -233,6 +233,9 @@ async def _handle_generating_state( if output.log_probs: agent_data.response_logprobs += output.log_probs + if output.routed_experts is not None: + agent_data.routed_experts = output.routed_experts + # Check termination conditions if not ignore_termination and len(agent_data.response_mask) >= self.response_length: return AgentState.TERMINATED diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py index 6ca270c6fb6..6cb0f9e5cce 100644 --- a/verl/models/mcore/util.py +++ b/verl/models/mcore/util.py @@ -144,7 +144,7 @@ def postprocess_packed_seqs( if cp_size > 1: # output shape: [1, packed_len, hidden_dim] # need to gather across cp group and concatenate in sequence dimension - output_list = [torch.empty_like(output) for _ in range(cp_size)] + output_list = [torch.empty_like(output, dtype=output.dtype) for _ in range(cp_size)] torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) output_list[mpu.get_context_parallel_rank()] = output else: @@ -159,7 +159,7 @@ def postprocess_packed_seqs( half_seqlen = s_len_padded_chunk // 2 s_len = seq_lens_cpu[i] s_len_padded = s_len_padded_chunk * cp_size - tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device, dtype=output.dtype) for j in range(cp_size): o = output_list[j][0] # split to 2 chunks diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 7bf04848ea4..6565fe660b1 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -121,6 +121,11 @@ actor_rollout_ref: _target_: verl.utils.profiler.config.TorchMemoryToolConfig trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null data_loader_seed: 42 load_weight: true ref: @@ -156,6 +161,11 @@ actor_rollout_ref: _target_: verl.utils.profiler.config.TorchMemoryToolConfig trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null megatron: _target_: verl.workers.config.McoreEngineConfig param_offload: ${oc.select:actor_rollout_ref.actor.megatron.param_offload,False} @@ -259,6 +269,7 @@ actor_rollout_ref: skip_rollout: false skip_dump_dir: /tmp/rollout_dump skip_tokenizer_init: true + enable_rollout_routing_replay: false profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 6a76e2993b0..86687831ad2 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -107,6 +107,11 @@ actor_rollout_ref: _target_: verl.utils.profiler.config.TorchMemoryToolConfig trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null grad_clip: 1.0 ulysses_sequence_parallel_size: 1 entropy_from_logits_with_chunking: false @@ -145,6 +150,11 @@ actor_rollout_ref: _target_: verl.utils.profiler.config.TorchMemoryToolConfig trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null fsdp_config: _target_: verl.workers.config.FSDPEngineConfig wrap_policy: @@ -245,6 +255,7 @@ actor_rollout_ref: skip_rollout: false skip_dump_dir: /tmp/rollout_dump skip_tokenizer_init: true + enable_rollout_routing_replay: false profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml index b61f898018e..f5f1d15eee5 100644 --- a/verl/trainer/config/actor/actor.yaml +++ b/verl/trainer/config/actor/actor.yaml @@ -220,3 +220,23 @@ profiler: # Stack trace depth for memory allocations stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + +# Router replay configuration for MoE models +router_replay: + + # Target dataclass for this configuration + _target_: verl.workers.config.RouterReplayConfig + + # Router replay mode: disabled, R2, R3 + # - R2: Use R2 routing strategy (record mode) + # - R3: Use R3 routing strategy (record mode) + mode: disabled + + # File path to save recorded routing decisions + # Required when mode is 'record', 'R2', or 'R3' + record_file: null + + # File path to load recorded routing decisions for replay + # Required when mode is 'replay' + replay_file: null + diff --git a/verl/trainer/config/ref/ref.yaml b/verl/trainer/config/ref/ref.yaml index 72b7ff048b2..ec566c25b9a 100644 --- a/verl/trainer/config/ref/ref.yaml +++ b/verl/trainer/config/ref/ref.yaml @@ -100,3 +100,22 @@ profiler: # Stack trace depth for memory allocations stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + +# Router replay configuration for MoE models +router_replay: + + # Target dataclass for this configuration + _target_: verl.workers.config.RouterReplayConfig + + # Router replay mode: disabled, R2, R3 + # - R2: Use R2 routing strategy (record mode) + # - R3: Use R3 routing strategy (record mode) + mode: disabled + + # File path to save recorded routing decisions + # Required when mode is 'record', 'R2', or 'R3' + record_file: null + + # File path to load recorded routing decisions for replay + # Required when mode is 'replay' + replay_file: null \ No newline at end of file diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index b1931344bcc..968d9e11277 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -277,6 +277,11 @@ skip_dump_dir: /tmp/rollout_dump # When enabled (True), the rollout assume token in token out for generation skip_tokenizer_init: True +# Whether to enable rollout routing replay for MoE models +# When enabled (True), the rollout will record the routing decisions. +enable_rollout_routing_replay: False + + # profile the rollout model in `generate_sequence` profiler: diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py new file mode 100644 index 00000000000..69f9aaeb153 --- /dev/null +++ b/verl/utils/megatron/router_replay_patch.py @@ -0,0 +1,343 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + +import torch +from megatron.core.transformer.moe.moe_utils import ( + apply_router_token_dropping, + compute_routing_scores_for_aux_loss, + group_limited_topk, +) +from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.transformer.transformer_config import TransformerConfig + +# https://github.com/THUDM/slime/blob/main/slime/utils/routing_replay.py + + +class RouterReplayAction(Enum): + RECORD = "record" + REPLAY_FORWARD = "replay_forward" + REPLAY_BACKWARD = "replay_backward" + + +class RouterReplay: + """ + A class to manage the recording and replaying of MoE routing decisions. + It holds all router instances and provides static methods to globally + control recording and replaying. + """ + + # Static variable to hold all router instances, one per MoE layer. + router_instances = [] + + @staticmethod + def set_replay_data(all_layers_topk_indices: list): + """ + Distributes the topk indices for all layers to their respective RouterReplay instances. + :param all_layers_topk_indices: A list of tensors, where each tensor contains the + topk indices for a specific layer. The order + must match the instantiation order of the routers. + """ + if len(all_layers_topk_indices) != len(RouterReplay.router_instances): + raise ValueError( + f"The number of replay tensors ({len(all_layers_topk_indices)}) " + f"does not match the number of router instances ({len(RouterReplay.router_instances)})." + ) + for i, router_instance in enumerate(RouterReplay.router_instances): + router_instance.set_target_indices(all_layers_topk_indices[i]) + + @staticmethod + def get_recorded_data() -> list: + """ + Collects the recorded topk indices from all RouterReplay instances. + :return: A list of tensors, each containing the recorded topk indices for a layer. + """ + return [router.get_recorded_indices() for router in RouterReplay.router_instances] + + @staticmethod + def clear_global_indices(): + """Clears the recorded and target topk indices in all instances.""" + for router in RouterReplay.router_instances: + router.clear_indices() + + def __init__(self): + """Initializes a RouterReplay instance for a specific layer.""" + self.target_topk_idx = None # For replay + self.recorded_topk_idx = None # For recording + self.router_replay_action = None # Router replay action for this layer + self.replay_backward_list = [] # List of tensors for backward pass replay + RouterReplay.router_instances.append(self) + + def set_target_indices(self, topk_indices: torch.Tensor): + """Sets the target topk indices for replay.""" + self.target_topk_idx = topk_indices + self.replay_backward_list.append(topk_indices) + + def get_recorded_indices(self): + """Returns the recorded topk indices.""" + return self.recorded_topk_idx + + def record_indices(self, topk_indices: torch.Tensor): + """Records the topk indices.""" + self.recorded_topk_idx = topk_indices + + def clear_indices(self): + """Clears the recorded and target topk indices.""" + self.recorded_topk_idx = None + self.target_topk_idx = None + self.replay_backward_list = [] + + def set_router_replay_action(self, router_replay_action: RouterReplayAction): + """Sets the router replay action for this layer.""" + self.router_replay_action = router_replay_action + + def clear_router_replay_action(self): + """Clears the router replay action for this layer.""" + self.router_replay_action = None + + @staticmethod + def set_global_router_replay_action(router_replay_action: RouterReplayAction): + """Sets the router replay action for all router instances.""" + for router in RouterReplay.router_instances: + router.set_router_replay_action(router_replay_action) + + @staticmethod + def clear_global_router_replay_action(): + """Clears the router replay action for all router instances.""" + for router in RouterReplay.router_instances: + router.clear_router_replay_action() + + +def _patched_topk_routing_with_score_function( + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: int, + group_topk: int, + score_function: str, + expert_bias: torch.Tensor, + fused: bool, + router_replay: RouterReplay, + scaling_factor: float, +): + """ + Patched version of topk_routing_with_score_function that supports router replay. + """ + num_tokens, num_experts = logits.shape + + def _compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return torch.topk(scores, k=topk, dim=1) + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + # Default behavior if no replay is active + + routing_action = router_replay.router_replay_action if router_replay is not None else None + + if routing_action is None: + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + + if routing_action == RouterReplayAction.RECORD: + probs, top_indices = _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + if router_replay is not None: + router_replay.record_indices(top_indices) + return probs, top_indices + + elif routing_action == RouterReplayAction.REPLAY_FORWARD: + if router_replay is None or router_replay.target_topk_idx is None: + # Fallback if replay data is not available + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + + # Use the provided indices for replay + top_indices = router_replay.target_topk_idx + # Ensure indices are on the correct device + top_indices = top_indices.to(scores.device) + # Gather the scores for the replayed indices to get the probabilities + probs = scores.gather(1, top_indices) + return probs, top_indices + elif routing_action == RouterReplayAction.REPLAY_BACKWARD: + if router_replay is None or not router_replay.replay_backward_list: + # Fallback if replay data is not available + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + + # Use the last recorded indices for backward replay + top_indices = router_replay.replay_backward_list.pop(0) + # Ensure indices are on the correct device + top_indices = top_indices.to(scores.device) + # Gather the scores for the replayed indices to get the probabilities + probs = scores.gather(1, top_indices) + return probs, top_indices + else: # Unknown action, fallback + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) + probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) + elif score_function == "sigmoid": + scores = torch.sigmoid(logits.float()).type_as(logits) + if expert_bias is not None: + scores_for_routing = scores + expert_bias + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + else: + scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) + probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + if torch.are_deterministic_algorithms_enabled(): + # build [num_tokens, num_experts] from [num_tokens, topk] + routing_probs = torch.zeros_like(logits) + rows = torch.arange(num_tokens, device=logits.device).unsqueeze(1) + routing_probs.index_put_((rows, top_indices), probs, accumulate=False) + + routing_map = torch.zeros_like(logits, dtype=logits.dtype) + routing_map.index_put_((rows, top_indices), torch.ones_like(probs, dtype=routing_map.dtype), accumulate=False) + routing_map = routing_map.bool() + else: + # TODO Try using element-wise operations instead of scatter? + routing_probs = torch.zeros_like(logits).scatter(1, top_indices, probs) + routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() + + return routing_probs, routing_map + + +def patched_routing(self, logits: torch.Tensor): + """Top-k routing function + + Args: + logits (torch.Tensor): Logits tensor after gating. + + Returns: + probs (torch.Tensor): The probabilities of token to experts assignment. + routing_map (torch.Tensor): The mapping of token to experts assignment, + with shape [num_tokens, num_experts]. + """ + seq_length, bsz = logits.shape[:2] + logits = logits.view(-1, self.config.num_moe_experts) + + # Apply Z-Loss + logits = self.apply_z_loss(logits) + + # Calculate probs and routing_map for token dispatching + if self.routing_type == "sinkhorn": + probs, routing_map = self.sinkhorn_load_balancing(logits) + else: + probs, routing_map = _patched_topk_routing_with_score_function( + logits=logits, + topk=self.topk, + use_pre_softmax=self.config.moe_router_pre_softmax, + num_groups=self.config.moe_router_num_groups, + group_topk=self.config.moe_router_group_topk, + scaling_factor=self.config.moe_router_topk_scaling_factor, + score_function=self.score_function, + expert_bias=self.expert_bias, + fused=self.config.moe_router_fusion, + router_replay=self.router_replay, + ) + + # Apply token dropping to probs and routing_map. + if self.config.moe_expert_capacity_factor is not None: + probs, routing_map = apply_router_token_dropping( + probs, + routing_map, + router_topk=self.topk, + capacity_factor=self.config.moe_expert_capacity_factor, + drop_policy=self.config.moe_token_drop_policy, + pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, + ) + + # Apply each aux loss type and attach aux loss autograd function to probs + if self.training and torch.is_grad_enabled() and self.is_aux_loss_enabled(): + # Calculate scores and routing_map for aux loss + routing_map_for_aux_loss, scores_for_aux_loss = compute_routing_scores_for_aux_loss( + logits, self.topk, self.score_function, fused=self.config.moe_router_fusion + ) + probs = self._apply_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss) + probs = self._apply_seq_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss, seq_length, bsz) + probs = self._apply_global_aux_loss(probs, scores_for_aux_loss, routing_map_for_aux_loss) + + # Update expert bias and tokens_per_expert + # Prevent extra local tokens accumulation on evaluation or activation recomputation + if self.enable_expert_bias and torch.is_grad_enabled(): + with torch.no_grad(): + self.local_tokens_per_expert += routing_map.sum(dim=0) + + return probs, routing_map + + +def apply_router_replay_patch(): + """ + Applies the monkey patch for MoE Router Replay functionality. + This patch dynamically adds the 'enable_routing_replay' attribute to TransformerConfig + and modifies the TopKRouter to support recording and replaying of routing decisions. + """ + print("Applying Router Replay Patch...") + # Clear router instances to avoid state leakage between model initializations. + RouterReplay.router_instances.clear() + # Step 1: Patch TransformerConfig to include the feature flag + if not hasattr(TransformerConfig, "enable_routing_replay"): + # Add class attribute with default value + TransformerConfig.enable_routing_replay = False + + # Store original __init__ method + original_tf_config_init = TransformerConfig.__init__ + + # Define new __init__ method that safely handles enable_routing_replay parameter + def patched_tf_config_init(self, *args, **kwargs): + # Simple solution: remove the unknown parameter before calling original constructor + enable_routing_replay = kwargs.pop("enable_routing_replay", TransformerConfig.enable_routing_replay) + + # Call original constructor with remaining kwargs + original_tf_config_init(self, *args, **kwargs) + + # Set the instance attribute + self.enable_routing_replay = enable_routing_replay + + # Apply the patch + TransformerConfig.__init__ = patched_tf_config_init + + # Step 2: Patch TopKRouter only once to ensure idempotency. + if hasattr(TopKRouter, "_router_replay_patched"): + return + + original_init = TopKRouter.__init__ + + # Step 3: Define the new __init__ method + def patched_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + self.router_replay = None + if self.config.enable_routing_replay: + self.router_replay = RouterReplay() + + # Step 4: Apply the patches + TopKRouter.__init__ = patched_init + TopKRouter.routing = patched_routing + TopKRouter._router_replay_patched = True diff --git a/verl/utils/megatron/router_replay_utils.py b/verl/utils/megatron/router_replay_utils.py new file mode 100644 index 00000000000..584603956bf --- /dev/null +++ b/verl/utils/megatron/router_replay_utils.py @@ -0,0 +1,437 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Router Replay Utilities +Utilities for handling router replay functionality in Megatron models. +""" + +from typing import Optional + +import torch +from megatron.core import parallel_state as mpu +from megatron.core.pipeline_parallel.schedules import get_schedule_table +from megatron.core.pipeline_parallel.utils import is_vp_first_stage, is_vp_last_stage +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset + +from verl.models.mcore.util import postprocess_packed_seqs, preprocess_packed_seqs +from verl.utils.device import get_device_name +from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction + +device_name = get_device_name() + + +# from megatron.core.transformer.transformer_block import get_num_layers_to_build +def get_num_layers_to_build( + config: TransformerConfig, vp_stage: Optional[int] = None, pp_rank: Optional[int] = None +) -> int: + """ + Determine the number of transformer layers to build for the current pipeline stage. + Args: + config (TransformerConfig): Configuration object containing transformer model parameters. + vp_stage (Optional[int]): Virtual pipeline stage number. + pp_rank (Optional[int]): Pipeline parallel rank. + + Returns: + int: The number of layers to be built for the current pipeline stage. + """ + # If we have a custom PP layout, straightforwardly + # return the number of decoders in the layout array. + if config.pipeline_model_parallel_layout is not None: + from megatron.core.transformer.enums import LayerType + + return config.pipeline_model_parallel_layout.get_num_layers_to_build( + layer_type=LayerType.decoder, vp_stage=vp_stage + ) + + # Fallback for legacy tests. + if pp_rank is None: + pp_rank = mpu.get_pipeline_model_parallel_rank() + + is_first_pp_stage = pp_rank == 0 + is_last_pp_stage = pp_rank == config.pipeline_model_parallel_size - 1 + + if config.num_layers_in_first_pipeline_stage is not None or config.num_layers_in_last_pipeline_stage is not None: + assert not (config.account_for_embedding_in_pipeline_split or config.account_for_loss_in_pipeline_split), ( + " \ + Does not support standalone embedding stage and standalone loss stage with uneven pp" + ) + # Number of layers to distribute over rest of pipeline stages + layers_to_distribute = config.num_layers + # Number of pipeline stages left for distributing transformer layers + pipeline_stages_left = config.pipeline_model_parallel_size + + # If the uneven first (last) pipeline stage is enabled, remove the specified number + # of layers to calculate the number of layers on each middle pipeline stage. + if config.num_layers_in_first_pipeline_stage is not None: + layers_to_distribute -= config.num_layers_in_first_pipeline_stage + pipeline_stages_left -= 1 + + if config.num_layers_in_last_pipeline_stage is not None: + layers_to_distribute -= config.num_layers_in_last_pipeline_stage + pipeline_stages_left -= 1 + + # If pp_size <= 2, we do not have any intermediate pipeline stages, and we do not + # need to check if the left over layers are divisible by the left over stages. + if pipeline_stages_left > 0: + assert layers_to_distribute % pipeline_stages_left == 0, ( + "With uneven pipelineing the left over layers must be divisible by left over stages" + ) + num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left + else: + num_layers_per_pipeline_rank = 0 + + # If the uneven first (last) pipeline stage is enabled, return the specified number + # of layers for all virtual pipeline parallel stages within the first (last) pipeline + # parallel stage. + + if is_first_pp_stage and config.num_layers_in_first_pipeline_stage is not None: + num_layers_per_pipeline_rank = config.num_layers_in_first_pipeline_stage + + if is_last_pp_stage and config.num_layers_in_last_pipeline_stage is not None: + num_layers_per_pipeline_rank = config.num_layers_in_last_pipeline_stage + else: + # Include the embedding layer and loss layer into pipeline parallelism partition + num_layers = config.num_layers + if config.account_for_embedding_in_pipeline_split: + num_layers += 1 + + if config.account_for_loss_in_pipeline_split: + num_layers += 1 + + assert num_layers % config.pipeline_model_parallel_size == 0, ( + "num_layers should be divisible by pipeline_model_parallel_size" + ) + num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size + + vp_size = config.virtual_pipeline_model_parallel_size + if vp_size is not None and config.pipeline_model_parallel_size > 1: + # Interleaved pipeline parallelism: + # Number of layers in each model chunk is the number of layers in the stage, + # divided by the number of model chunks in a stage. + # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0] [2] [4] [6] + # Stage 1: [1] [3] [5] [7] + # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0, 1] [4, 5] + # Stage 1: [2, 3] [6, 7] + + assert num_layers_per_pipeline_rank % vp_size == 0, ( + f"num_layers_per_pipeline_rank {num_layers_per_pipeline_rank} \ + should be divisible by vp_size {vp_size}" + ) + num_layers_per_virtual_stage = num_layers_per_pipeline_rank // vp_size + + num_layers_to_build = num_layers_per_virtual_stage + + else: + # Non-interleaved pipeline parallelism: + # Each stage gets a contiguous set of layers. + num_layers_to_build = num_layers_per_pipeline_rank + + # The embedding (or loss) layer cannot function as a standalone transformer layer + # Reduce the number of layers to construct by 1 on the first (or last) stage if the + # embedding (or loss) layer is included in the pipeline parallelism partition and placement. + if config.account_for_embedding_in_pipeline_split: + if is_vp_first_stage(vp_stage, vp_size) and is_first_pp_stage: + num_layers_to_build -= 1 + assert num_layers_to_build >= 0, "Not enough layers in the first virtual pipeline stage" + + if config.account_for_loss_in_pipeline_split: + if is_vp_last_stage(vp_stage, vp_size) and is_last_pp_stage: + num_layers_to_build -= 1 + assert num_layers_to_build >= 0, "Not enough layers in the last virtual pipeline stage" + + return num_layers_to_build + + +def merge_router_topk_indices(attention_mask, input_ids, mini_layer_topk_idx_list, tf_config, vp_rank=None): + """ + Merge recorded router top-k indices across sequence-parallel ranks for all router instances, + then pack/unpack them to align with the original (batch, seq_len) layout and append the result. + + Args: + attention_mask (torch.Tensor): Attention mask of shape [batch_size, seq_len]. Used to determine + the valid token positions during pack/unpack. + input_ids (torch.Tensor): Input token IDs of shape [batch_size, seq_len]. Used together with + attention_mask for sequence packing/unpacking. + mini_layer_topk_idx_list (list): A Python list to which the merged top-k indices tensor will be appended. + tf_config: Megatron/Transformer engine configuration object. Used to locate router instances for + the current micro-batch. + vp_rank (Optional[int]): Virtual pipeline stage rank override. If None, the current VP rank from + Megatron parallel state will be used. + + Returns: + None: The function has side effects only; it appends a tensor of shape + [1, dynamic_bs_all, layer_num, topk] to mini_layer_topk_idx_list. + """ + with torch.no_grad(): + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + layers_topk_idx = [] + for router in router_instances_list: + layers_topk_idx.append(router.recorded_topk_idx.to(torch.uint8)) # dynamic_bs, topk + + # layer_num, dynamic_bs, topk -> dynamic_bs, layer_num, topk + layers_topk_idx = torch.stack(layers_topk_idx).permute(1, 0, 2).to(device_name) + # dynamic_bs, layer_num, topk -> 1, dynamic_bs_all, layer_num, topk + layers_topk_idx = ( + gather_from_sequence_parallel_region(layers_topk_idx, tensor_parallel_output_grad=False) + .unsqueeze(0) + .contiguous() + ) + + batch_size, seq_len = attention_mask.shape[:2] + _, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) + layers_topk_idx = postprocess_packed_seqs( + layers_topk_idx, packed_seq_params, attention_mask, batch_size, seq_len, post_process=True + ) + mini_layer_topk_idx_list.append(layers_topk_idx.cpu()) + + +def set_router_replay_data(layers_topk_idx, attention_mask, tf_config, vp_rank=None): + """ + Scatter the packed router top-k indices back to sequence-parallel ranks and update each local + RouterReplay instance with target indices for replay mode. + + This function prepares the per-layer, per-sample top-k routing decisions (recorded during an earlier + forward) so that subsequent replay passes can follow exactly the same routing. + + Args: + layers_topk_idx (torch.Tensor): Router top-k indices with shape [bs, max_seq_len, layer_num, topk]. + This should be the merged output produced by merge_router_topk_indices. + attention_mask (torch.Tensor): Attention mask [batch_size, seq_len] used for pack/unpack alignment. + tf_config: Megatron/Transformer engine configuration object. + vp_rank (Optional[int]): Virtual pipeline stage rank override. If None, the current VP rank from + Megatron parallel state will be used. + + Returns: + None: The function updates internal RouterReplay instances in-place. + """ + with torch.no_grad(): + layers_topk_idx_rmpad, _ = preprocess_packed_seqs(layers_topk_idx, attention_mask, pre_process=True) + layers_topk_idx_rmpad = layers_topk_idx_rmpad.contiguous() # 1, dynamic_bs_all, layer_num, topk + + # 1, dynamic_bs_split, layer_num, topk + layers_topk_idx_rmpad_split = scatter_to_sequence_parallel_region( + layers_topk_idx_rmpad.to(device_name).squeeze(dim=0) + ).unsqueeze(dim=0) + + # dynamic_bs_split, layer_num, topk -> layer_num, dynamic_bs_split, topk + layers_topk_idx_reshape = layers_topk_idx_rmpad_split.permute(0, 2, 1, 3).squeeze( + dim=0 + ) # layer_num, dynamic_bs_all, topk + local_rank_info = get_current_rank_layer_info(tf_config, vp_rank) + offset, _ = local_rank_info["start"], local_rank_info["end"] + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + for i, router in enumerate(router_instances_list): + router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64)) + + +def reorder_and_merge_vpp_layers( + micro_batch_tensor_list, + num_microbatches: int, + vpp_size: int, + microbatch_group_size_per_vp_stage: int, +) -> torch.Tensor: + """ + Reorder and merge per-VPP layer blocks into a contiguous layer dimension. + + Given a tensor shaped as [bs*vpp_size, max_token_len, layer_num_per_vpp, topk], this function: + 1) Builds the schedule table for virtual microbatches and reorders the first dimension so that entries + belonging to the same model chunk (VPP stage) become contiguous. + 2) Reshapes and merges the (vpp_size, layer_num_per_vpp) into a single layer dimension, producing + [bs, max_token_len, layer_num, topk]. + + Args: + micro_batch_tensor_list : the list of Input tensor. + num_microbatches (int): Number of microbatches per pipeline stage (bs). + vpp_size (int): Virtual pipeline parallel size (number of model chunks). + microbatch_group_size_per_vp_stage (int): Number of consecutive microbatches processed per VPP stage. + + Returns: + torch.Tensor: Output tensor of shape [bs, max_token_len, layer_num, topk]. + + Raises: + ValueError: If input tensor dimensionality or expected sizes do not match. + RuntimeError: If the computed output shape is unexpected or the schedule length mismatches. + """ + # 1) Build schedule table: map each virtual_microbatch_id -> (microbatch_id, model_chunk_id) + schedule_table = get_schedule_table(num_microbatches, vpp_size, microbatch_group_size_per_vp_stage) + + # 2) Group by model_chunk_id to build reorder indices so entries of the same chunk become contiguous along dim 0 + tensor_by_chunk = [[] for _ in range(vpp_size)] + mini_tensor_list = [] + + for vidx, (_mb, chunk_id) in enumerate(schedule_table): + tensor_by_chunk[chunk_id].append(micro_batch_tensor_list[vidx]) + + for chunk_id in range(vpp_size): + mini_tensor_list.append(torch.cat(tensor_by_chunk[chunk_id], dim=0)) + + out = torch.cat(mini_tensor_list, dim=2) + return out + + +def get_current_rank_layer_info(tf_config, vp_rank=None): + # When vp_rank is None, default to the current VP rank (or 0 if VP is disabled). + """Return the local layer range/count for the current process and the full assignment table. + + Args: + tf_config: Configuration object used by compute_pipeline_layer_assignment. + vp_rank (Optional[int]): Explicit virtual pipeline stage rank to query. If None, uses + mpu.get_virtual_pipeline_model_parallel_rank() when VP is enabled; otherwise 0. + + Returns: + Tuple[dict, dict]: A tuple of (local_assignment, all_assignments) where local_assignment contains + keys {"start", "end", "count"} for the current (pp_rank, vp_stage). + """ + if vp_rank is None: + vp_rank = 0 + num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage=vp_rank) + offset = get_transformer_layer_offset(tf_config, vp_stage=vp_rank) + local = {} + local["start"] = offset + local["end"] = offset + num_layers_to_build + local["count"] = num_layers_to_build + return local + + +def pp_gather(local_layers_router_map, tf_config): + # TODO: Consider non-uniform layer allocation cases. + """ + Gather local router maps from all PP ranks into a global router map. + + Args: + local_layers_router_map (torch.Tensor): Local router map of shape + [bs, max_seq_len, local_num_layers, topk]. + tf_config: Configuration providing pipeline_model_parallel_size. + + Returns: + torch.Tensor: Global router map of shape [bs, max_seq_len, num_layers, topk] placed on CPU. + """ + pp_size = tf_config.pipeline_model_parallel_size + if pp_size <= 1: + return local_layers_router_map + + pp_group = mpu.get_pipeline_model_parallel_group() + world_size = torch.distributed.get_world_size(pp_group) + local_layers_router_map = local_layers_router_map.to(device_name) + layers_topk_idx_global_list = [ + torch.empty( + size=local_layers_router_map.shape, + dtype=local_layers_router_map.dtype, + device=local_layers_router_map.device, + ) + for _ in range(world_size) + ] + torch.distributed.all_gather( + tensor=local_layers_router_map, + tensor_list=layers_topk_idx_global_list, + group=pp_group, + async_op=False, + ) + vp_size = tf_config.virtual_pipeline_model_parallel_size + if vp_size is not None: + vpp_router_map_offset = [[] for _ in range(pp_size)] + for pp_stage in range(pp_size): + vpp_router_map_offset[pp_stage].append(0) + for vp_stage in range(vp_size): + num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage, pp_stage) + vpp_router_map_offset[pp_stage].append(num_layers_to_build + vpp_router_map_offset[pp_stage][-1]) + layers_topk_idx_global = [] + for vp_stage in range(vp_size): + for pp_stage in range(pp_size): + piece = slice(vpp_router_map_offset[pp_stage][vp_stage], vpp_router_map_offset[pp_stage][vp_stage + 1]) + layers_topk_idx_global.append(layers_topk_idx_global_list[pp_stage][:, :, piece, :]) + global_router_map = torch.cat(layers_topk_idx_global, dim=2).to("cpu") + else: + global_router_map = torch.cat(layers_topk_idx_global_list, dim=2).to("cpu") + + return global_router_map + + +class RouterReplayHelper: + """Helper class to query router replay state and locate local RouterReplay instances.""" + + @staticmethod + def get_micro_batch_router_list(tf_config, vp_rank=None): + """ + Return the list of RouterReplay instances corresponding to the current micro-batch and local + (pp_rank, vp_stage) layer range. + + When virtual pipeline (VPP) is enabled, the local range for the PP rank is expanded to include + all VP stages by multiplying the per-VP count by vp_size. The returned slice is taken from the + global RouterReplay.router_instances list. + + Args: + tf_config: Configuration object used to compute layer assignments. + vp_rank (Optional[int]): Explicit virtual pipeline stage to query. If None, the current VP + rank from Megatron parallel state is used when available. + Returns: + list: A contiguous sublist of RouterReplay.router_instances for the local layer range. + """ + vp_size = tf_config.virtual_pipeline_model_parallel_size + if vp_size is not None: + vp_rank = 0 if vp_rank is None else vp_rank + offset = 0 + for pre_vp_stage in range(vp_size): + if pre_vp_stage == vp_rank: + break + num_layers_to_build = get_num_layers_to_build(tf_config, pre_vp_stage) + offset += num_layers_to_build + else: + offset = 0 + + num_layers_to_build = get_num_layers_to_build(tf_config, vp_rank) + router_instances_list = RouterReplay.router_instances[offset : offset + num_layers_to_build] + return router_instances_list + + @staticmethod + def is_r2_record_action(tf_config, vp_rank=None) -> bool: + """Return True if the current router_replay_action is RECORD (R2) for the local router instances. + + This inspects the first local RouterReplay instance's router_replay_action and compares it to + RouterReplayAction.RECORD. + """ + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return router_instances_list and router_instances_list[0].router_replay_action == RouterReplayAction.RECORD + + @staticmethod + def is_replay_forward_action(tf_config, vp_rank=None) -> bool: + """Return True if the current router_replay_action is REPLAY_FORWARD for the local router instances. + + This inspects the first local RouterReplay instance's router_replay_action and compares it to + RouterReplayAction.REPLAY_FORWARD. + """ + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return ( + router_instances_list and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_FORWARD + ) + + @staticmethod + def is_replay_backward_action(tf_config, vp_rank=None) -> bool: + """Return True if the current router_replay_action is REPLAY_BACKWARD for the local router instances. + + This inspects the first local RouterReplay instance's router_replay_action and compares it to + RouterReplayAction.REPLAY_BACKWARD. + """ + router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) + return ( + router_instances_list + and router_instances_list[0].router_replay_action == RouterReplayAction.REPLAY_BACKWARD + ) diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 4e9e1d62dd6..d8e802c7564 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -40,8 +40,16 @@ from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty from verl.utils.device import get_device_id, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator +from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction +from verl.utils.megatron.router_replay_utils import ( + RouterReplayHelper, + merge_router_topk_indices, + pp_gather, + reorder_and_merge_vpp_layers, + set_router_replay_data, +) from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits -from verl.utils.megatron_utils import get_model_config +from verl.utils.megatron_utils import get_model_config, unwrap_model from verl.utils.profiler import GPUMemoryLogger from verl.utils.profiler.profile import Profiler from verl.utils.py_functional import append_to_dict @@ -150,6 +158,11 @@ def __init__( } ) + self.router_replay = self.config.router_replay + self.enable_routing_replay = self.router_replay.mode != "disabled" + if self.enable_routing_replay: + self.mini_layer_topk_idx_list = [] + config = get_model_config(self.actor_module[0]) print(config) config.finalize_model_grads_func = finalize_model_grads @@ -208,6 +221,11 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): entropys = torch.Tensor() if recompute_old_log_prob: select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + + if self.enable_routing_replay and self.config.router_replay.mode == "R3": + assert "routed_experts" in data.batch.keys(), "routed_experts must be in data.batch.keys()" + select_keys.append("routed_experts") + batch = data.select(batch_keys=select_keys).batch input_ids = batch["input_ids"] batch_size = input_ids.size(0) @@ -273,11 +291,22 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): async_op=False, ) entropys = entropys.to("cpu") + layers_topk_idx = None + if RouterReplayHelper.is_r2_record_action(self.tf_config): + # (bs, max_seq_len/response_len,local_layer_num,topk) + layers_topk_idx = output["mini_layer_topk_idx_tensor"].to(torch.uint8) + if use_dynamic_bsz: + indices = output["indices"] + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == layers_topk_idx.size(0), f"{len(indices)} vs. {layers_topk_idx.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + layers_topk_idx = layers_topk_idx[revert_indices] + layers_topk_idx = pp_gather(layers_topk_idx, self.tf_config) # add empty cache after each compute get_torch_device().empty_cache() - return log_probs, entropys + return log_probs, entropys, layers_topk_idx def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: """Make minibatch iterator for updating the actor @@ -324,10 +353,14 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: if "rollout_log_probs" in data.batch.keys(): select_keys.append("rollout_log_probs") self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + # router replay + if self.enable_routing_replay: + select_keys.append("routed_experts") if self.has_multi_modal_inputs: data = data.select(select_keys, ["multi_modal_inputs"]) else: data = data.select(batch_keys=select_keys) + return data.make_iterator( mini_batch_size=self.config.ppo_mini_batch_size, epochs=self.config.ppo_epochs, @@ -537,6 +570,12 @@ def forward_step(batch_iter, model, return_schedule_plan: bool = False): attention_mask = batch["attention_mask"].to(bool) position_ids = batch["position_ids"] + unwrapped_model = unwrap_model(model) + if hasattr(unwrapped_model, "vp_stage"): + vp_rank = unwrapped_model.vp_stage + else: + vp_rank = 0 + multi_modal_inputs = {} if "multi_modal_inputs" in batch: from verl.utils.model import extract_multi_modal_inputs @@ -551,6 +590,15 @@ def forward_step(batch_iter, model, return_schedule_plan: bool = False): label_mask[:, : -response_length - 1] = False label_mask[:, -1] = False + if RouterReplayHelper.is_replay_backward_action(self.tf_config, vp_rank): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list(self.tf_config, vp_rank) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank): + layers_topk_idx = batch["routed_experts"] + set_router_replay_data(layers_topk_idx, attention_mask, self.tf_config, vp_rank) + from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn if self.use_fused_kernels: @@ -612,6 +660,17 @@ def logits_processor(logits, label, label_mask): "entropy_coeff": self.config.entropy_coeff, "clip_ratio_c": clip_ratio_c, } + + if RouterReplayHelper.is_r2_record_action(self.tf_config, vp_rank): + merge_router_topk_indices( + attention_mask, input_ids, self.mini_layer_topk_idx_list, self.tf_config, vp_rank + ) + + if RouterReplayHelper.is_replay_forward_action(self.tf_config, vp_rank): + router_instance_list = RouterReplayHelper.get_micro_batch_router_list(self.tf_config, vp_rank) + for router in router_instance_list: + router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD) + return output, partial(loss_func, data=batch, meta_info=meta_info) # batch should be a list of batches inside micro-batches @@ -649,6 +708,19 @@ def logits_processor(logits, label, label_mask): losses_reduced = {"output": losses_reduced} if use_dynamic_bsz: losses_reduced["indices"] = indices + if RouterReplayHelper.is_r2_record_action(self.tf_config): + if self.tf_config.virtual_pipeline_model_parallel_size is not None: + # config = self.actor_module[0].module.module.config + vp_size = len(self.actor_module) + microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + bs = n_micro_batch + losses_reduced["mini_layer_topk_idx_tensor"] = reorder_and_merge_vpp_layers( + self.mini_layer_topk_idx_list, bs, vp_size, microbatch_group_size_per_vp_stage + ) + else: + losses_reduced["mini_layer_topk_idx_tensor"] = torch.cat(self.mini_layer_topk_idx_list, dim=0) + self.mini_layer_topk_idx_list = [] + return losses_reduced @GPUMemoryLogger(role="megatron actor", logger=logger) @@ -668,6 +740,8 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: if self.use_torch_profiler and self.prof and self.prof.enable: self.prof.start() for data in dataloader: + if self.config.router_replay.mode in ["R2", "R3"]: + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) self.actor_optimizer.zero_grad() # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm for chunk in self.actor_module: @@ -706,6 +780,11 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: raise NotImplementedError if self.use_torch_profiler and self.prof and self.prof.enable: self.prof.step() + + if self.config.router_replay.mode in ["R2", "R3"]: + RouterReplay.clear_global_router_replay_action() + RouterReplay.clear_global_indices() + # add empty cache after each compute if self.use_torch_profiler and self.prof and self.prof.enable: self.prof.stop_and_save() diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py index ecc392d47dd..38aeb4bf5f9 100644 --- a/verl/workers/config/actor.py +++ b/verl/workers/config/actor.py @@ -25,7 +25,36 @@ from .model import HFModelConfig from .optimizer import OptimizerConfig -__all__ = ["PolicyLossConfig", "ActorConfig", "FSDPActorConfig", "McoreActorConfig"] +__all__ = ["PolicyLossConfig", "RouterReplayConfig", "ActorConfig", "FSDPActorConfig", "McoreActorConfig"] + + +@dataclass +class RouterReplayConfig(BaseConfig): + """Configuration for router replay in MoE models. + + This configuration controls the routing behavior for Mixture of Experts (MoE) models, + allowing for deterministic training through route recording and replay. + + Args: + mode (str): Router replay mode. Options: 'disabled', 'R2', 'R3'. + - 'disabled': No router replay functionality + - 'R2': Use Router Replay routing strategy + - 'R3': Use Rollout Router Replay routing strategy + record_file (Optional[str]): File path to save recorded routing decisions. + Required when mode is 'record', 'R2', or 'R3'. + replay_file (Optional[str]): File path to load recorded routing decisions for replay. + Required when mode is 'replay'. + """ + + mode: str = "disabled" + record_file: Optional[str] = None + replay_file: Optional[str] = None + + def __post_init__(self): + """Validate router replay configuration.""" + valid_modes = ["disabled", "R2", "R3"] + if self.mode not in valid_modes: + raise ValueError(f"Invalid router_replay mode: {self.mode}. Must be one of {valid_modes}") @dataclass @@ -84,6 +113,7 @@ class ActorConfig(BaseConfig): optim (OptimizerConfig): Configuration for optimizer. use_fused_kernels (bool): Whether to use custom fused kernels (e.g., FlashAttention, fused MLP). data_loader_seed (int): Seed for data loader. If None, uses global seed. + router_replay (RouterReplayConfig): Configuration for router replay in MoE models. """ _mutable_fields = BaseConfig._mutable_fields | { @@ -127,6 +157,7 @@ class ActorConfig(BaseConfig): engine: BaseConfig = field(default_factory=BaseConfig) rollout_n: int = MISSING # must be override by sampling config model_config: HFModelConfig = field(default_factory=BaseConfig) + router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig) # Store global batch info for loss aggregation: # dp_size: data parallel size diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index fb186ca17f8..c0a8e5b8b8f 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -204,6 +204,7 @@ class RolloutConfig(BaseConfig): skip_tokenizer_init: bool = False quantization: Optional[str] = None + enable_rollout_routing_replay: bool = False def __post_init__(self): """Validate the rollout config""" diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 83b7314266b..877df2dc4cf 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -51,6 +51,7 @@ from verl.utils.distributed import set_numa_affinity from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local +from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch from verl.utils.megatron_utils import ( load_megatron_model_to_gpu, load_megatron_optimizer, @@ -285,6 +286,15 @@ def __init__(self, config: DictConfig, role: str, **kwargs): mesh_name="actor", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect ) only_rollout = self._is_rollout and not self._is_actor + + self.enable_routing_replay = False + if self._is_actor: + self.router_replay = self.config.actor.router_replay + self.enable_routing_replay = self.router_replay.mode != "disabled" + + if self.enable_routing_replay: + apply_router_replay_patch() + set_random_seed(seed=self.config.actor.megatron.seed, only_rollout=only_rollout) if self._is_actor: @@ -540,6 +550,8 @@ def init_model(self): override_transformer_config = OmegaConf.to_container( OmegaConf.create(self.config.actor.megatron.get("override_transformer_config", {})) ) + if self.enable_routing_replay: + override_transformer_config["enable_routing_replay"] = True override_ddp_config = OmegaConf.to_container( OmegaConf.create(self.config.actor.megatron.get("override_ddp_config", {})) ) @@ -585,6 +597,7 @@ def init_model(self): actor_module=self.actor_module, actor_optimizer=self.actor_optimizer, ) + print(f"routing replay layers: {len(RouterReplay.router_instances)}") log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) if self._is_rollout: @@ -812,7 +825,7 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature - output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) + output, _, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = output.to("cpu") if self._ref_is_offload_param: @@ -834,11 +847,25 @@ def compute_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature - output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) + + if self.enable_routing_replay and self.config.actor.router_replay.mode == "R2": + RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) + + if self.enable_routing_replay and self.config.actor.router_replay.mode == "R3": + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + + output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=True) output = DataProto.from_dict( tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}, ) + if self.config.actor.router_replay.mode == "R2": + output.batch["routed_experts"] = layers_topk_idx + + if self.config.actor.router_replay.mode in ["R2", "R3"]: + RouterReplay.clear_global_indices() + RouterReplay.clear_global_router_replay_action() + output = output.to("cpu") # clear kv cache if self._is_offload_param: diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index 44f0b5366bf..fbb3ae8efc4 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -16,7 +16,7 @@ import os from abc import ABC, abstractmethod from enum import Enum -from typing import Callable, Optional +from typing import Any, Callable, Optional from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel @@ -35,6 +35,8 @@ class TokenOutput(BaseModel): """response token ids""" log_probs: Optional[list[float]] = None """logprobs of response token ids""" + routed_experts: Optional[Any] = None + """routed experts of response token ids""" class RolloutMode(Enum): diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index db9e234f3fb..b42f9c251b2 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -294,6 +294,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non } ) + if self.config.enable_rollout_routing_replay: + args.update({"enable_return_routed_experts": True}) + server_args = ["serve", self.model_config.local_path] for k, v in args.items(): if isinstance(v, bool): @@ -430,7 +433,12 @@ async def generate( log_probs = None if sampling_params.logprobs is not None: log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(final_res.outputs[0].logprobs)] - return TokenOutput(token_ids=token_ids, log_probs=log_probs) + + routed_experts = None + if self.config.enable_rollout_routing_replay: + routed_experts = final_res.outputs[0].routed_experts + + return TokenOutput(token_ids=token_ids, log_probs=log_probs, routed_experts=routed_experts) async def wake_up(self): if self.rollout_mode == RolloutMode.HYBRID: diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 6a8cd55b9c6..ccdf7f9f75b 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -111,6 +111,19 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[in return token_ids +def pad_first_dim_tail(x: torch.Tensor, M: int, value: float = 0.0, left_pad: bool = True): + N = x.size(0) + if M < N: + raise ValueError(f"M ({M}) must be >= N ({N})") + if M == N: + return x + pad_shape = (M - N, *x.shape[1:]) + pad_tensor = torch.full(pad_shape, value, dtype=x.dtype, device=x.device) + if left_pad: + return torch.cat([pad_tensor, x], dim=0) + return torch.cat([x, pad_tensor], dim=0) + + if is_version_ge(pkg="vllm", minver="0.7.3"): VLLMHijack.hijack() @@ -236,6 +249,10 @@ def __init__( else: logger.warning(f"cudagraph_capture_sizes must be a list, but got {cudagraph_capture_sizes}") + router_replay_args = {} + if config.enable_rollout_routing_replay: + router_replay_args = {"enable_return_routed_experts": config.enable_rollout_routing_replay} + self.inference_engine = LLM( model=model_path, enable_sleep_mode=config.free_cache_engine, @@ -258,6 +275,7 @@ def __init__( **compilation_config, **self.lora_kwargs, **engine_kwargs, + **router_replay_args, ) kwargs = dict( @@ -410,6 +428,63 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: curr_log_prob.append(logprob[response_ids[i]].logprob) rollout_log_probs.append(curr_log_prob) + routed_experts = [] + input_routed_experts = [] + output_routed_experts = [] + if self.config.enable_rollout_routing_replay: + # Calculate target length for padding (prompt length + max response length) + max_prompt_length = idx.shape[-1] + + for output in outputs: + for sample_id in range(len(output.outputs)): + routed_expert = output.outputs[sample_id].routed_experts + routed_experts.append(routed_expert) + + for i, routed_expert in enumerate(routed_experts): + response_length = len(response[i]) + # Convert numpy array to torch tensor + routed_expert_tensor = torch.from_numpy(routed_expert) + total_length = routed_expert_tensor.shape[0] + assert total_length >= response_length, ( + f"routed_expert length {total_length} is shorter than response length {response_length}" + ) + input_len = len(vllm_inputs[i]["prompt_token_ids"]) + input_expert = routed_expert_tensor[:input_len] + + pad_input_expert = pad_first_dim_tail(input_expert, max_prompt_length, value=0, left_pad=True) + input_routed_experts.append(pad_input_expert) + + output_expert = routed_expert_tensor[input_len:] + output_routed_experts.append(output_expert) + # Convert list of tensors to batch tensor + input_routed_experts = torch.stack(input_routed_experts, dim=0) + + def pad_3d_list_to_length(routed_experts, pad_token_id, max_length=None): + """ + pad a 3D list (e.g. all layer expert_idx) to a 3D tensor. + """ + len_list = [sub_response.shape[0] for sub_response in routed_experts] + # response_length = max(len(sub_list.shape[0]) for sub_list in routed_experts) + response_length = max(len_list) + target_length = ( + max_length if max_length is not None and max_length > response_length else response_length + ) + new_sub_resposne_list = [] + for sub_response in routed_experts: + pad_shape = (target_length - sub_response.shape[0], *sub_response.shape[1:]) + pad_tensor = torch.full( + pad_shape, pad_token_id, dtype=sub_response.dtype, device=sub_response.device + ) + new_sub_response = torch.concat([sub_response, pad_tensor], dim=0) + new_sub_resposne_list.append(new_sub_response) + tensor = torch.stack(new_sub_resposne_list, dim=0) + return tensor + + output_routed_experts = pad_3d_list_to_length( + output_routed_experts, 0, max_length=self.config.response_length + ) + routed_experts = torch.cat([input_routed_experts, output_routed_experts], dim=1) + response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to( idx.device ) @@ -453,6 +528,9 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # we will recompute old log prob with actor batch["rollout_log_probs"] = rollout_log_probs + if self.config.enable_rollout_routing_replay: + batch["routed_experts"] = routed_experts + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) async def resume(self, tags: list[str]):