Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 52 additions & 41 deletions docker/patch/latest/sglang.patch
Original file line number Diff line number Diff line change
Expand Up @@ -814,27 +814,23 @@ index 7f6f6a010..c4a673145 100644
if not get_global_server_args().sampling_backend == "ascend" or (
return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB
diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py
index 87922077e..8cb6bad8d 100644
index 87922077e..6507d8bf5 100644
--- a/python/sglang/srt/managers/detokenizer_manager.py
+++ b/python/sglang/srt/managers/detokenizer_manager.py
@@ -247,6 +247,16 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
@@ -247,6 +247,12 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
s.sent_offset = len(output_str)
output_strs.append(incremental_output)

+ output_routed_experts = []
+ if recv_obj.output_routed_experts is not None:
+ output_routed_experts = [
+ (
+ output_routed_experts.tolist()
+ if output_routed_experts is not None
+ else []
+ )
+ output_routed_experts
+ for output_routed_experts in recv_obj.output_routed_experts
+ ]
return BatchStrOutput(
rids=recv_obj.rids,
http_worker_ipcs=recv_obj.http_worker_ipcs,
@@ -272,6 +282,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
@@ -272,6 +278,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_token_entropy_val=recv_obj.output_token_entropy_val,
output_hidden_states=recv_obj.output_hidden_states,
Expand Down Expand Up @@ -1165,29 +1161,72 @@ index f8ebfc1f4..48b9a1a3b 100644
return ResumeMemoryOccupationReqOutput()

def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput):
diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
index edbc52526..2cdc42755 100644
--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py
+++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
@@ -421,6 +421,11 @@ class TokenizerCommunicatorMixin:
result = (await self.update_weights_from_distributed_communicator(obj))[
0
]
+ if result.success and obj.weight_version is not None:
+ self._update_weight_version_if_provided(obj.weight_version)
+ result.message += (
+ f" Weight version updated to {obj.weight_version}."
+ )
return result.success, result.message

# This means that weight sync
@@ -480,6 +485,11 @@ class TokenizerCommunicatorMixin:
async with self.is_pause_cond:
if self.is_pause:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
+ if result.success and obj.weight_version is not None:
+ self._update_weight_version_if_provided(obj.weight_version)
+ result.message += (
+ f" Weight version updated to {obj.weight_version}."
+ )
return result.success, result.message

# This means that weight sync
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index b90cf0616..98d71d896 100644
index b90cf0616..9b0992655 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -888,6 +888,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
@@ -20,6 +20,7 @@ import logging
import math
import os
import pickle
+import pybase64
import signal
import sys
import threading
@@ -888,6 +889,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states,
+ return_routed_experts=obj.return_routed_experts,
data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority,
extra_key=obj.extra_key,
@@ -1621,6 +1622,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
@@ -1621,6 +1623,16 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if getattr(recv_obj, "output_hidden_states", None):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

+ if getattr(recv_obj, "output_routed_experts", None):
+ meta_info["routed_experts"] = recv_obj.output_routed_experts[i]
+ if recv_obj.output_routed_experts[i] is not None:
+ # print(f"{recv_obj.output_routed_experts[i].shape=}, {recv_obj.output_routed_experts[i].dtype=}")
+ # torch.save(recv_obj.output_routed_experts[i], f"/root/{recv_obj.output_routed_experts[i].shape[0]}.pt")
+ meta_info["routed_experts"] = pybase64.b64encode(
+ recv_obj.output_routed_experts[i].contiguous().numpy().tobytes(order="C")
+ ).decode("ascii")
+ else:
+ meta_info["routed_experts"] = None
+
if isinstance(recv_obj, BatchStrOutput):
state.text += recv_obj.output_strs[i]
if self.server_args.stream_output and state.obj.stream:
@@ -1747,12 +1751,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
@@ -1747,12 +1759,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return

if len(recv_obj.input_token_logprobs_val) > 0:
Expand Down Expand Up @@ -1975,31 +2014,3 @@ index b3d72df05..ddfe0b178 100644


@dataclass
diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
index edbc52526..2cdc42755 100644
--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py
+++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
@@ -421,6 +421,11 @@ class TokenizerCommunicatorMixin:
result = (await self.update_weights_from_distributed_communicator(obj))[
0
]
+ if result.success and obj.weight_version is not None:
+ self._update_weight_version_if_provided(obj.weight_version)
+ result.message += (
+ f" Weight version updated to {obj.weight_version}."
+ )
return result.success, result.message

# This means that weight sync
@@ -480,6 +485,11 @@ class TokenizerCommunicatorMixin:
async with self.is_pause_cond:
if self.is_pause:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
+ if result.success and obj.weight_version is not None:
+ self._update_weight_version_if_provided(obj.weight_version)
+ result.message += (
+ f" Weight version updated to {obj.weight_version}."
+ )
return result.success, result.message

# This means that weight sync
2 changes: 1 addition & 1 deletion docker/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
nightly-dev-20251215a
nightly-dev-20251216a
11 changes: 9 additions & 2 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

import numpy as np
import pybase64
import sglang_router
from packaging.version import parse
from tqdm import tqdm
Expand Down Expand Up @@ -186,8 +187,14 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
sample.weight_versions.append(output["meta_info"]["weight_version"])

if "routed_experts" in output["meta_info"]:
assert len(output["meta_info"]["routed_experts"]) == len(sample.tokens) - 1
sample.rollout_routed_experts = np.array(output["meta_info"]["routed_experts"])
sample.rollout_routed_experts = np.frombuffer(
pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")),
dtype=np.int32,
).reshape(
len(sample.tokens) - 1,
args.num_layers,
args.moe_router_topk,
)

match output["meta_info"]["finish_reason"]["type"]:
case "length":
Expand Down