Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions docker/patch/latest/megatron.patch
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,142 @@ index 63ee9d1f..b90b744c 100644
)
ops.append(recv_next_op)
if len(ops) > 0:
diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py
index 235b6f6a..640273ba 100644
--- a/megatron/core/transformer/moe/moe_utils.py
+++ b/megatron/core/transformer/moe/moe_utils.py
@@ -1,9 +1,11 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

+import os
import math
from typing import List, Optional, Union

import torch
+import torch.distributed as dist

from megatron.core import parallel_state
from megatron.core.process_groups_config import ModelCommProcessGroups
@@ -506,6 +508,48 @@ def pad_routing_map(routing_map: torch.Tensor, pad_multiple: int) -> torch.Tenso
return routing_map


+ROUTING_REPLAY = None
+
+
+def set_routing_replay(replay):
+ global ROUTING_REPLAY
+ ROUTING_REPLAY = replay
+
+
+class RoutingReplay:
+ all_routing_replays = []
+
+ def __init__(self):
+ self.forward_indices = 0
+ self.backward_indices = []
+ self.top_indices_list = []
+ RoutingReplay.all_routing_replays.append(self)
+
+ def record(self, top_indices):
+ self.top_indices_list.append(top_indices)
+
+ def pop_forward(self):
+ top_indices = self.top_indices_list[self.forward_indices]
+ self.backward_indices.append(self.forward_indices)
+ self.forward_indices += 1
+ return top_indices
+
+ def pop_backward(self):
+ backward_indices = self.backward_indices.pop()
+ top_indices = self.top_indices_list[backward_indices]
+ return top_indices
+
+ def clear(self):
+ self.forward_indices = 0
+ self.backward_indices = []
+ self.top_indices_list = []
+
+ @staticmethod
+ def clear_all():
+ for replay in RoutingReplay.all_routing_replays:
+ replay.clear()
+
+
def topk_routing_with_score_function(
logits: torch.Tensor,
topk: int,
@@ -553,7 +597,7 @@ def topk_routing_with_score_function(
expert_bias=expert_bias,
)

- def compute_topk(scores, topk, num_groups=None, group_topk=None):
+ def _compute_topk(scores, topk, num_groups=None, group_topk=None):
if group_topk:
return group_limited_topk(
scores=scores,
@@ -566,6 +610,30 @@ def topk_routing_with_score_function(
else:
return torch.topk(scores, k=topk, dim=1)

+ def compute_topk(scores, topk, num_groups=None, group_topk=None):
+ if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1":
+ routing_replay_stage = os.environ["ROUTING_REPLAY_STAGE"]
+ if routing_replay_stage == "fallthrough":
+ return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk)
+ if routing_replay_stage == "record":
+ probs, top_indices = _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk)
+ ROUTING_REPLAY.record(top_indices)
+ elif routing_replay_stage == "replay_forward":
+ top_indices = ROUTING_REPLAY.pop_forward()
+ assert top_indices.shape[0] == scores.shape[0] and top_indices.shape[1] == topk, (
+ f"top_indices shape {top_indices.shape} does not match scores shape {scores.shape} and topk {topk}"
+ )
+ probs = scores.gather(1, top_indices)
+ elif routing_replay_stage == "replay_backward":
+ top_indices = ROUTING_REPLAY.pop_backward()
+ assert top_indices.shape[0] == scores.shape[0] and top_indices.shape[1] == topk, (
+ f"top_indices shape {top_indices.shape} does not match scores shape {scores.shape} and topk {topk}"
+ )
+ probs = scores.gather(1, top_indices)
+ return probs, top_indices
+ else:
+ return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk)
+
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py
index 6b20b862..80786f84 100644
--- a/megatron/core/transformer/moe/router.py
+++ b/megatron/core/transformer/moe/router.py
@@ -1,5 +1,6 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

+import os
from abc import ABC, abstractmethod
from typing import Optional

@@ -156,6 +157,19 @@ class TopKRouter(Router):
self.local_tokens_per_expert = None
self.expert_bias = None

