Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions examples/router_replay/run_qwen30_a3b_megatron_sglang.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))

USE_LEGACY_WORKER_IMPL="enable" # disable, enable
if [ "$USE_LEGACY_WORKER_IMPL" = "disable" ]; then
ROUTING_REPLAY_MODE_ARG="actor_rollout_ref.actor.megatron.router_replay.mode=${ROUTING_REPLAY_MODE}"
remove_padding=True
else
ROUTING_REPLAY_MODE_ARG="actor_rollout_ref.actor.router_replay.mode=${ROUTING_REPLAY_MODE}"
remove_padding=False
fi
exper_name=Node${NODES}_bs${bs}_${PP}${TP}${EP}${ETP}_${SGLANG_INFER_TP}_minbs${ppo_mini_batch_size}_micro_bs${micro_bs}

python3 -m verl.trainer.main_ppo --config-path=config \
Expand All @@ -54,10 +61,10 @@ python3 -m verl.trainer.main_ppo --config-path=config \
data.truncation='error' \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.model.path=$HF_MODEL_PATH \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_remove_padding=${remove_padding} \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.actor.megatron.router_replay.mode=${ROUTING_REPLAY_MODE} \
${ROUTING_REPLAY_MODE_ARG} \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
Expand Down
14 changes: 11 additions & 3 deletions examples/router_replay/run_qwen30_a3b_megatron_vllm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,14 @@ ppo_mini_batch_size=8
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))

USE_LEGACY_WORKER_IMPL="enable" # disable, enable
USE_LEGACY_WORKER_IMPL="disable" # disable, enable
if [ "$USE_LEGACY_WORKER_IMPL" = "disable" ]; then
ROUTING_REPLAY_MODE_ARG="actor_rollout_ref.actor.megatron.router_replay.mode=${ROUTING_REPLAY_MODE}"
remove_padding=True
else
ROUTING_REPLAY_MODE_ARG="actor_rollout_ref.actor.router_replay.mode=${ROUTING_REPLAY_MODE}"
remove_padding=False
fi
exper_name=Node${NODES}_bs${bs}_${PP}${TP}${EP}${ETP}_${VLLM_INFER_TP}_minbs${ppo_mini_batch_size}_micro_bs${micro_bs}

python3 -m verl.trainer.main_ppo --config-path=config \
Expand All @@ -56,10 +63,10 @@ python3 -m verl.trainer.main_ppo --config-path=config \
data.truncation='error' \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.model.path=$HF_MODEL_PATH \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_remove_padding=$remove_padding \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.actor.megatron.router_replay.mode=${ROUTING_REPLAY_MODE} \
${ROUTING_REPLAY_MODE_ARG} \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
Expand All @@ -84,6 +91,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_bs \
Expand Down
10 changes: 1 addition & 9 deletions verl/utils/megatron/router_replay_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(self):
self.recorded_topk_idx = None # For recording
self.router_replay_action = None # Router replay action for this layer
self.replay_backward_list = [] # List of tensors for backward pass replay
self.layer_number = None # Global layer index if available
RouterReplay.router_instances.append(self)

def set_target_indices(self, topk_indices: torch.Tensor):
Expand Down Expand Up @@ -268,7 +267,7 @@ def patched_routing(self, logits: torch.Tensor, *args, **kwargs):
score_function=self.score_function,
expert_bias=self.expert_bias,
fused=self.config.moe_router_fusion,
router_replay=self.router_replay,
router_replay=getattr(self, "router_replay", None),
)

# Apply token dropping to probs and routing_map.
Expand Down Expand Up @@ -337,12 +336,6 @@ def patched_tf_config_init(self, *args, **kwargs):
return

original_init = TopKRouter.__init__
original_set_layer_number = TopKRouter.set_layer_number

def patched_set_layer_number(self, layer_number: int):
original_set_layer_number(self, layer_number)
if self.router_replay is not None:
self.router_replay.layer_number = layer_number

