Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
cf8d0fd
init
ocss884 Oct 26, 2025
a720fae
Merge branch 'sgl-project:main' into return_routed_expert
ocss884 Oct 26, 2025
542d490
more
ocss884 Oct 26, 2025
cfc330b
small fix
ocss884 Oct 26, 2025
1b9e9fa
small fix
ocss884 Oct 26, 2025
5a70eb7
refactor
ocss884 Oct 28, 2025
6924725
add layer_id to select_experts
ocss884 Oct 28, 2025
1f40d43
more
ocss884 Oct 29, 2025
794af1f
more
ocss884 Oct 30, 2025
e15db98
rm
ocss884 Oct 30, 2025
f92a997
copy to req when finished
ocss884 Oct 30, 2025
41bebeb
misc
ocss884 Oct 30, 2025
4ec04be
lint
ocss884 Oct 31, 2025
fa916bc
small fix
ocss884 Oct 31, 2025
98c8fe0
nit
ocss884 Nov 3, 2025
8da8de1
Merge branch 'main' into return_routed_expert
ocss884 Nov 3, 2025
6c92475
use key word args
ocss884 Nov 3, 2025
812421a
more
ocss884 Nov 3, 2025
ded85ca
adapt for cuda graph
yizhang2077 Nov 10, 2025
ce7463d
enable for engine
ocss884 Nov 11, 2025
bf63df9
Merge branch 'main' into return_routed_expert
ocss884 Nov 11, 2025
22e3910
fix for schedule overlap
ocss884 Nov 11, 2025
b2af1cb
Merge branch 'main' into return_routed_expert
zhaochenyang20 Nov 11, 2025
911bd1a
Merge branch 'main' into return_routed_expert
ocss884 Nov 12, 2025
0a0645f
bugfix
ocss884 Nov 12, 2025
b041ddc
add arg for bench_serving
ocss884 Nov 12, 2025
c70bded
more
ocss884 Nov 12, 2025
3d54694
add out_cache_loc_cpu to improve performance
ocss884 Nov 13, 2025
1b701a8
Merge branch 'main' into return_routed_expert
ocss884 Nov 13, 2025
31d9509
more
ocss884 Nov 13, 2025
b72a6f2
move topk_ids tolist into detokenizer
yizhang2077 Nov 24, 2025
fdb5c47
rollback
ocss884 Nov 19, 2025
b77a5fb
add ut
ocss884 Nov 23, 2025
83192fb
dp attn
ocss884 Dec 2, 2025
aa4df5e
fix dp attn
ocss884 Dec 3, 2025
e7b1fc9
Merge branch 'main' into return_routed_expert
ocss884 Dec 3, 2025
65471d5
Merge branch 'main' into return_routed_expert
ocss884 Dec 3, 2025
4c23438
fix a bug when dp attn + r3 + cuda graph
yizhang2077 Dec 8, 2025
aaca6e3
Merge branch 'main' into return_routed_expert
ocss884 Dec 8, 2025
823ea6b
small fix & support dpsk
ocss884 Dec 8, 2025
54ba922
rollback
ocss884 Dec 8, 2025
7d9a69d
add to test suite
ocss884 Dec 9, 2025
d49fe1a
Merge branch 'main' into return_routed_expert
ocss884 Dec 9, 2025
d8e2e34
Merge branch 'main' into return_routed_expert
zhaochenyang20 Dec 9, 2025
e8c276a
Merge branch 'main' into return_routed_expert
Kangyan-Zhou Dec 9, 2025
4457751
update test
ocss884 Dec 9, 2025
b4f4aaf
Merge branch 'main' into return_routed_expert
ocss884 Dec 9, 2025
5a3e735
fix test
ocss884 Dec 10, 2025
f9a442b
fix test
ocss884 Dec 10, 2025
676612f
Merge branch 'main' into return_routed_expert
ocss884 Dec 10, 2025
041dfcd
tiny fix for retract
yizhang2077 Dec 10, 2025
e6af67d
cleanup code
ocss884 Dec 11, 2025
067f185
bugfix
ocss884 Dec 11, 2025
be8ce97
bugfix
ocss884 Dec 11, 2025
09b203b
Merge branch 'main' into return_routed_expert
ocss884 Dec 12, 2025
017d0b9
Merge branch 'main' into return_routed_expert
ocss884 Dec 12, 2025
808d29f
Merge branch 'main' into return_routed_expert
hnyls2002 Dec 18, 2025
dbe9775
fix merge conflicts
hnyls2002 Dec 18, 2025
33fe8ee
fix glm4 moe
hnyls2002 Dec 18, 2025
5770aa2
remove out cache loc cpu
hnyls2002 Dec 18, 2025
2f580f5
clean up
hnyls2002 Dec 18, 2025
b32e693
tiny fix
hnyls2002 Dec 18, 2025
7da8cfa
split function out
hnyls2002 Dec 18, 2025
37a6c76
remove duplicate code
hnyls2002 Dec 18, 2025
0189d33
use b64 return
ocss884 Dec 19, 2025
ee2a780
Merge branch 'main' into return_routed_expert
ocss884 Dec 19, 2025
f486811
small bugfix
ocss884 Dec 19, 2025
ca4b3a4
fix test
ocss884 Dec 20, 2025
efee10b
Merge branch 'main' into return_routed_expert
ocss884 Dec 20, 2025
67210c2
lint
ocss884 Dec 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ async def async_request_sglang_generate(
"stream": not args.disable_stream,
"lora_path": request_func_input.lora_name,
"return_logprob": args.return_logprob,
"return_routed_experts": args.return_routed_experts,
"logprob_start_len": -1,
**request_func_input.extra_request_body,
}
Expand Down Expand Up @@ -2809,6 +2810,11 @@ def __call__(self, parser, namespace, values, option_string=None):
action="store_true",
help="Return logprob.",
)
parser.add_argument(
"--return-routed-experts",
action="store_true",
help="Return routed experts.",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--disable-ignore-eos",
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def generate(
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
return_routed_experts: bool = False,
stream: bool = False,
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
Expand Down Expand Up @@ -321,6 +322,7 @@ def generate(
lora_path=lora_path,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
return_routed_experts=return_routed_experts,
stream=stream,
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
Expand Down
289 changes: 289 additions & 0 deletions python/sglang/srt/layers/moe/routed_experts_capturer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import logging
from abc import ABC
from typing import Optional

import numpy as np
import pybase64
import torch

from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_dp_local_info,
is_dp_attention_enabled,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args

logger = logging.getLogger(__name__)

_GB = 1024 * 1024 * 1024
_MB = 1024 * 1024


def get_tensor_size_bytes(t: torch.Tensor):
return np.prod(t.shape) * t.dtype.itemsize


class _RoutedExpertsDeviceCache:
def __init__(
self,
max_running_requests: int,
num_hidden_layers: int,
num_experts_per_tok: int,
num_fused_shared_experts: int,
device: str,
) -> None:
self.buffer = torch.zeros(
(
max(
get_global_server_args().chunked_prefill_size
* get_global_server_args().dp_size,
max_running_requests,
),
num_hidden_layers,
num_experts_per_tok + num_fused_shared_experts,
),
dtype=torch.int32,
device=device,
)
self._finalize_allocation_log()

def get_buffer_size_bytes(self):
assert hasattr(self, "buffer")
return get_tensor_size_bytes(self.buffer)

def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor):
assert layer_id is not None, "capturing routing experts but get layer_id None"
batch, _ = topk_ids.shape
self.buffer[:batch, layer_id, :] = topk_ids

def _finalize_allocation_log(self):
"""Common logging and memory usage computation for captured experts buffers."""
buffer_size_MB = self.get_buffer_size_bytes() / _MB
logger.info(
f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB"
)


class _RoutedExpertsHostCache:
def __init__(
self,
num_tokens: int,
num_hidden_layers: int,
num_experts_per_tok: int,
) -> None:
self.num_tokens = num_tokens
self.buffer = torch.zeros(
(
num_tokens,
num_hidden_layers,
num_experts_per_tok,
),
dtype=torch.int32,
device="cpu",
pin_memory=True,
)
self._finalize_allocation_log()

def get_buffer_size_bytes(self):
assert hasattr(self, "buffer")
return get_tensor_size_bytes(self.buffer)

def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor):
self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True)

