diff --git a/examples/router_replay/run_qwen30_a3b_megatron_sglang.sh b/examples/router_replay/run_qwen30_a3b_megatron_sglang.sh index a2936ae65b6..70ba749c7af 100644 --- a/examples/router_replay/run_qwen30_a3b_megatron_sglang.sh +++ b/examples/router_replay/run_qwen30_a3b_megatron_sglang.sh @@ -40,6 +40,13 @@ actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) USE_LEGACY_WORKER_IMPL="enable" # disable, enable +if [ "$USE_LEGACY_WORKER_IMPL" = "disable" ]; then + ROUTING_REPLAY_MODE_ARG="actor_rollout_ref.actor.megatron.router_replay.mode=${ROUTING_REPLAY_MODE}" + remove_padding=True +else + ROUTING_REPLAY_MODE_ARG="actor_rollout_ref.actor.router_replay.mode=${ROUTING_REPLAY_MODE}" + remove_padding=False +fi exper_name=Node${NODES}_bs${bs}_${PP}${TP}${EP}${ETP}_${SGLANG_INFER_TP}_minbs${ppo_mini_batch_size}_micro_bs${micro_bs} python3 -m verl.trainer.main_ppo --config-path=config \ @@ -54,10 +61,10 @@ python3 -m verl.trainer.main_ppo --config-path=config \ data.truncation='error' \ actor_rollout_ref.model.use_fused_kernels=True \ actor_rollout_ref.model.path=$HF_MODEL_PATH \ - actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_remove_padding=${remove_padding} \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.actor.megatron.router_replay.mode=${ROUTING_REPLAY_MODE} \ + ${ROUTING_REPLAY_MODE_ARG} \ +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \ +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ diff --git a/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh b/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh index 43eb56cc157..2b9d9cece3c 100644 --- a/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh +++ b/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh @@ -41,7 +41,14 @@ ppo_mini_batch_size=8 actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) -USE_LEGACY_WORKER_IMPL="enable" # disable, enable +USE_LEGACY_WORKER_IMPL="disable" # disable, enable +if [ "$USE_LEGACY_WORKER_IMPL" = "disable" ]; then + ROUTING_REPLAY_MODE_ARG="actor_rollout_ref.actor.megatron.router_replay.mode=${ROUTING_REPLAY_MODE}" + remove_padding=True +else + ROUTING_REPLAY_MODE_ARG="actor_rollout_ref.actor.router_replay.mode=${ROUTING_REPLAY_MODE}" + remove_padding=False +fi exper_name=Node${NODES}_bs${bs}_${PP}${TP}${EP}${ETP}_${VLLM_INFER_TP}_minbs${ppo_mini_batch_size}_micro_bs${micro_bs} python3 -m verl.trainer.main_ppo --config-path=config \ @@ -56,10 +63,10 @@ python3 -m verl.trainer.main_ppo --config-path=config \ data.truncation='error' \ actor_rollout_ref.model.use_fused_kernels=True \ actor_rollout_ref.model.path=$HF_MODEL_PATH \ - actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_remove_padding=$remove_padding \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.actor.megatron.router_replay.mode=${ROUTING_REPLAY_MODE} \ + ${ROUTING_REPLAY_MODE_ARG} \ +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \ +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \ +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ @@ -84,6 +91,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_bs \ diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 50eb4a128d3..91fbd93e88a 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -84,7 +84,6 @@ def __init__(self): self.recorded_topk_idx = None # For recording self.router_replay_action = None # Router replay action for this layer self.replay_backward_list = [] # List of tensors for backward pass replay - self.layer_number = None # Global layer index if available RouterReplay.router_instances.append(self) def set_target_indices(self, topk_indices: torch.Tensor): @@ -268,7 +267,7 @@ def patched_routing(self, logits: torch.Tensor, *args, **kwargs): score_function=self.score_function, expert_bias=self.expert_bias, fused=self.config.moe_router_fusion, - router_replay=self.router_replay, + router_replay=getattr(self, "router_replay", None), ) # Apply token dropping to probs and routing_map. @@ -337,12 +336,6 @@ def patched_tf_config_init(self, *args, **kwargs): return original_init = TopKRouter.__init__ - original_set_layer_number = TopKRouter.set_layer_number - - def patched_set_layer_number(self, layer_number: int): - original_set_layer_number(self, layer_number) - if self.router_replay is not None: - self.router_replay.layer_number = layer_number # Step 3: Define the new __init__ method def patched_init(self, *args, **kwargs): @@ -384,5 +377,4 @@ def patched_preprocess(self, routing_map): # Step 5: Apply the patches TopKRouter.__init__ = patched_init TopKRouter.routing = patched_routing - TopKRouter.set_layer_number = patched_set_layer_number TopKRouter._router_replay_patched = True diff --git a/verl/utils/megatron/router_replay_utils.py b/verl/utils/megatron/router_replay_utils.py index 3aec85c24b7..98df5cf1d53 100644 --- a/verl/utils/megatron/router_replay_utils.py +++ b/verl/utils/megatron/router_replay_utils.py @@ -171,6 +171,41 @@ def get_num_layers_to_build( return num_layers_to_build +def is_moe_layer(tf_config, layer_idx): + moe_layer_freq = getattr(tf_config, "moe_layer_freq", None) + + if isinstance(moe_layer_freq, int): + return layer_idx % moe_layer_freq == 0 + elif isinstance(moe_layer_freq, list): + return moe_layer_freq[layer_idx] == 1 + else: + raise ValueError(f"Unsupported moe_layer_freq type: {type(moe_layer_freq)}") + + +def get_moe_num_layers_to_build( + config: TransformerConfig, vp_stage: Optional[int] = None, pp_rank: Optional[int] = None +) -> int: + """Count the number of MoE layers assigned to the current rank. + When ``moe_layer_freq`` is 1 or unset, every transformer layer is an MoE + layer, so the count equals the total layer count. Otherwise only layers + whose global index satisfies the frequency predicate are counted. + Args: + config: Megatron TransformerConfig providing layer layout information. + vp_stage: Virtual-pipeline stage index (None defaults to current). + pp_rank: Pipeline-parallel rank (None defaults to current). + Returns: + Number of MoE layers on the specified rank/stage. + """ + total_layers = get_num_layers_to_build(config, vp_stage=vp_stage, pp_rank=pp_rank) + + layer_offset = get_transformer_layer_offset(config, vp_stage=vp_stage) + local_global_indices = range(layer_offset, layer_offset + total_layers) + + num_moe_layers = sum(1 for idx in local_global_indices if is_moe_layer(config, idx)) + + return num_moe_layers + + def merge_router_topk_indices(attention_mask, input_ids, mini_layer_topk_idx_list, tf_config, vp_rank=None): """ Merge recorded router top-k indices across sequence-parallel ranks for all router instances, @@ -249,24 +284,27 @@ def set_router_replay_data(layers_topk_idx, attention_mask, tf_config, vp_rank=N layers_topk_idx_reshape = layers_topk_idx_rmpad_split.permute(0, 2, 1, 3).squeeze( dim=0 ) # layer_num, dynamic_bs_all, topk - num_layers_in_data = layers_topk_idx_reshape.shape[0] - use_global_layer_index = getattr(tf_config, "num_layers", None) == num_layers_in_data local_rank_info = get_current_rank_layer_info(tf_config, vp_rank) - offset, _ = local_rank_info["start"], local_rank_info["end"] + offset, end = local_rank_info["start"], local_rank_info["end"] router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) - for i, router in enumerate(router_instances_list): - layer_idx = None - if use_global_layer_index: - layer_number = getattr(router, "layer_number", None) - if layer_number is not None: - layer_idx = layer_number - 1 - if layer_idx is None: - layer_idx = i + offset - if layer_idx < 0 or layer_idx >= num_layers_in_data: - raise ValueError( - f"router replay layer index {layer_idx} out of range for data with {num_layers_in_data} layers" - ) - router.set_target_indices(layers_topk_idx_reshape[layer_idx].to(torch.int64)) + + # When dim-0 covers all layers (e.g. R3, or R2 with all-MoE models), + # index by absolute layer_idx; otherwise (R2 with mixed dense/MoE), + # dim-0 only contains MoE layers, index by MoE-layer ordinal. + index_by_layer = len(layers_topk_idx_reshape) == tf_config.num_layers + + # For R2: count MoE layers before `offset` as the starting position. + moe_idx = sum(1 for i in range(offset) if is_moe_layer(tf_config, i)) + + router_offset = 0 + for layer_idx in range(offset, end): + if not is_moe_layer(tf_config, layer_idx): + continue + router = router_instances_list[router_offset] + idx = layer_idx if index_by_layer else moe_idx + router.set_target_indices(layers_topk_idx_reshape[idx].to(torch.int64)) + router_offset += 1 + moe_idx += 1 def reorder_and_merge_vpp_layers( @@ -378,7 +416,7 @@ def pp_gather(local_layers_router_map, tf_config): for pp_stage in range(pp_size): vpp_router_map_offset[pp_stage].append(0) for vp_stage in range(vp_size): - num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage, pp_stage) + num_layers_to_build = get_moe_num_layers_to_build(tf_config, vp_stage, pp_stage) vpp_router_map_offset[pp_stage].append(num_layers_to_build + vpp_router_map_offset[pp_stage][-1]) layers_topk_idx_global = [] for vp_stage in range(vp_size): @@ -419,8 +457,7 @@ def get_micro_batch_router_list(tf_config, vp_rank=None): for pre_vp_stage in range(vp_size): if pre_vp_stage == vp_rank: break - num_layers_to_build = get_num_layers_to_build(tf_config, pre_vp_stage) - offset += num_layers_to_build + offset += get_moe_num_layers_to_build(tf_config, pre_vp_stage) else: offset = 0 diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 5cb0824a96b..8775627ee5d 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -673,7 +673,7 @@ def prepare_model_inputs(self, batch: TensorDict): loss_mask = batch["loss_mask"].to(bool) multi_modal_inputs = extract_multi_modal_inputs(batch.get("multi_modal_inputs", [])) - routed_experts = batch.get("routed_experts", []) + routed_experts = batch.get("routed_experts", None) return { "input_ids": input_ids,