Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs_new/docs/basic_usage/openai_api_completions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@
"source": [
"#### Returning Routed Experts (MoE Models)\n",
"\n",
"For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`."
"For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. By default this returns `[0, seqlen - 1)`, the full available sequence, because RL workflows need routed experts for the full sequence. Set `routed_experts_start_len` in `extra_body` to an absolute prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks."
]
},
{
Expand Down Expand Up @@ -468,7 +468,7 @@
"source": [
"#### Returning Routed Experts (MoE Models)\n",
"\n",
"For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`."
"For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. By default this returns `[0, seqlen - 1)`, the full available sequence, because RL workflows need routed experts for the full sequence. Set `routed_experts_start_len` in `extra_body` to an absolute prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks."
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs_new/docs/basic_usage/openai_api_completions.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ for chunk in stream:

#### Returning Routed Experts (MoE Models)

For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`.
For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. By default this returns `[0, seqlen - 1)`, the full available sequence, because RL workflows need routed experts for the full sequence. Set `routed_experts_start_len` in `extra_body` to an absolute prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks.

```python Example
# Example with logit_bias parameter for completions API
Expand Down Expand Up @@ -406,7 +406,7 @@ print_highlight(f"Response: {response}")

#### Returning Routed Experts (MoE Models)

For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`.
For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. By default this returns `[0, seqlen - 1)`, the full available sequence, because RL workflows need routed experts for the full sequence. Set `routed_experts_start_len` in `extra_body` to an absolute prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks.

## Structured Outputs (JSON, Regex, EBNF)

Expand Down
7 changes: 6 additions & 1 deletion docs_new/docs/basic_usage/sampling_params.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ The `/generate` endpoint accepts the following parameters in JSON format. For de
<tr>
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}>return_routed_experts</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>`bool = False`</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}>Whether to return routed experts for MoE models. Requires `--enable-return-routed-experts` server flag. Returns base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`.</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}>Whether to return routed experts for MoE models. Requires `--enable-return-routed-experts` server flag. With the default `routed_experts_start_len=0`, returns the full available sequence `[0, seqlen - 1)` because RL workflows need routed experts for the full sequence. The result is base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`.</td>
</tr>
<tr>
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}>routed_experts_start_len</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>`int = 0`</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}>If `return_routed_experts`, the absolute start position for returned routed-experts rows. `0` preserves the default full sequence; set it to an accumulated prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks. Must be in `[0, prompt_tokens]`.</td>
</tr>
</tbody>
</table>
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/encode_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ def create_req(self, recv_req: TokenizedGenerateReqInput):
require_reasoning=recv_req.require_reasoning,
return_hidden_states=recv_req.return_hidden_states,
return_routed_experts=recv_req.return_routed_experts,
routed_experts_start_len=recv_req.routed_experts_start_len,
eos_token_ids=self.scheduler.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def generate(
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
return_routed_experts: bool = False,
routed_experts_start_len: int = 0,
stream: bool = False,
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
Expand Down Expand Up @@ -369,6 +370,7 @@ def generate(
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
return_routed_experts=return_routed_experts,
routed_experts_start_len=routed_experts_start_len,
stream=stream,
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
Expand Down Expand Up @@ -423,6 +425,7 @@ async def async_generate(
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
return_routed_experts: bool = False,
routed_experts_start_len: int = 0,
stream: bool = False,
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
Expand Down Expand Up @@ -458,6 +461,7 @@ async def async_generate(
lora_path=lora_path,
return_hidden_states=return_hidden_states,
return_routed_experts=return_routed_experts,
routed_experts_start_len=routed_experts_start_len,
stream=stream,
custom_logit_processor=custom_logit_processor,
bootstrap_host=bootstrap_host,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ class CompletionRequest(BaseModel):
user: Optional[str] = None
return_hidden_states: bool = False
return_routed_experts: bool = False
routed_experts_start_len: int = 0
return_cached_tokens_details: bool = False

# Extra parameters for SRT backend only and will be ignored by OpenAI models.
Expand Down Expand Up @@ -632,6 +633,7 @@ class ChatCompletionRequest(BaseModel):
parallel_tool_calls: bool = True
return_hidden_states: bool = False
return_routed_experts: bool = False
routed_experts_start_len: int = 0
return_cached_tokens_details: bool = False
reasoning_effort: Optional[Literal["none", "low", "medium", "high", "max"]] = Field(
default=None,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def _convert_to_internal_request(
disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,
return_hidden_states=request.return_hidden_states,
return_routed_experts=request.return_routed_experts,
routed_experts_start_len=request.routed_experts_start_len,
rid=request.rid,
extra_key=self._compute_extra_key(request),
require_reasoning=self._get_reasoning_from_request(request),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _convert_to_internal_request(
disagg_prefill_dp_rank=request.disagg_prefill_dp_rank,
return_hidden_states=request.return_hidden_states,
return_routed_experts=request.return_routed_experts,
routed_experts_start_len=request.routed_experts_start_len,
rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority,
Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ class GenerateReqInput(BaseReq):
# Whether to return captured routed experts
return_routed_experts: bool = False
return_indexer_topk: bool = False
# The start location in the prompt for returning routed experts.
# Absolute start position for returned routings; response covers
# `[routed_experts_start_len, seqlen - 1)`. Must be in [0, prompt_tokens].
# 0 = full sequence.
routed_experts_start_len: int = 0

# The modalities of the image data [image, multi-images, video]
Expand Down Expand Up @@ -654,6 +656,7 @@ def __getitem__(self, i):
else self.return_hidden_states
),
return_routed_experts=self.return_routed_experts,
routed_experts_start_len=self.routed_experts_start_len,
return_indexer_topk=self.return_indexer_topk,
modalities=self.modalities[i] if self.modalities else None,
session_params=self.session_params,
Expand Down Expand Up @@ -730,7 +733,7 @@ class TokenizedGenerateReqInput(BaseReq):

# Whether to return captured routed experts
return_routed_experts: bool = False
# The start location in the prompt for returning routed experts.
# See GenerateReqInput.routed_experts_start_len.
routed_experts_start_len: int = 0

return_indexer_topk: bool = False
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def __init__(
require_reasoning: bool = False,
return_hidden_states: bool = False,
return_routed_experts: bool = False,
routed_experts_start_len: int = 0,
return_indexer_topk: bool = False,
eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
Expand Down Expand Up @@ -818,6 +819,7 @@ def __init__(

# capture routed experts
self.return_routed_experts = return_routed_experts
self.routed_experts_start_len = routed_experts_start_len
self.routed_experts: Optional[torch.Tensor] = (
None # cpu tensor: shape (seqlen, topk)
)
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,7 @@ def handle_generate_request(
require_reasoning=recv_req.require_reasoning,
return_hidden_states=recv_req.return_hidden_states,
return_routed_experts=recv_req.return_routed_experts,
routed_experts_start_len=recv_req.routed_experts_start_len,
return_indexer_topk=recv_req.return_indexer_topk,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
Expand Down Expand Up @@ -2158,6 +2159,27 @@ def handle_generate_request(
self._add_request_to_queue(req)
return

if recv_req.return_routed_experts:
error_msg = None
if recv_req.routed_experts_start_len < 0:
error_msg = (
f"{recv_req.routed_experts_start_len=} is lower than 0. "
"Please use a non-negative routed_experts_start_len."
)

if recv_req.routed_experts_start_len > len(req.origin_input_ids):
error_msg = (
f"{recv_req.routed_experts_start_len=} is higher than the "
f"number of input tokens {len(req.origin_input_ids)=}. Please "
f"use a smaller routed_experts_start_len."
)

if error_msg is not None:
req.routed_experts_start_len = 0
req.set_finish_with_abort(error_msg)
self._add_request_to_queue(req)
return

added_to_grammar_queue = self.grammar_manager.process_req_with_grammar(req)
if not added_to_grammar_queue:
self._add_request_to_queue(req)
Expand Down
34 changes: 33 additions & 1 deletion python/sglang/srt/managers/scheduler_output_processor_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,48 @@ def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch):
self.token_to_kv_pool_allocator.free_group_end()

def maybe_collect_routed_experts(self: Scheduler, req: Req):
"""Collect routed experts for a finished request."""
"""Collect routed experts for a finished request.

Returns immediately if `return_routed_experts` was not set on the
request, so non-opted-in reqs don't pay the host-gather cost.

Honors the caller's absolute start so the response covers
`[start_len, seqlen - 1)`. The default start_len is 0, which returns
the full sequence.

Logs a soft warning if the resulting tensor's row count differs from
the expected `seqlen - 1 - start_len`, to catch silent regressions.
"""
if not req.return_routed_experts:
return
capturer = get_global_experts_capturer()
if capturer is None:
return
start_len = req.routed_experts_start_len
req.routed_experts = capturer.get_topk(
req_pool_idx=req.req_pool_idx,
seqlen=req.seqlen,
req_to_token_pool=self.req_to_token_pool,
start_len=start_len,
)

expected_rows = max(0, req.seqlen - 1 - start_len)
if (
req.routed_experts is not None
and req.routed_experts.shape[0] != expected_rows
):
logger.warning(
"routed_experts row-count mismatch for req %s: got %d, "
"expected %d (seqlen=%d, cached_tokens=%d, start_len=%s). "
"This indicates a silent bug.",
req.rid,
req.routed_experts.shape[0],
expected_rows,
req.seqlen,
req.cached_tokens,
req.routed_experts_start_len,
)

def maybe_collect_indexer_topk(self: Scheduler, req: Req):
capturer = get_global_indexer_capturer()
if capturer is None:
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,7 @@ def _create_tokenized_object(
require_reasoning=obj.require_reasoning,
return_hidden_states=obj.return_hidden_states,
return_routed_experts=obj.return_routed_experts,
routed_experts_start_len=obj.routed_experts_start_len,
return_indexer_topk=obj.return_indexer_topk,
routed_dp_rank=obj.routed_dp_rank,
disagg_prefill_dp_rank=obj.disagg_prefill_dp_rank,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/session/session_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def create_req(
require_reasoning=req.require_reasoning,
return_hidden_states=req.return_hidden_states,
return_routed_experts=req.return_routed_experts,
routed_experts_start_len=req.routed_experts_start_len,
priority=req.priority,
routing_key=req.routing_key,
extra_key=req.extra_key,
Expand Down
12 changes: 9 additions & 3 deletions python/sglang/srt/state_capturer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,16 @@ def get_topk(
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
start_len: int = 0,
) -> torch.Tensor:
cache_pool_idx = req_to_token_pool.req_to_token[req_pool_idx][
: seqlen - 1
].cpu()
if start_len < 0:
raise ValueError(f"{start_len=} must be non-negative")
start_len = min(start_len, seqlen - 1)
cache_pool_idx = (
req_to_token_pool.req_to_token[req_pool_idx][start_len : seqlen - 1]
.cpu()
.clone()
)
return self.host_cache.buffer[cache_pool_idx]

def on_forward_end(
Expand Down
Loading
Loading