Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/encode_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ def create_req(self, recv_req: TokenizedGenerateReqInput):
require_reasoning=recv_req.require_reasoning,
return_hidden_states=recv_req.return_hidden_states,
return_routed_experts=recv_req.return_routed_experts,
routed_experts_start_len=recv_req.routed_experts_start_len,
eos_token_ids=self.scheduler.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def generate(
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
return_routed_experts: bool = False,
routed_experts_start_len: int = 0,
stream: bool = False,
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
Expand Down Expand Up @@ -331,6 +332,7 @@ def generate(
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
return_routed_experts=return_routed_experts,
routed_experts_start_len=routed_experts_start_len,
stream=stream,
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
Expand Down Expand Up @@ -385,6 +387,7 @@ async def async_generate(
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
return_routed_experts: bool = False,
routed_experts_start_len: int = 0,
stream: bool = False,
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
Expand Down Expand Up @@ -420,6 +423,7 @@ async def async_generate(
lora_path=lora_path,
return_hidden_states=return_hidden_states,
return_routed_experts=return_routed_experts,
routed_experts_start_len=routed_experts_start_len,
stream=stream,
custom_logit_processor=custom_logit_processor,
bootstrap_host=bootstrap_host,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class CompletionRequest(BaseModel):
user: Optional[str] = None
return_hidden_states: bool = False
return_routed_experts: bool = False
routed_experts_start_len: int = 0
return_cached_tokens_details: bool = False

# Extra parameters for SRT backend only and will be ignored by OpenAI models.
Expand Down Expand Up @@ -588,6 +589,7 @@ class ChatCompletionRequest(BaseModel):
parallel_tool_calls: bool = True
return_hidden_states: bool = False
return_routed_experts: bool = False
routed_experts_start_len: int = 0
return_cached_tokens_details: bool = False
return_prompt_token_ids: bool = False
return_meta_info: bool = False
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def _convert_to_internal_request(
disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,
return_hidden_states=request.return_hidden_states,
return_routed_experts=request.return_routed_experts,
routed_experts_start_len=request.routed_experts_start_len,
rid=request.rid,
extra_key=self._compute_extra_key(request),
require_reasoning=self._get_reasoning_from_request(request),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _convert_to_internal_request(
disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,
return_hidden_states=request.return_hidden_states,
return_routed_experts=request.return_routed_experts,
routed_experts_start_len=request.routed_experts_start_len,
rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority,
Expand Down
10 changes: 9 additions & 1 deletion python/sglang/srt/layers/moe/routed_experts_capturer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def get_routed_experts(
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
start_len: int = 0,
):
raise NotImplementedError

Expand Down Expand Up @@ -275,9 +276,15 @@ def get_routed_experts(
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
start_len: int = 0,
):
if start_len < 0:
raise ValueError(f"{start_len=} must be non-negative")
start_len = min(start_len, seqlen - 1)
cache_pool_idx = (
req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone()
req_to_token_pool.req_to_token[req_pool_idx][start_len : seqlen - 1]
.cpu()
.clone()
)
return self.get_host_cache().buffer[cache_pool_idx]

Expand Down Expand Up @@ -325,6 +332,7 @@ def get_routed_experts(
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
start_len: int = 0,
):
pass

Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ class GenerateReqInput(BaseReq):
return_hidden_states: Union[List[bool], bool] = False
# Whether to return captured routed experts
return_routed_experts: bool = False
# The start location in the prompt for returning routed experts.
# Absolute start position for returned routings; response covers
# `[routed_experts_start_len, seqlen - 1)`. Must be in [0, prompt_tokens].
# 0 = full sequence.
routed_experts_start_len: int = 0

# The modalities of the image data [image, multi-images, video]
Expand Down Expand Up @@ -628,6 +630,7 @@ def __getitem__(self, i):
else self.return_hidden_states
),
return_routed_experts=self.return_routed_experts,
routed_experts_start_len=self.routed_experts_start_len,
modalities=self.modalities[i] if self.modalities else None,
session_params=self.session_params,
lora_path=self.lora_path[i] if self.lora_path is not None else None,
Expand Down Expand Up @@ -696,7 +699,7 @@ class TokenizedGenerateReqInput(BaseReq):

# Whether to return captured routed experts
return_routed_experts: bool = False
# The start location in the prompt for returning routed experts.
# See GenerateReqInput.routed_experts_start_len.
routed_experts_start_len: int = 0

# The input embeds
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def __init__(
require_reasoning: bool = False,
return_hidden_states: bool = False,
return_routed_experts: bool = False,
routed_experts_start_len: int = 0,
eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
Expand Down Expand Up @@ -789,6 +790,7 @@ def __init__(

# capture routed experts
self.return_routed_experts = return_routed_experts
self.routed_experts_start_len = routed_experts_start_len
self.routed_experts: Optional[torch.Tensor] = (
None # cpu tensor: shape (seqlen, topk)
)
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,7 @@ def handle_generate_request(
require_reasoning=recv_req.require_reasoning,
return_hidden_states=recv_req.return_hidden_states,
return_routed_experts=recv_req.return_routed_experts,
routed_experts_start_len=recv_req.routed_experts_start_len,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
Expand Down Expand Up @@ -1936,6 +1937,27 @@ def handle_generate_request(
self._add_request_to_queue(req)
return

if recv_req.return_routed_experts:
error_msg = None
if recv_req.routed_experts_start_len < 0:
error_msg = (
f"{recv_req.routed_experts_start_len=} is lower than 0. "
"Please use a non-negative routed_experts_start_len."
)

if recv_req.routed_experts_start_len > len(req.origin_input_ids):
error_msg = (
f"{recv_req.routed_experts_start_len=} is higher than the "
f"number of input tokens {len(req.origin_input_ids)=}. Please "
f"use a smaller routed_experts_start_len."
)

if error_msg is not None:
req.routed_experts_start_len = 0
req.set_finish_with_abort(error_msg)
self._add_request_to_queue(req)
return

Comment thread
ByronHsu marked this conversation as resolved.
added_to_grammar_queue = self.grammar_manager.process_req_with_grammar(req)
if not added_to_grammar_queue:
self._add_request_to_queue(req)
Expand Down
39 changes: 37 additions & 2 deletions python/sglang/srt/managers/scheduler_output_processor_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,48 @@ def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch):
self.stream_output(batch.reqs, batch.return_logprob)

def maybe_collect_routed_experts(self: Scheduler, req: Req):
"""Collect routed experts for a finished request."""
req.routed_experts = get_global_experts_capturer().get_routed_experts(
"""Collect routed experts for a finished request.

Returns immediately if `return_routed_experts` was not set on the
request, so non-opted-in reqs don't pay the host-gather cost.

Honors the caller's absolute start so the response covers
`[start_len, seqlen - 1)`. The default start_len is 0, which returns
the full sequence.

Logs a soft warning if the resulting tensor's row count differs from
the expected `seqlen - 1 - start_len`, to catch silent regressions.
"""
if not req.return_routed_experts:
return
capturer = get_global_experts_capturer()
if capturer is None:
return
start_len = req.routed_experts_start_len
req.routed_experts = capturer.get_routed_experts(
req_pool_idx=req.req_pool_idx,
seqlen=req.seqlen,
req_to_token_pool=self.req_to_token_pool,
start_len=start_len,
)

expected_rows = max(0, req.seqlen - 1 - start_len)
if (
req.routed_experts is not None
and req.routed_experts.shape[0] != expected_rows
):
logger.warning(
"routed_experts row-count mismatch for req %s: got %d, "
"expected %d (seqlen=%d, cached_tokens=%d, start_len=%s). "
"This indicates a silent bug.",
req.rid,
req.routed_experts.shape[0],
expected_rows,
req.seqlen,
req.cached_tokens,
req.routed_experts_start_len,
)
Comment on lines +130 to +140
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logger variable is used here but it does not appear to be imported or defined within this mixin file. This will likely result in a NameError at runtime when a row-count mismatch occurs. Please ensure logger is imported from sglang.srt.utils or defined at the module level.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is wrong. it is above


def maybe_collect_customized_info(
self: Scheduler, i: int, req: Req, logits_output: LogitsProcessorOutput
):
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/session_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def create_req(
require_reasoning=req.require_reasoning,
return_hidden_states=req.return_hidden_states,
return_routed_experts=req.return_routed_experts,
routed_experts_start_len=req.routed_experts_start_len,
priority=req.priority,
routing_key=req.routing_key,
extra_key=req.extra_key,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ def _create_tokenized_object(
require_reasoning=obj.require_reasoning,
return_hidden_states=obj.return_hidden_states,
return_routed_experts=obj.return_routed_experts,
routed_experts_start_len=obj.routed_experts_start_len,
routed_dp_rank=obj.routed_dp_rank,
disagg_prefill_dp_rank=obj.disagg_prefill_dp_rank,
priority=obj.priority,
Expand Down
Loading
Loading