def _finalize_allocation_log(self):
"""Common logging and memory usage computation for captured experts buffers."""
buffer_size_GB = self.get_buffer_size_bytes() / _GB
logger.info(
f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB"
)


class RoutedExpertsCapturer(ABC):
@staticmethod
def create(
enable: bool,
model_config: ModelConfig,
num_fused_shared_experts: int,
num_tokens: int,
max_running_requests: int,
device: str,
):
if enable:
return _RoutedExpertsCapturerReal(
model_config,
num_tokens=num_tokens,
max_running_requests=max_running_requests,
num_fused_shared_experts=num_fused_shared_experts,
device=device,
)
else:
return _RoutedExpertsCapturerNoop()

def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
raise NotImplementedError

def capture(self, layer_id: int, topk_ids: torch.Tensor):
raise NotImplementedError

def get_routed_experts(
self,
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
):
raise NotImplementedError

def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch):
raise NotImplementedError

def get_host_cache(self):
raise NotImplementedError

def get_device_cache(self):
raise NotImplementedError


class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
"""Capturer for routed experts with host buffer"""

def __init__(
self,
model_config: ModelConfig,
num_tokens: int,
max_running_requests: int,
num_fused_shared_experts: int,
device: str,
):
self.num_fused_shared_experts = num_fused_shared_experts
self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers
self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok

