diff --git a/examples/router_replay/README.md b/examples/router_replay/README.md index c8c556b8f8b..93006431ee2 100644 --- a/examples/router_replay/README.md +++ b/examples/router_replay/README.md @@ -68,4 +68,4 @@ actor_rollout_ref.actor.router_replay.mode="R3" actor_rollout_ref.rollout.enable_rollout_routing_replay=True ``` -R3 mode requires the rollout backend to support returning router selection results. Currently, this functionality is being tested based on the vllm implementation at https://github.com/vllm-project/vllm/pull/28284 and SGLang implementation at https://github.com/sgl-project/sglang/commit/bed301a5acaa9577c9aa706468bdf242f6a43051. +R3 mode requires the rollout backend to support returning router selection results. Currently, this functionality is being tested based on the vllm implementation at https://github.com/vllm-project/vllm/pull/28284 as well as bug fix at https://github.com/vllm-project/vllm/pull/33013 and SGLang implementation at https://github.com/sgl-project/sglang/commit/bed301a5acaa9577c9aa706468bdf242f6a43051. diff --git a/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh b/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh index 9a45bdaac73..74e7af0dee0 100644 --- a/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh +++ b/examples/router_replay/run_qwen30_a3b_megatron_vllm.sh @@ -6,7 +6,9 @@ NODES=1 # R2: enable routing replay # R3: enable rollout routing replay # If enabling R3, please set actor_rollout_ref.rollout.enable_rollout_routing_replay=True -# R3 example is based on vllm related pr https://github.com/vllm-project/vllm/pull/5322 +# R3 example is based on vllm related pr: +# - https://github.com/vllm-project/vllm/pull/28284 +# - https://github.com/vllm-project/vllm/pull/33013 ROUTING_REPLAY_MODE="R2" diff --git a/verl/utils/megatron/router_replay_patch.py b/verl/utils/megatron/router_replay_patch.py index 328b93f99dd..8311126043e 100644 --- a/verl/utils/megatron/router_replay_patch.py +++ b/verl/utils/megatron/router_replay_patch.py @@ -84,6 +84,7 @@ def __init__(self): self.recorded_topk_idx = None # For recording self.router_replay_action = None # Router replay action for this layer self.replay_backward_list = [] # List of tensors for backward pass replay + self.layer_number = None # Global layer index if available RouterReplay.router_instances.append(self) def set_target_indices(self, topk_indices: torch.Tensor): @@ -336,6 +337,12 @@ def patched_tf_config_init(self, *args, **kwargs): return original_init = TopKRouter.__init__ + original_set_layer_number = TopKRouter.set_layer_number + + def patched_set_layer_number(self, layer_number: int): + original_set_layer_number(self, layer_number) + if self.router_replay is not None: + self.router_replay.layer_number = layer_number # Step 3: Define the new __init__ method def patched_init(self, *args, **kwargs): @@ -374,4 +381,5 @@ def patched_preprocess(self, routing_map): # Step 5: Apply the patches TopKRouter.__init__ = patched_init TopKRouter.routing = patched_routing + TopKRouter.set_layer_number = patched_set_layer_number TopKRouter._router_replay_patched = True diff --git a/verl/utils/megatron/router_replay_utils.py b/verl/utils/megatron/router_replay_utils.py index b1a42840377..b3774102c14 100644 --- a/verl/utils/megatron/router_replay_utils.py +++ b/verl/utils/megatron/router_replay_utils.py @@ -242,11 +242,24 @@ def set_router_replay_data(layers_topk_idx, attention_mask, tf_config, vp_rank=N layers_topk_idx_reshape = layers_topk_idx_rmpad_split.permute(0, 2, 1, 3).squeeze( dim=0 ) # layer_num, dynamic_bs_all, topk + num_layers_in_data = layers_topk_idx_reshape.shape[0] + use_global_layer_index = getattr(tf_config, "num_layers", None) == num_layers_in_data local_rank_info = get_current_rank_layer_info(tf_config, vp_rank) offset, _ = local_rank_info["start"], local_rank_info["end"] router_instances_list = RouterReplayHelper.get_micro_batch_router_list(tf_config, vp_rank) for i, router in enumerate(router_instances_list): - router.set_target_indices(layers_topk_idx_reshape[i + offset].to(torch.int64)) + layer_idx = None + if use_global_layer_index: + layer_number = getattr(router, "layer_number", None) + if layer_number is not None: + layer_idx = layer_number - 1 + if layer_idx is None: + layer_idx = i + offset + if layer_idx < 0 or layer_idx >= num_layers_in_data: + raise ValueError( + f"router replay layer index {layer_idx} out of range for data with {num_layers_in_data} layers" + ) + router.set_target_indices(layers_topk_idx_reshape[layer_idx].to(torch.int64)) def reorder_and_merge_vpp_layers(