diff --git a/docs_new/docs/basic_usage/openai_api_completions.ipynb b/docs_new/docs/basic_usage/openai_api_completions.ipynb index ffa576ae52c5..8d417dab465d 100644 --- a/docs_new/docs/basic_usage/openai_api_completions.ipynb +++ b/docs_new/docs/basic_usage/openai_api_completions.ipynb @@ -374,7 +374,7 @@ "source": [ "#### Returning Routed Experts (MoE Models)\n", "\n", - "For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`." + "For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. By default this returns `[0, seqlen - 1)`, the full available sequence, because RL workflows need routed experts for the full sequence. Set `routed_experts_start_len` in `extra_body` to an absolute prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks." ] }, { @@ -468,7 +468,7 @@ "source": [ "#### Returning Routed Experts (MoE Models)\n", "\n", - "For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`." + "For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. By default this returns `[0, seqlen - 1)`, the full available sequence, because RL workflows need routed experts for the full sequence. Set `routed_experts_start_len` in `extra_body` to an absolute prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks." ] }, { diff --git a/docs_new/docs/basic_usage/openai_api_completions.mdx b/docs_new/docs/basic_usage/openai_api_completions.mdx index c463fcca5ac7..0cc08ad8ad96 100644 --- a/docs_new/docs/basic_usage/openai_api_completions.mdx +++ b/docs_new/docs/basic_usage/openai_api_completions.mdx @@ -342,7 +342,7 @@ for chunk in stream: #### Returning Routed Experts (MoE Models) -For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. +For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. By default this returns `[0, seqlen - 1)`, the full available sequence, because RL workflows need routed experts for the full sequence. Set `routed_experts_start_len` in `extra_body` to an absolute prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks. ```python Example # Example with logit_bias parameter for completions API @@ -406,7 +406,7 @@ print_highlight(f"Response: {response}") #### Returning Routed Experts (MoE Models) -For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. +For MoE models, set `return_routed_experts: true` in `extra_body` to return expert routing data. Requires `--enable-return-routed-experts` server flag. The `routed_experts` field will be returned in the `sgl_ext` object on each choice, containing base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. By default this returns `[0, seqlen - 1)`, the full available sequence, because RL workflows need routed experts for the full sequence. Set `routed_experts_start_len` in `extra_body` to an absolute prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks. ## Structured Outputs (JSON, Regex, EBNF) diff --git a/docs_new/docs/basic_usage/sampling_params.mdx b/docs_new/docs/basic_usage/sampling_params.mdx index 4b271a229c9c..1e49a7783d75 100644 --- a/docs_new/docs/basic_usage/sampling_params.mdx +++ b/docs_new/docs/basic_usage/sampling_params.mdx @@ -107,7 +107,12 @@ The `/generate` endpoint accepts the following parameters in JSON format. For de return_routed_experts `bool = False` - Whether to return routed experts for MoE models. Requires `--enable-return-routed-experts` server flag. Returns base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. + Whether to return routed experts for MoE models. Requires `--enable-return-routed-experts` server flag. With the default `routed_experts_start_len=0`, returns the full available sequence `[0, seqlen - 1)` because RL workflows need routed experts for the full sequence. The result is base64-encoded int32 expert IDs as a flattened array with logical shape `[num_tokens, num_layers, top_k]`. + + + routed_experts_start_len + `int = 0` + If `return_routed_experts`, the absolute start position for returned routed-experts rows. `0` preserves the default full sequence; set it to an accumulated prefix length to return only `[routed_experts_start_len, seqlen - 1)`. For example, in multi-turn RL rollouts, routed experts for tokens from previous turns have already been collected, so setting this value avoids unnecessary transfer that cause bottlenecks. Must be in `[0, prompt_tokens]`. diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 1a5028ef9dbd..754d9ad18406 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -973,6 +973,7 @@ def create_req(self, recv_req: TokenizedGenerateReqInput): require_reasoning=recv_req.require_reasoning, return_hidden_states=recv_req.return_hidden_states, return_routed_experts=recv_req.return_routed_experts, + routed_experts_start_len=recv_req.routed_experts_start_len, eos_token_ids=self.scheduler.model_config.hf_eos_token_id, bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 7f44c88d85c0..a94f9d7ed26d 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -333,6 +333,7 @@ def generate( custom_logit_processor: Optional[Union[List[str], str]] = None, return_hidden_states: bool = False, return_routed_experts: bool = False, + routed_experts_start_len: int = 0, stream: bool = False, bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, @@ -369,6 +370,7 @@ def generate( custom_logit_processor=custom_logit_processor, return_hidden_states=return_hidden_states, return_routed_experts=return_routed_experts, + routed_experts_start_len=routed_experts_start_len, stream=stream, bootstrap_host=bootstrap_host, bootstrap_port=bootstrap_port, @@ -423,6 +425,7 @@ async def async_generate( custom_logit_processor: Optional[Union[List[str], str]] = None, return_hidden_states: bool = False, return_routed_experts: bool = False, + routed_experts_start_len: int = 0, stream: bool = False, bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, @@ -458,6 +461,7 @@ async def async_generate( lora_path=lora_path, return_hidden_states=return_hidden_states, return_routed_experts=return_routed_experts, + routed_experts_start_len=routed_experts_start_len, stream=stream, custom_logit_processor=custom_logit_processor, bootstrap_host=bootstrap_host, diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index a19744243f09..58a137b19e5e 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -285,6 +285,7 @@ class CompletionRequest(BaseModel): user: Optional[str] = None return_hidden_states: bool = False return_routed_experts: bool = False + routed_experts_start_len: int = 0 return_cached_tokens_details: bool = False # Extra parameters for SRT backend only and will be ignored by OpenAI models. @@ -632,6 +633,7 @@ class ChatCompletionRequest(BaseModel): parallel_tool_calls: bool = True return_hidden_states: bool = False return_routed_experts: bool = False + routed_experts_start_len: int = 0 return_cached_tokens_details: bool = False reasoning_effort: Optional[Literal["none", "low", "medium", "high", "max"]] = Field( default=None, diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 7bf381d629e4..26293fc27dcc 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -437,6 +437,7 @@ def _convert_to_internal_request( disagg_prefill_dp_rank=request.disagg_prefill_dp_rank, return_hidden_states=request.return_hidden_states, return_routed_experts=request.return_routed_experts, + routed_experts_start_len=request.routed_experts_start_len, rid=request.rid, extra_key=self._compute_extra_key(request), require_reasoning=self._get_reasoning_from_request(request), diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 8598620c4f75..eba04f463b1f 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -123,6 +123,7 @@ def _convert_to_internal_request( disagg_prefill_dp_rank=request.disagg_prefill_dp_rank, return_hidden_states=request.return_hidden_states, return_routed_experts=request.return_routed_experts, + routed_experts_start_len=request.routed_experts_start_len, rid=request.rid, extra_key=self._compute_extra_key(request), priority=request.priority, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 1722f75d6af3..33bfc1bc6051 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -174,7 +174,9 @@ class GenerateReqInput(BaseReq): # Whether to return captured routed experts return_routed_experts: bool = False return_indexer_topk: bool = False - # The start location in the prompt for returning routed experts. + # Absolute start position for returned routings; response covers + # `[routed_experts_start_len, seqlen - 1)`. Must be in [0, prompt_tokens]. + # 0 = full sequence. routed_experts_start_len: int = 0 # The modalities of the image data [image, multi-images, video] @@ -654,6 +656,7 @@ def __getitem__(self, i): else self.return_hidden_states ), return_routed_experts=self.return_routed_experts, + routed_experts_start_len=self.routed_experts_start_len, return_indexer_topk=self.return_indexer_topk, modalities=self.modalities[i] if self.modalities else None, session_params=self.session_params, @@ -730,7 +733,7 @@ class TokenizedGenerateReqInput(BaseReq): # Whether to return captured routed experts return_routed_experts: bool = False - # The start location in the prompt for returning routed experts. + # See GenerateReqInput.routed_experts_start_len. routed_experts_start_len: int = 0 return_indexer_topk: bool = False diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2450d49fb2a0..99f2744ee763 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -599,6 +599,7 @@ def __init__( require_reasoning: bool = False, return_hidden_states: bool = False, return_routed_experts: bool = False, + routed_experts_start_len: int = 0, return_indexer_topk: bool = False, eos_token_ids: Optional[Set[int]] = None, bootstrap_host: Optional[str] = None, @@ -818,6 +819,7 @@ def __init__( # capture routed experts self.return_routed_experts = return_routed_experts + self.routed_experts_start_len = routed_experts_start_len self.routed_experts: Optional[torch.Tensor] = ( None # cpu tensor: shape (seqlen, topk) ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6e662e77baac..8f4dfe996afb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2001,6 +2001,7 @@ def handle_generate_request( require_reasoning=recv_req.require_reasoning, return_hidden_states=recv_req.return_hidden_states, return_routed_experts=recv_req.return_routed_experts, + routed_experts_start_len=recv_req.routed_experts_start_len, return_indexer_topk=recv_req.return_indexer_topk, eos_token_ids=self.model_config.hf_eos_token_id, bootstrap_host=recv_req.bootstrap_host, @@ -2158,6 +2159,27 @@ def handle_generate_request( self._add_request_to_queue(req) return + if recv_req.return_routed_experts: + error_msg = None + if recv_req.routed_experts_start_len < 0: + error_msg = ( + f"{recv_req.routed_experts_start_len=} is lower than 0. " + "Please use a non-negative routed_experts_start_len." + ) + + if recv_req.routed_experts_start_len > len(req.origin_input_ids): + error_msg = ( + f"{recv_req.routed_experts_start_len=} is higher than the " + f"number of input tokens {len(req.origin_input_ids)=}. Please " + f"use a smaller routed_experts_start_len." + ) + + if error_msg is not None: + req.routed_experts_start_len = 0 + req.set_finish_with_abort(error_msg) + self._add_request_to_queue(req) + return + added_to_grammar_queue = self.grammar_manager.process_req_with_grammar(req) if not added_to_grammar_queue: self._add_request_to_queue(req) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index d572039e6fb2..997073ab8ca4 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -108,16 +108,48 @@ def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch): self.token_to_kv_pool_allocator.free_group_end() def maybe_collect_routed_experts(self: Scheduler, req: Req): - """Collect routed experts for a finished request.""" + """Collect routed experts for a finished request. + + Returns immediately if `return_routed_experts` was not set on the + request, so non-opted-in reqs don't pay the host-gather cost. + + Honors the caller's absolute start so the response covers + `[start_len, seqlen - 1)`. The default start_len is 0, which returns + the full sequence. + + Logs a soft warning if the resulting tensor's row count differs from + the expected `seqlen - 1 - start_len`, to catch silent regressions. + """ + if not req.return_routed_experts: + return capturer = get_global_experts_capturer() if capturer is None: return + start_len = req.routed_experts_start_len req.routed_experts = capturer.get_topk( req_pool_idx=req.req_pool_idx, seqlen=req.seqlen, req_to_token_pool=self.req_to_token_pool, + start_len=start_len, ) + expected_rows = max(0, req.seqlen - 1 - start_len) + if ( + req.routed_experts is not None + and req.routed_experts.shape[0] != expected_rows + ): + logger.warning( + "routed_experts row-count mismatch for req %s: got %d, " + "expected %d (seqlen=%d, cached_tokens=%d, start_len=%s). " + "This indicates a silent bug.", + req.rid, + req.routed_experts.shape[0], + expected_rows, + req.seqlen, + req.cached_tokens, + req.routed_experts_start_len, + ) + def maybe_collect_indexer_topk(self: Scheduler, req: Req): capturer = get_global_indexer_capturer() if capturer is None: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1eea93d2eb65..0c368367bdf9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1018,6 +1018,7 @@ def _create_tokenized_object( require_reasoning=obj.require_reasoning, return_hidden_states=obj.return_hidden_states, return_routed_experts=obj.return_routed_experts, + routed_experts_start_len=obj.routed_experts_start_len, return_indexer_topk=obj.return_indexer_topk, routed_dp_rank=obj.routed_dp_rank, disagg_prefill_dp_rank=obj.disagg_prefill_dp_rank, diff --git a/python/sglang/srt/session/session_controller.py b/python/sglang/srt/session/session_controller.py index ce98514b8568..1915c370a540 100644 --- a/python/sglang/srt/session/session_controller.py +++ b/python/sglang/srt/session/session_controller.py @@ -233,6 +233,7 @@ def create_req( require_reasoning=req.require_reasoning, return_hidden_states=req.return_hidden_states, return_routed_experts=req.return_routed_experts, + routed_experts_start_len=req.routed_experts_start_len, priority=req.priority, routing_key=req.routing_key, extra_key=req.extra_key, diff --git a/python/sglang/srt/state_capturer/base.py b/python/sglang/srt/state_capturer/base.py index 8620804e7b6f..3931c431573b 100644 --- a/python/sglang/srt/state_capturer/base.py +++ b/python/sglang/srt/state_capturer/base.py @@ -147,10 +147,16 @@ def get_topk( req_pool_idx: int, seqlen: int, req_to_token_pool: ReqToTokenPool, + start_len: int = 0, ) -> torch.Tensor: - cache_pool_idx = req_to_token_pool.req_to_token[req_pool_idx][ - : seqlen - 1 - ].cpu() + if start_len < 0: + raise ValueError(f"{start_len=} must be non-negative") + start_len = min(start_len, seqlen - 1) + cache_pool_idx = ( + req_to_token_pool.req_to_token[req_pool_idx][start_len : seqlen - 1] + .cpu() + .clone() + ) return self.host_cache.buffer[cache_pool_idx] def on_forward_end( diff --git a/test/registered/rl/test_return_routed_experts.py b/test/registered/rl/test_return_routed_experts.py index 8caa66795c5e..33643207a052 100644 --- a/test/registered/rl/test_return_routed_experts.py +++ b/test/registered/rl/test_return_routed_experts.py @@ -5,6 +5,8 @@ from typing import List import aiohttp +import numpy as np +import requests import torch from torch.nn.utils.rnn import pad_sequence @@ -15,6 +17,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.test_utils import ( + DEFAULT_ENABLE_ROUTED_EXPERTS_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -32,6 +35,9 @@ SHAREGPT_FILENAME = "ShareGPT_V3_unfiltered_cleaned_split.json" logger = logging.getLogger(__name__) +_QWEN3_30B_A3B_NUM_LAYERS = 48 +_QWEN3_30B_A3B_TOPK = 8 + class TestReturnRoutedExperts(CustomTestCase): """End-to-end check that --enable-return-routed-experts stays correct @@ -263,5 +269,195 @@ def compare_baseline_w_reference(baseline, reference): return num_total_mismatches +class TestRoutedExpertsStartLen(CustomTestCase): + """Verify the `routed_experts_start_len` parameter: + + - default (0) returns the full sequence + - explicit start_len crops the response and the cropped tail matches + the corresponding tail of the full response + """ + + MAX_NEW_TOKENS = 8 + + @classmethod + def setUpClass(cls): + cls.process = popen_launch_server( + DEFAULT_ENABLE_ROUTED_EXPERTS_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-return-routed-experts", + "--enable-deterministic-inference", + "--tp", + 2, + ], + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def _send(self, payload: dict) -> dict: + resp = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", json=payload, timeout=120 + ) + return resp + + def _build_payload(self, **extra) -> dict: + payload = { + "text": "User: Tell me a fact about cats.\nAssistant:", + "sampling_params": { + "temperature": 0, + "max_new_tokens": self.MAX_NEW_TOKENS, + "ignore_eos": True, + }, + "return_routed_experts": True, + } + payload.update(extra) + return payload + + def _routed_experts(self, resp_json: dict): + return extract_routed_experts_from_meta_info(resp_json).reshape( + -1, _QWEN3_30B_A3B_NUM_LAYERS, _QWEN3_30B_A3B_TOPK + ) + + def _seqlen(self, resp_json: dict) -> int: + meta = resp_json["meta_info"] + return meta["prompt_tokens"] + meta["completion_tokens"] + + def test_start_len_zero_is_default(self): + """Omitting the field must match `routed_experts_start_len=0`, + which returns the full sequence (start_len=0).""" + resp_default = self._send(self._build_payload()).json() + resp_zero = self._send(self._build_payload(routed_experts_start_len=0)).json() + + rows_default = self._routed_experts(resp_default) + rows_zero = self._routed_experts(resp_zero) + + seqlen_default = self._seqlen(resp_default) + seqlen_zero = self._seqlen(resp_zero) + self.assertEqual(seqlen_default, seqlen_zero) + self.assertEqual(rows_default.shape[0], seqlen_default - 1) + self.assertEqual(rows_zero.shape[0], seqlen_zero - 1) + self.assertTrue( + np.array_equal(rows_default, rows_zero), + "default and explicit 0 must produce identical routed experts", + ) + + def test_start_len_controls_row_count(self): + """`routed_experts_start_len=N` must return `seqlen - 1 - N` rows + and the returned tail must match the corresponding tail of the + full sequence (start_len omitted).""" + full_resp = self._send(self._build_payload()).json() + full_rows = self._routed_experts(full_resp) + seqlen = self._seqlen(full_resp) + self.assertEqual(full_rows.shape[0], seqlen - 1) + + start_len = max(1, full_resp["meta_info"]["prompt_tokens"] // 2) + + cropped_resp = self._send( + self._build_payload(routed_experts_start_len=start_len) + ).json() + cropped_rows = self._routed_experts(cropped_resp) + cropped_seqlen = self._seqlen(cropped_resp) + + self.assertEqual(seqlen, cropped_seqlen) + expected_rows = seqlen - 1 - start_len + self.assertEqual( + cropped_rows.shape[0], + expected_rows, + f"expected {expected_rows} rows, got {cropped_rows.shape[0]}", + ) + self.assertTrue( + np.array_equal(full_rows[start_len:], cropped_rows), + "cropped routed experts must match the tail of the full sequence", + ) + + def test_start_len_exceeds_prompt_tokens_aborts(self): + """`routed_experts_start_len > prompt_tokens` must abort the request: + the caller cannot meaningfully reference positions that don't exist + in the prompt yet.""" + baseline = self._send(self._build_payload()).json() + prompt_tokens = baseline["meta_info"]["prompt_tokens"] + + ok = self._send(self._build_payload(routed_experts_start_len=prompt_tokens)) + self.assertEqual( + ok.status_code, + 200, + f"start_len=={prompt_tokens} should pass, got {ok.text}", + ) + + too_big = self._send( + self._build_payload(routed_experts_start_len=prompt_tokens + 1) + ) + self._assert_aborted(too_big, "is higher than the number of input tokens") + + def test_start_len_with_cache_hit(self): + """`start_len` must allow the radix prefix to extend past it. The + first request seeds the cache; the second sends the same prompt + with `start_len` somewhere inside the prompt. We verify: + + - meta_info.cached_tokens > start_len (would be impossible if a + cap forced the prefix match to <= start_len), + - the response row count still equals `seqlen - 1 - start_len`. + """ + cache_salt = "cache-hit-test" + first = self._send(self._build_payload(extra_key=cache_salt)).json() + self.assertEqual( + first["meta_info"].get("cached_tokens", 0), + 0, + "first request must be a cold miss", + ) + + prompt_tokens = first["meta_info"]["prompt_tokens"] + start_len = max(1, prompt_tokens // 2) + second = self._send( + self._build_payload( + extra_key=cache_salt, + routed_experts_start_len=start_len, + ) + ).json() + + cached = second["meta_info"].get("cached_tokens", 0) + self.assertGreater( + cached, + start_len, + f"expected radix prefix past start_len={start_len}, " + f"got cached_tokens={cached} (cap not removed?)", + ) + + rows = self._routed_experts(second) + expected = self._seqlen(second) - 1 - start_len + self.assertEqual( + rows.shape[0], + expected, + f"expected {expected} rows, got {rows.shape[0]}", + ) + + def _assert_aborted(self, resp, expected_substring: str): + """Assert a request was aborted with `expected_substring` in the + error message.""" + if resp.status_code == 200: + body = resp.json() + meta = body.get("meta_info", {}) + finish_reason = meta.get("finish_reason") or {} + message = ( + str(finish_reason.get("message", "")) + + " " + + str(body.get("text", "")) + + " " + + str(body.get("error", "")) + ) + self.assertIn( + expected_substring, + message, + f"expected abort with '{expected_substring}', got body={body}", + ) + else: + self.assertGreaterEqual(resp.status_code, 400) + self.assertIn(expected_substring, resp.text) + + if __name__ == "__main__": unittest.main()