self.host_cache = _RoutedExpertsHostCache(
num_tokens=num_tokens,
num_hidden_layers=self.num_hidden_layers,
num_experts_per_tok=self.num_experts_per_tok,
)

self.device_cache = _RoutedExpertsDeviceCache(
max_running_requests=max_running_requests,
num_hidden_layers=self.num_hidden_layers,
num_experts_per_tok=self.num_experts_per_tok,
num_fused_shared_experts=self.num_fused_shared_experts,
device=device,
)

def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
if is_dp_attention_enabled():
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
# handle with cuda graph padding
if can_run_graph:
local_start_pos = get_attention_dp_rank() * cuda_graph_batch
local_end_pos = local_start_pos + local_num_tokens
else:
local_end_pos = local_start_pos + local_num_tokens
else:
local_start_pos = 0
local_end_pos = forward_batch.out_cache_loc.shape[0]

# FIXME: sync explicitly here, overlap scheduler breaks here.
out_cache_loc_cpu = forward_batch.out_cache_loc.cpu()
self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[
local_start_pos:local_end_pos, :, : self.num_experts_per_tok
].cpu()

def capture(self, layer_id: int, topk_ids: torch.Tensor):
self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)

def get_routed_experts(
self,
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
):
cache_pool_idx = (
req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone()
)
return self.get_host_cache().buffer[cache_pool_idx]

def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch):
self._sync_fwd_experts_buffer_DtoH(
forward_batch=forward_batch,
can_run_graph=can_run_graph,
cuda_graph_batch=cuda_graph_batch,
)

def get_host_cache(self):
return self.host_cache

def get_device_cache(self):
return self.device_cache


class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer):
def __init__(self):
pass

def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
pass

def capture(self, layer_id: int, topk_ids: torch.Tensor):
pass

def get_routed_experts(
self,
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
):
pass

def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch):
pass

def get_host_cache(self):
pass

def get_device_cache(self):
pass


_global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop()


def get_global_experts_capturer():
return _global_expert_capturer


def set_global_experts_capturer(capturer: RoutedExpertsCapturer):
global _global_expert_capturer
_global_expert_capturer = capturer


def extract_routed_experts_from_meta_info(data):
# To solve the performance issue, we return the experts_ids in base64
# We left this function for user to change it back to normal int32
# See detokenizer_manager::_extract_routed_experts
routed_experts_base64 = data["meta_info"].get("routed_experts", None)
routed_experts = np.frombuffer(
pybase64.b64decode(routed_experts_base64.encode("utf-8")), dtype=np.int32
)
return routed_experts
12 changes: 11 additions & 1 deletion python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
Expand Down Expand Up @@ -203,6 +204,7 @@ def __init__(
self,
top_k: int,
*,
layer_id: Optional[int] = None,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
Expand All @@ -224,6 +226,7 @@ def __init__(
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None

self.layer_id = layer_id
self.topk_config = TopKConfig(
top_k=top_k,
use_grouped_topk=use_grouped_topk,
Expand Down Expand Up @@ -251,6 +254,7 @@ def forward_native(
self.topk_config.torch_native = True
return select_experts(
hidden_states=hidden_states,
layer_id=self.layer_id,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
Expand Down Expand Up @@ -300,6 +304,7 @@ def forward_cuda(
):
topk_output = select_experts(
hidden_states=hidden_states,
layer_id=self.layer_id,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
Expand All @@ -317,6 +322,7 @@ def forward_cpu(
) -> TopKOutput:
return select_experts(
hidden_states=hidden_states,
layer_id=self.layer_id,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
Expand Down Expand Up @@ -847,6 +853,7 @@ def select_experts(
router_logits: torch.Tensor,
topk_config: TopKConfig,
*,
layer_id: Optional[int] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> StandardTopKOutput:
Expand Down Expand Up @@ -974,7 +981,10 @@ def select_experts(
)

get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)

get_global_experts_capturer().capture(
layer_id=layer_id,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

about how to pass in layer_id: ExpertDistributionRecorder has a with_current_layer. what about refactoring into sth like

class CurrentLayerManager:
  def with_current_layer(layer_id): ...
  def get current layer id: ...

and let both ExpertDistributionRecorder and the new RoutedExpertsCapture to use current_layer_mgr.get_the_current_layer_id_value

topk_ids=topk_ids,
)
return StandardTopKOutput(topk_weights, topk_ids, router_logits)


Expand Down
Loading
Loading