# Step 3: Define the new __init__ method
def patched_init(self, *args, **kwargs):
Expand Down Expand Up @@ -384,5 +377,4 @@ def patched_preprocess(self, routing_map):
# Step 5: Apply the patches
TopKRouter.__init__ = patched_init
TopKRouter.routing = patched_routing
TopKRouter.set_layer_number = patched_set_layer_number
TopKRouter._router_replay_patched = True
75 changes: 56 additions & 19 deletions verl/utils/megatron/router_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,41 @@ def get_num_layers_to_build(
return num_layers_to_build


def is_moe_layer(tf_config, layer_idx):
moe_layer_freq = getattr(tf_config, "moe_layer_freq", None)

if isinstance(moe_layer_freq, int):
return layer_idx % moe_layer_freq == 0
elif isinstance(moe_layer_freq, list):
return moe_layer_freq[layer_idx] == 1
else:
raise ValueError(f"Unsupported moe_layer_freq type: {type(moe_layer_freq)}")
Comment on lines +174 to +182
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The is_moe_layer function has two potential issues that could lead to runtime errors:

  1. Unset moe_layer_freq: The docstring for get_moe_num_layers_to_build specifies that if moe_layer_freq is unset, all layers should be treated as MoE layers. However, the current implementation getattr(tf_config, "moe_layer_freq", None) will cause moe_layer_freq to be None, which then leads to a ValueError. This will cause a crash when router replay is used with a model that doesn't explicitly define moe_layer_freq.
  2. moe_layer_freq is zero: If moe_layer_freq is set to 0, the expression layer_idx % moe_layer_freq will cause a ZeroDivisionError.

To make the function more robust, it should default to 1 for an unset moe_layer_freq and explicitly handle non-positive values.

Suggested change
def is_moe_layer(tf_config, layer_idx):
moe_layer_freq = getattr(tf_config, "moe_layer_freq", None)
if isinstance(moe_layer_freq, int):
return layer_idx % moe_layer_freq == 0
elif isinstance(moe_layer_freq, list):
return moe_layer_freq[layer_idx] == 1
else:
raise ValueError(f"Unsupported moe_layer_freq type: {type(moe_layer_freq)}")
def is_moe_layer(tf_config, layer_idx):
moe_layer_freq = getattr(tf_config, "moe_layer_freq", 1)
if isinstance(moe_layer_freq, int):
if moe_layer_freq <= 0:
return False
return layer_idx % moe_layer_freq == 0
elif isinstance(moe_layer_freq, list):
return moe_layer_freq[layer_idx] == 1
else:
raise ValueError(f"Unsupported moe_layer_freq type: {type(moe_layer_freq)}")



def get_moe_num_layers_to_build(
config: TransformerConfig, vp_stage: Optional[int] = None, pp_rank: Optional[int] = None
) -> int:
"""Count the number of MoE layers assigned to the current rank.
When ``moe_layer_freq`` is 1 or unset, every transformer layer is an MoE
layer, so the count equals the total layer count. Otherwise only layers
whose global index satisfies the frequency predicate are counted.
Args:
config: Megatron TransformerConfig providing layer layout information.
vp_stage: Virtual-pipeline stage index (None defaults to current).
pp_rank: Pipeline-parallel rank (None defaults to current).
Returns:
Number of MoE layers on the specified rank/stage.
"""
total_layers = get_num_layers_to_build(config, vp_stage=vp_stage, pp_rank=pp_rank)

layer_offset = get_transformer_layer_offset(config, vp_stage=vp_stage)
local_global_indices = range(layer_offset, layer_offset + total_layers)

num_moe_layers = sum(1 for idx in local_global_indices if is_moe_layer(config, idx))

return num_moe_layers


def merge_router_topk_indices(attention_mask, input_ids, mini_layer_topk_idx_list, tf_config, vp_rank=None):
"""
Merge recorded router top-k indices across sequence-parallel ranks for all router instances,
Expand Down Expand Up @@ -249,24 +284,27 @@ def set_router_replay_data(layers_topk_idx, attention_mask, tf_config, vp_rank=N
layers_topk_idx_reshape = layers_topk_idx_rmpad_split.permute(0, 2, 1, 3).squeeze(
dim=0
) # layer_num, dynamic_bs_all, topk
num_layers_in_data = layers_topk_idx_reshape.shape[0]
use_global_layer_index = getattr(tf_config, "num_layers", None) == num_layers_in_data
local_rank_info = get_current_rank_layer_info(tf_config, vp_rank)
offset, _ = local_rank_info["start"], local_rank_info["end"]
offset, end = local_rank_info["start"], local_rank_info["end"]
router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank)
for i, router in enumerate(router_instances_list):
layer_idx = None
if use_global_layer_index:
layer_number = getattr(router, "layer_number", None)
if layer_number is not None:
layer_idx = layer_number - 1
if layer_idx is None:
layer_idx = i + offset
if layer_idx < 0 or layer_idx >= num_layers_in_data:
raise ValueError(
f"router replay layer index {layer_idx} out of range for data with {num_layers_in_data} layers"
)
router.set_target_indices(layers_topk_idx_reshape[layer_idx].to(torch.int64))

# When dim-0 covers all layers (e.g. R3, or R2 with all-MoE models),
# index by absolute layer_idx; otherwise (R2 with mixed dense/MoE),
# dim-0 only contains MoE layers, index by MoE-layer ordinal.
index_by_layer = len(layers_topk_idx_reshape) == tf_config.num_layers

# For R2: count MoE layers before `offset` as the starting position.
moe_idx = sum(1 for i in range(offset) if is_moe_layer(tf_config, i))

router_offset = 0
for layer_idx in range(offset, end):
if not is_moe_layer(tf_config, layer_idx):
continue
router = router_instances_list[router_offset]
idx = layer_idx if index_by_layer else moe_idx
router.set_target_indices(layers_topk_idx_reshape[idx].to(torch.int64))
router_offset += 1
moe_idx += 1


def reorder_and_merge_vpp_layers(
Expand Down Expand Up @@ -378,7 +416,7 @@ def pp_gather(local_layers_router_map, tf_config):
for pp_stage in range(pp_size):
vpp_router_map_offset[pp_stage].append(0)
for vp_stage in range(vp_size):
num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage, pp_stage)
num_layers_to_build = get_moe_num_layers_to_build(tf_config, vp_stage, pp_stage)
vpp_router_map_offset[pp_stage].append(num_layers_to_build + vpp_router_map_offset[pp_stage][-1])
layers_topk_idx_global = []
for vp_stage in range(vp_size):
Expand Down Expand Up @@ -419,8 +457,7 @@ def get_micro_batch_router_list(tf_config, vp_rank=None):
for pre_vp_stage in range(vp_size):
if pre_vp_stage == vp_rank:
break
num_layers_to_build = get_num_layers_to_build(tf_config, pre_vp_stage)
offset += num_layers_to_build
offset += get_moe_num_layers_to_build(tf_config, pre_vp_stage)
else:
offset = 0

Expand Down
2 changes: 1 addition & 1 deletion verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def prepare_model_inputs(self, batch: TensorDict):
loss_mask = batch["loss_mask"].to(bool)
multi_modal_inputs = extract_multi_modal_inputs(batch.get("multi_modal_inputs", []))

routed_experts = batch.get("routed_experts", [])
routed_experts = batch.get("routed_experts", None)

return {
"input_ids": input_ids,
Expand Down
Loading