+ if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1":
+ from .moe_utils import RoutingReplay, set_routing_replay
+ self.routing_replay = RoutingReplay()
+
+ def forward_hook(*args, **kwargs):
+ set_routing_replay(self.routing_replay)
+
+ def backward_hook(*args, **kwargs):
+ set_routing_replay(self.routing_replay)
+
+ self.register_forward_pre_hook(forward_hook)
+ self.register_full_backward_pre_hook(backward_hook)
+
def _maintain_float32_expert_bias(self):
"""
Maintain the expert bias in float32.
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
index 6f557e1f..b295fd35 100644
--- a/megatron/core/transformer/transformer_config.py
Expand Down
12 changes: 12 additions & 0 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import socket
import time
from contextlib import nullcontext
Expand Down Expand Up @@ -305,6 +306,8 @@ def train_actor(self, rollout_id, rollout_data):
with timer("train"):
if self.args.compute_advantages_and_returns:
if "ref" in self.weights:
if self.args.use_routing_replay:
os.environ["ROUTING_REPLAY_STAGE"] = "fallthrough"
ref_log_probs = self.compute_log_prob(
"ref",
data_iterator,
Expand All @@ -313,6 +316,8 @@ def train_actor(self, rollout_id, rollout_data):
)
rollout_data.update(ref_log_probs)

if self.args.use_routing_replay:
os.environ["ROUTING_REPLAY_STAGE"] = "record"
log_probs = self.compute_log_prob(
"old_actor" if self.args.keep_old_actor else "actor",
data_iterator,
Expand Down Expand Up @@ -348,6 +353,8 @@ def train_actor(self, rollout_id, rollout_data):
self.prof.step()

# Train
if self.args.use_routing_replay:
os.environ["ROUTING_REPLAY_STAGE"] = "replay_backward"
with timer("actor_train"):
train(
rollout_id,
Expand Down Expand Up @@ -383,6 +390,11 @@ def train_actor(self, rollout_id, rollout_data):
path,
)

if self.args.use_routing_replay:
from megatron.core.transformer.moe.moe_utils import RoutingReplay

RoutingReplay.clear_all()

# update the cpu actor weight to the latest model
self.update_cpu_params_dict(self.weights["actor"])

Expand Down
8 changes: 8 additions & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import gc
import math
import os
from contextlib import nullcontext
from functools import partial

Expand Down Expand Up @@ -297,6 +298,10 @@ def forward_step(data_iterator, model: GPTModel):
],
)

if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1":
old_stage = os.environ["ROUTING_REPLAY_STAGE"]
os.environ["ROUTING_REPLAY_STAGE"] = "replay_forward"

output_tensor = model(
input_ids=batch["tokens"],
position_ids=None,
Expand All @@ -305,6 +310,9 @@ def forward_step(data_iterator, model: GPTModel):
packed_seq_params=batch["packed_seq_params"],
)

if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1":
os.environ["ROUTING_REPLAY_STAGE"] = old_stage

return output_tensor, partial(loss_function, args, batch, num_microbatches)

# Forward pass.
Expand Down
3 changes: 3 additions & 0 deletions slime/ray/actor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor, wandb_run_id: Optiona
env_vars["TMS_INIT_ENABLE"] = "1"
env_vars["TMS_INIT_ENABLE_CPU_BACKUP"] = "1"

if self.args.use_routing_replay:
env_vars["ENABLE_ROUTING_REPLAY"] = "1"

backend = os.environ.get("SLIME_BACKEND", "megatron").lower()
if backend == "megatron":
from slime.backends.megatron_utils import MegatronTrainRayActor
Expand Down
6 changes: 6 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,12 @@ def add_algo_arguments(parser):
default=0,
help="Lower bound clipping threshold C for importance sampling ratios to control variance.",
)

parser.add_argument(
"--use-routing-replay",
action="store_true",
default=False,
)
return parser

def add_router_arguments(parser):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_quick_start_glm4-9B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ source "${SCRIPT_DIR}/../scripts/models/glm4-9B.sh"
CKPT_ARGS=(
--hf-checkpoint /root/models/GLM-Z1-9B-0414/
--ref-load /root/GLM-Z1-9B-0414_torch_dist

--fp8-format e4m3
--fp8-recipe blockwise
)

ROLLOUT_ARGS=(
Expand Down
1 change: 1 addition & 0 deletions tests/test_qwen3-30B-A3B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ GRPO_ARGS=(
--eps-clip 4e-4

--use-tis
--use-routing-replay
)

OPTIMIZER_ARGS=(
Expand Down
Loading