From abe258aa6537356ceac9a4b96679da3d01e08b9c Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Fri, 26 Sep 2025 11:48:34 +0000 Subject: [PATCH] [feat] add --use-routing-replay --- docker/patch/latest/megatron.patch | 136 +++++++++++++++++++++++++ slime/backends/megatron_utils/actor.py | 12 +++ slime/backends/megatron_utils/model.py | 8 ++ slime/ray/actor_group.py | 3 + slime/utils/arguments.py | 6 ++ tests/test_quick_start_glm4-9B.sh | 3 + tests/test_qwen3-30B-A3B.sh | 1 + 7 files changed, 169 insertions(+) diff --git a/docker/patch/latest/megatron.patch b/docker/patch/latest/megatron.patch index 45f625556..1148507ff 100644 --- a/docker/patch/latest/megatron.patch +++ b/docker/patch/latest/megatron.patch @@ -363,6 +363,142 @@ index 63ee9d1f..b90b744c 100644 ) ops.append(recv_next_op) if len(ops) > 0: +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 235b6f6a..640273ba 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -1,9 +1,11 @@ + # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + ++import os + import math + from typing import List, Optional, Union + + import torch ++import torch.distributed as dist + + from megatron.core import parallel_state + from megatron.core.process_groups_config import ModelCommProcessGroups +@@ -506,6 +508,48 @@ def pad_routing_map(routing_map: torch.Tensor, pad_multiple: int) -> torch.Tenso + return routing_map + + ++ROUTING_REPLAY = None ++ ++ ++def set_routing_replay(replay): ++ global ROUTING_REPLAY ++ ROUTING_REPLAY = replay ++ ++ ++class RoutingReplay: ++ all_routing_replays = [] ++ ++ def __init__(self): ++ self.forward_indices = 0 ++ self.backward_indices = [] ++ self.top_indices_list = [] ++ RoutingReplay.all_routing_replays.append(self) ++ ++ def record(self, top_indices): ++ self.top_indices_list.append(top_indices) ++ ++ def pop_forward(self): ++ top_indices = self.top_indices_list[self.forward_indices] ++ self.backward_indices.append(self.forward_indices) ++ self.forward_indices += 1 ++ return top_indices ++ ++ def pop_backward(self): ++ backward_indices = self.backward_indices.pop() ++ top_indices = self.top_indices_list[backward_indices] ++ return top_indices ++ ++ def clear(self): ++ self.forward_indices = 0 ++ self.backward_indices = [] ++ self.top_indices_list = [] ++ ++ @staticmethod ++ def clear_all(): ++ for replay in RoutingReplay.all_routing_replays: ++ replay.clear() ++ ++ + def topk_routing_with_score_function( + logits: torch.Tensor, + topk: int, +@@ -553,7 +597,7 @@ def topk_routing_with_score_function( + expert_bias=expert_bias, + ) + +- def compute_topk(scores, topk, num_groups=None, group_topk=None): ++ def _compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return group_limited_topk( + scores=scores, +@@ -566,6 +610,30 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ def compute_topk(scores, topk, num_groups=None, group_topk=None): ++ if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": ++ routing_replay_stage = os.environ["ROUTING_REPLAY_STAGE"] ++ if routing_replay_stage == "fallthrough": ++ return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) ++ if routing_replay_stage == "record": ++ probs, top_indices = _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) ++ ROUTING_REPLAY.record(top_indices) ++ elif routing_replay_stage == "replay_forward": ++ top_indices = ROUTING_REPLAY.pop_forward() ++ assert top_indices.shape[0] == scores.shape[0] and top_indices.shape[1] == topk, ( ++ f"top_indices shape {top_indices.shape} does not match scores shape {scores.shape} and topk {topk}" ++ ) ++ probs = scores.gather(1, top_indices) ++ elif routing_replay_stage == "replay_backward": ++ top_indices = ROUTING_REPLAY.pop_backward() ++ assert top_indices.shape[0] == scores.shape[0] and top_indices.shape[1] == topk, ( ++ f"top_indices shape {top_indices.shape} does not match scores shape {scores.shape} and topk {topk}" ++ ) ++ probs = scores.gather(1, top_indices) ++ return probs, top_indices ++ else: ++ return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 6b20b862..80786f84 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -1,5 +1,6 @@ + # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + ++import os + from abc import ABC, abstractmethod + from typing import Optional + +@@ -156,6 +157,19 @@ class TopKRouter(Router): + self.local_tokens_per_expert = None + self.expert_bias = None + ++ if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": ++ from .moe_utils import RoutingReplay, set_routing_replay ++ self.routing_replay = RoutingReplay() ++ ++ def forward_hook(*args, **kwargs): ++ set_routing_replay(self.routing_replay) ++ ++ def backward_hook(*args, **kwargs): ++ set_routing_replay(self.routing_replay) ++ ++ self.register_forward_pre_hook(forward_hook) ++ self.register_full_backward_pre_hook(backward_hook) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 6f557e1f..b295fd35 100644 --- a/megatron/core/transformer/transformer_config.py diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 1aa7bc8d7..1439fa621 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -1,3 +1,4 @@ +import os import socket import time from contextlib import nullcontext @@ -305,6 +306,8 @@ def train_actor(self, rollout_id, rollout_data): with timer("train"): if self.args.compute_advantages_and_returns: if "ref" in self.weights: + if self.args.use_routing_replay: + os.environ["ROUTING_REPLAY_STAGE"] = "fallthrough" ref_log_probs = self.compute_log_prob( "ref", data_iterator, @@ -313,6 +316,8 @@ def train_actor(self, rollout_id, rollout_data): ) rollout_data.update(ref_log_probs) + if self.args.use_routing_replay: + os.environ["ROUTING_REPLAY_STAGE"] = "record" log_probs = self.compute_log_prob( "old_actor" if self.args.keep_old_actor else "actor", data_iterator, @@ -348,6 +353,8 @@ def train_actor(self, rollout_id, rollout_data): self.prof.step() # Train + if self.args.use_routing_replay: + os.environ["ROUTING_REPLAY_STAGE"] = "replay_backward" with timer("actor_train"): train( rollout_id, @@ -383,6 +390,11 @@ def train_actor(self, rollout_id, rollout_data): path, ) + if self.args.use_routing_replay: + from megatron.core.transformer.moe.moe_utils import RoutingReplay + + RoutingReplay.clear_all() + # update the cpu actor weight to the latest model self.update_cpu_params_dict(self.weights["actor"]) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index d4985b0f6..f09f52dc8 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -1,6 +1,7 @@ import dataclasses import gc import math +import os from contextlib import nullcontext from functools import partial @@ -297,6 +298,10 @@ def forward_step(data_iterator, model: GPTModel): ], ) + if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": + old_stage = os.environ["ROUTING_REPLAY_STAGE"] + os.environ["ROUTING_REPLAY_STAGE"] = "replay_forward" + output_tensor = model( input_ids=batch["tokens"], position_ids=None, @@ -305,6 +310,9 @@ def forward_step(data_iterator, model: GPTModel): packed_seq_params=batch["packed_seq_params"], ) + if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": + os.environ["ROUTING_REPLAY_STAGE"] = old_stage + return output_tensor, partial(loss_function, args, batch, num_microbatches) # Forward pass. diff --git a/slime/ray/actor_group.py b/slime/ray/actor_group.py index 7b581d8c5..0f6d4be0c 100644 --- a/slime/ray/actor_group.py +++ b/slime/ray/actor_group.py @@ -75,6 +75,9 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor, wandb_run_id: Optiona env_vars["TMS_INIT_ENABLE"] = "1" env_vars["TMS_INIT_ENABLE_CPU_BACKUP"] = "1" + if self.args.use_routing_replay: + env_vars["ENABLE_ROUTING_REPLAY"] = "1" + backend = os.environ.get("SLIME_BACKEND", "megatron").lower() if backend == "megatron": from slime.backends.megatron_utils import MegatronTrainRayActor diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index bdc4bb48e..aeafedd3c 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -635,6 +635,12 @@ def add_algo_arguments(parser): default=0, help="Lower bound clipping threshold C for importance sampling ratios to control variance.", ) + + parser.add_argument( + "--use-routing-replay", + action="store_true", + default=False, + ) return parser def add_router_arguments(parser): diff --git a/tests/test_quick_start_glm4-9B.sh b/tests/test_quick_start_glm4-9B.sh index 9a0f8a015..53413a3ff 100644 --- a/tests/test_quick_start_glm4-9B.sh +++ b/tests/test_quick_start_glm4-9B.sh @@ -19,6 +19,9 @@ source "${SCRIPT_DIR}/../scripts/models/glm4-9B.sh" CKPT_ARGS=( --hf-checkpoint /root/models/GLM-Z1-9B-0414/ --ref-load /root/GLM-Z1-9B-0414_torch_dist + + --fp8-format e4m3 + --fp8-recipe blockwise ) ROLLOUT_ARGS=( diff --git a/tests/test_qwen3-30B-A3B.sh b/tests/test_qwen3-30B-A3B.sh index 188537657..125efc827 100644 --- a/tests/test_qwen3-30B-A3B.sh +++ b/tests/test_qwen3-30B-A3B.sh @@ -74,6 +74,7 @@ GRPO_ARGS=( --eps-clip 4e-4 --use-tis + --use-routing-replay ) OPTIMIZER_ARGS=(