diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a1beb497d62d..ca494730b3d8 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -129,6 +129,19 @@ def get_nsa_index_n_heads(config: PretrainedConfig) -> int: return config.index_n_heads +def get_num_indexer_layers(config) -> int: + """Layer count for the global indexer-topk capturer's host buffer. + + NSA models (V3.2) instantiate an Indexer on every transformer layer. + With index_topk_freq > 1 some layers reuse prev layer's topk; those still + get a slot (mirrored at the MLA call site). Other architectures: set + num_indexer_layers on hf_text_config; 0 disables the capturer. + """ + if is_deepseek_nsa(config): + return config.num_hidden_layers + return getattr(config, "num_indexer_layers", 0) + + class ModelConfig: def __init__( self, diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 533265ae72a4..1619e3e28292 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -478,6 +478,9 @@ def process_batch_result_disagg_prefill( if result.routed_experts_output is not None: result.routed_experts_output.finalize() result.routed_experts_output = None + if result.indexer_topk_output is not None: + result.indexer_topk_output.finalize() + result.indexer_topk_output = None logprob_pt = 0 # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue diff --git a/python/sglang/srt/layers/attention/indexer_topk_capturer.py b/python/sglang/srt/layers/attention/indexer_topk_capturer.py new file mode 100644 index 000000000000..b186575e8901 --- /dev/null +++ b/python/sglang/srt/layers/attention/indexer_topk_capturer.py @@ -0,0 +1,103 @@ +import logging +from typing import Optional + +import numpy as np +import pybase64 +import torch + +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.topk_capturer_base import BaseTopkCapturer + +logger = logging.getLogger(__name__) + + +class IndexerTopkCapturer(BaseTopkCapturer): + def __init__( + self, + num_tokens: int, + num_indexer_layers: int, + index_topk: int, + max_running_requests: int, + device: str, + ): + from sglang.srt.server_args import get_global_server_args + + self.num_indexer_layers = num_indexer_layers + self.index_topk = index_topk + + attn_tp_size = get_attention_tp_size() + assert attn_tp_size == 1, "IndexerTopkCapturer now only supports DP attention" + + # DP-attention capture is per-rank-local: each rank writes [:local_batch, ...] + # to its own device_cache, so the buffer only needs to fit one rank's batch. + server_args = get_global_server_args() + max_batch_size = max(server_args.chunked_prefill_size, max_running_requests) + + super().__init__( + num_tokens=num_tokens, + max_batch_size=max_batch_size, + num_layers=self.num_indexer_layers, + topk_size=self.index_topk, + device=device, + name="indexer_topk", + ) + + +_global_indexer_capturer: Optional[IndexerTopkCapturer] = None + + +def get_global_indexer_capturer() -> Optional[IndexerTopkCapturer]: + return _global_indexer_capturer + + +def set_global_indexer_capturer(capturer: Optional[IndexerTopkCapturer]): + global _global_indexer_capturer + _global_indexer_capturer = capturer + + +def maybe_capture_indexer_topk( + layer_id: int, topk_indices: Optional[torch.Tensor] +) -> Optional[torch.Tensor]: + """Capture topk for layer_id if a capturer is set; pass through unchanged. + + Works in both expression context (`return maybe_capture_indexer_topk(...)`) + and statement context (call for side-effect, ignore return). + """ + if topk_indices is None: + return None + if (cap := get_global_indexer_capturer()) is not None: + cap.capture(layer_id=layer_id, topk_indices=topk_indices) + return topk_indices + + +def extract_indexer_topk_from_meta_info(data): + # Mirrors extract_routed_experts_from_meta_info: indices are returned as + # base64-encoded int32 bytes. Caller reshapes to (seqlen-1, num_indexer_layers, + # index_topk). + indexer_topk_base64 = data["meta_info"].get("indexer_topk", None) + indexer_topk = np.frombuffer( + pybase64.b64decode(indexer_topk_base64.encode("utf-8")), dtype=np.int32 + ) + return indexer_topk + + +def create_indexer_capturer( + enable: bool, + num_indexer_layers: int, + index_topk: int, + num_tokens: int, + max_running_requests: int, + device: str, +) -> Optional[IndexerTopkCapturer]: + if not enable: + return None + if num_indexer_layers == 0: + logger.warning("No indexer layers found, IndexerTopkCapturer disabled") + return None + return IndexerTopkCapturer( + num_tokens=num_tokens, + num_indexer_layers=num_indexer_layers, + index_topk=index_topk, + max_running_requests=max_running_requests, + device=device, + ) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 66d83fa86177..3eedb663ea69 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -12,6 +12,9 @@ fused_store_index_k_cache, ) from sglang.srt.environ import envs +from sglang.srt.layers.attention.indexer_topk_capturer import ( + maybe_capture_indexer_topk, +) from sglang.srt.layers.dp_attention import attn_tp_all_gather_into_tensor from sglang.srt.layers.layernorm import LayerNorm from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype, is_fp8_fnuz @@ -1121,15 +1124,18 @@ def forward_cuda( # Optimization: fast path when skipping topk computation if skip_logits_computation and (not self.nsa_enable_prefill_cp): - return self._forward_cuda_k_only( - x, - positions, - forward_batch, + return maybe_capture_indexer_topk( layer_id, - act_quant, - enable_dual_stream, - metadata, - return_indices, + self._forward_cuda_k_only( + x, + positions, + forward_batch, + layer_id, + act_quant, + enable_dual_stream, + metadata, + return_indices, + ), ) if enable_dual_stream and forward_batch.forward_mode.is_decode_or_idle(): @@ -1227,11 +1233,14 @@ def forward_cuda( # print( # "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result" # ) - return torch.full( - (x_meta.shape[0], self.index_topk), - -1, - dtype=torch.int, - device=x_meta.device, + return maybe_capture_indexer_topk( + layer_id, + torch.full( + (x_meta.shape[0], self.index_topk), + -1, + dtype=torch.int, + device=x_meta.device, + ), ) if ( @@ -1281,7 +1290,10 @@ def forward_cuda( kv_len_next, actual_seq_q_next, ) - return torch.cat([topk_result_prev, topk_result_next], dim=0) + return maybe_capture_indexer_topk( + layer_id, + torch.cat([topk_result_prev, topk_result_next], dim=0), + ) else: topk_result = self._get_topk_ragged( enable_dual_stream, @@ -1299,7 +1311,7 @@ def forward_cuda( topk=self.index_topk, layer_id=layer_id, ) - return topk_result + return maybe_capture_indexer_topk(layer_id, topk_result) def forward_npu( self, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 9497e57ff34e..9b507601c555 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -349,6 +349,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): else [] ) routed_experts = self._b64_encode_per_request(recv_obj.routed_experts) + indexer_topk = self._b64_encode_per_request(recv_obj.indexer_topk) return BatchStrOutput( rids=recv_obj.rids, http_worker_ipcs=recv_obj.http_worker_ipcs, @@ -378,6 +379,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): output_token_entropy_val=recv_obj.output_token_entropy_val, output_hidden_states=recv_obj.output_hidden_states, routed_experts=routed_experts, + indexer_topk=indexer_topk, customized_info=recv_obj.customized_info, placeholder_tokens_idx=None, placeholder_tokens_val=None, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 3a1838792e3d..a573e672a8ae 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -177,6 +177,7 @@ class GenerateReqInput(BaseReq): return_hidden_states: Union[List[bool], bool] = False # Whether to return captured routed experts return_routed_experts: bool = False + return_indexer_topk: bool = False # The start location in the prompt for returning routed experts. routed_experts_start_len: int = 0 @@ -653,6 +654,7 @@ def __getitem__(self, i): else self.return_hidden_states ), return_routed_experts=self.return_routed_experts, + return_indexer_topk=self.return_indexer_topk, modalities=self.modalities[i] if self.modalities else None, session_params=self.session_params, lora_path=self.lora_path[i] if self.lora_path is not None else None, @@ -731,6 +733,8 @@ class TokenizedGenerateReqInput(BaseReq): # The start location in the prompt for returning routed experts. routed_experts_start_len: int = 0 + return_indexer_topk: bool = False + # The input embeds input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None @@ -1107,6 +1111,8 @@ class BatchTokenIDOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): # straight to TokenizerManager, which encodes on demand. routed_experts: List[Optional[torch.Tensor]] + indexer_topk: List[Optional[torch.Tensor]] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. @@ -1170,6 +1176,8 @@ class BatchStrOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): # see BatchTokenIDOutput.routed_experts. routed_experts: List[Optional[str]] + indexer_topk: List[Optional[str]] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 7a5dcfff6295..0620cb3d6b42 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -272,6 +272,9 @@ def _handle_output_by_index(output, i): routed_experts=_extract_field_by_index( output, "routed_experts", i, check_length=False ), + indexer_topk=_extract_field_by_index( + output, "indexer_topk", i, check_length=False + ), customized_info=_extract_field_by_index( output, "customized_info", i, check_length=False ), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ee6c1c77b795..6f6b404757ce 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -595,6 +595,7 @@ def __init__( require_reasoning: bool = False, return_hidden_states: bool = False, return_routed_experts: bool = False, + return_indexer_topk: bool = False, eos_token_ids: Optional[Set[int]] = None, bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, @@ -814,6 +815,11 @@ def __init__( self.routed_experts: Optional[torch.Tensor] = ( None # cpu tensor: shape (seqlen, topk) ) + + self.return_indexer_topk = return_indexer_topk + self.indexer_topk: Optional[torch.Tensor] = ( + None # cpu tensor: shape (seqlen, num_indexer_layers, index_topk) + ) # Customized info self.customized_info: Optional[Dict[str, List[Any]]] = None @@ -1229,6 +1235,7 @@ def reset_for_retract(self): self.prefix_indices = torch.empty((0,), dtype=torch.int64) self.routed_experts = None + self.indexer_topk = None self.last_node = None self.cache_protected_len = 0 self.swa_uuid_for_lock = None @@ -1469,6 +1476,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Whether to return captured experts return_routed_experts: bool = False + return_indexer_topk: bool = False + # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False @@ -1522,6 +1531,7 @@ def init_new( spec_algorithm=spec_algorithm, return_hidden_states=any(req.return_hidden_states for req in reqs), return_routed_experts=any(req.return_routed_experts for req in reqs), + return_indexer_topk=any(req.return_indexer_topk for req in reqs), is_prefill_only=all(req.is_prefill_only for req in reqs), chunked_req=chunked_req, dllm_config=dllm_config, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 56a752b75907..22710db85e21 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1998,6 +1998,7 @@ def handle_generate_request( require_reasoning=recv_req.require_reasoning, return_hidden_states=recv_req.return_hidden_states, return_routed_experts=recv_req.return_routed_experts, + return_indexer_topk=recv_req.return_indexer_topk, eos_token_ids=self.model_config.hf_eos_token_id, bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 76983add6d53..22cfc475e785 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -7,6 +7,9 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs +from sglang.srt.layers.attention.indexer_topk_capturer import ( + get_global_indexer_capturer, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer from sglang.srt.managers.io_struct import ( @@ -115,6 +118,16 @@ def maybe_collect_routed_experts(self: Scheduler, req: Req): req_to_token_pool=self.req_to_token_pool, ) + def maybe_collect_indexer_topk(self: Scheduler, req: Req): + capturer = get_global_indexer_capturer() + if capturer is None: + return + req.indexer_topk = capturer.get_topk( + req_pool_idx=req.req_pool_idx, + seqlen=req.seqlen, + req_to_token_pool=self.req_to_token_pool, + ) + def maybe_collect_customized_info( self: Scheduler, i: int, req: Req, logits_output: LogitsProcessorOutput ): @@ -146,6 +159,9 @@ def process_batch_result_prefill( if result.routed_experts_output is not None: result.routed_experts_output.finalize() result.routed_experts_output = None + if result.indexer_topk_output is not None: + result.indexer_topk_output.finalize() + result.indexer_topk_output = None ( logits_output, @@ -204,6 +220,7 @@ def process_batch_result_prefill( req.check_finished() if req.finished(): self.maybe_collect_routed_experts(req) + self.maybe_collect_indexer_topk(req) release_kv_cache(req, self.tree_cache) req.time_stats.set_completion_time() elif not batch.decoding_reqs or req not in batch.decoding_reqs: @@ -407,6 +424,9 @@ def process_batch_result_decode( if result.routed_experts_output is not None: result.routed_experts_output.finalize() result.routed_experts_output = None + if result.indexer_topk_output is not None: + result.indexer_topk_output.finalize() + result.indexer_topk_output = None logits_output, next_token_ids, can_run_cuda_graph = ( result.logits_output, @@ -574,6 +594,7 @@ def _handle_finished_req( if req.multimodal_inputs is not None and req.session is None: req.multimodal_inputs.release_features() self.maybe_collect_routed_experts(req) + self.maybe_collect_indexer_topk(req) if self.server_args.disaggregation_decode_enable_offload_kvcache: # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes @@ -980,6 +1001,7 @@ def stream_output_generation( output_hidden_states = None load = self.get_loads(GetLoadsReqInput(include=["core"])) routed_experts = None + indexer_topk = None customized_info = {} time_stats = [] @@ -1163,6 +1185,10 @@ def stream_output_generation( if routed_experts is None: routed_experts = [] routed_experts.append(req.routed_experts) + if req.return_indexer_topk: + if indexer_topk is None: + indexer_topk = [] + indexer_topk.append(req.indexer_topk) if req.customized_info is not None: for k, v in req.customized_info.items(): @@ -1219,6 +1245,7 @@ def stream_output_generation( output_token_entropy_val=None, output_hidden_states=output_hidden_states, routed_experts=routed_experts, + indexer_topk=indexer_topk, customized_info=customized_info, placeholder_tokens_idx=None, placeholder_tokens_val=None, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 49a7f86ed459..a3b266731319 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1013,6 +1013,7 @@ def _create_tokenized_object( require_reasoning=obj.require_reasoning, return_hidden_states=obj.return_hidden_states, return_routed_experts=obj.return_routed_experts, + return_indexer_topk=obj.return_indexer_topk, routed_dp_rank=obj.routed_dp_rank, disagg_prefill_dp_rank=obj.disagg_prefill_dp_rank, priority=obj.priority, @@ -1710,6 +1711,12 @@ async def _handle_batch_output( if isinstance(val, torch.Tensor): val = pybase64.b64encode(val.numpy().tobytes()).decode("utf-8") meta_info["routed_experts"] = val + if getattr(recv_obj, "indexer_topk", None): + val = recv_obj.indexer_topk[i] + if val is not None: + if isinstance(val, torch.Tensor): + val = pybase64.b64encode(val.numpy().tobytes()).decode("utf-8") + meta_info["indexer_topk"] = val if getattr(recv_obj, "customized_info", None): for k, v in recv_obj.customized_info.items(): meta_info[k] = v[i] diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index efb7f7563b09..ace1f19504be 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -478,6 +478,7 @@ def forward_batch_generation( can_run_cuda_graph=can_run_cuda_graph, expert_distribution_metrics=out.expert_distribution_metrics, routed_experts_output=out.routed_experts_output, + indexer_topk_output=out.indexer_topk_output, ) if is_verify: diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index 9d34ae31df20..3f2911afc273 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -49,6 +49,7 @@ class GenerationBatchResult: # Routed experts: pending async D2H for overlap scheduling routed_experts_output: Optional[TopkCaptureOutput] = None + indexer_topk_output: Optional[TopkCaptureOutput] = None # metrics expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None @@ -94,6 +95,9 @@ def copy_to_cpu(self, return_logprob: bool): if self.routed_experts_output is not None: self.routed_experts_output.copy_to_cpu() + if self.indexer_topk_output is not None: + self.indexer_topk_output.copy_to_cpu() + if (x := self.expert_distribution_metrics) is not None: x.copy_to_cpu() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 565e20bfe8fe..4b70e4da0775 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -55,7 +55,12 @@ from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.linear_attn_model_registry import get_linear_attn_config from sglang.srt.configs.load_config import LoadConfig, LoadFormat -from sglang.srt.configs.model_config import AttentionArch, ModelConfig, ModelImpl +from sglang.srt.configs.model_config import ( + AttentionArch, + ModelConfig, + ModelImpl, + get_num_indexer_layers, +) from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.debug_utils.dumper import dumper @@ -105,6 +110,11 @@ ATTENTION_BACKENDS, attn_backend_wrapper, ) +from sglang.srt.layers.attention.indexer_topk_capturer import ( + create_indexer_capturer, + get_global_indexer_capturer, + set_global_indexer_capturer, +) from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp from sglang.srt.layers.attention.tbo_backend import TboAttnBackend from sglang.srt.layers.dp_attention import ( @@ -306,6 +316,7 @@ class ModelRunnerOutput: can_run_graph: bool expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None routed_experts_output: Optional[TopkCaptureOutput] = None + indexer_topk_output: Optional[TopkCaptureOutput] = None class ModelRunner(ModelRunnerKVCacheMixin): @@ -750,6 +761,8 @@ def initialize(self, pre_model_load_memory: float): # Init routed experts capturer self.init_routed_experts_capturer() + self.init_indexer_capturer() + # TODO: Refactor device-specific init branches into platform interface (separate PR). # Must be called BEFORE init_device_graphs() so CUDA graph capture # runs with aux hidden state capture enabled. @@ -822,6 +835,33 @@ def init_routed_experts_capturer(self): ) ) + def init_indexer_capturer(self): + enable = get_global_server_args().enable_return_indexer_topk + # Producer wiring is CUDA-only (Indexer.forward_cuda + MLA skip_topk + # path); other backends would create a capturer but never feed it. + if enable and self.device != "cuda": + logger.warning( + "indexer-topk capture is CUDA-only; %s backend not yet wired. " + "Disabling capturer.", + self.device, + ) + set_global_indexer_capturer(None) + return + + hf_text_config = self.model_config.hf_text_config + num_indexer_layers = get_num_indexer_layers(hf_text_config) + index_topk = getattr(hf_text_config, "index_topk", 0) + set_global_indexer_capturer( + create_indexer_capturer( + enable=enable, + num_indexer_layers=num_indexer_layers, + index_topk=index_topk, + num_tokens=self.max_total_num_tokens + self.page_size, + max_running_requests=self.max_running_requests, + device=self.device, + ) + ) + def init_aux_hidden_state_capture(self): """Configure auxiliary hidden state capture for speculative decoding. @@ -3227,6 +3267,14 @@ def forward( no_copy_to_cpu=no_copy_to_cpu, ) + if (indexer_capturer := get_global_indexer_capturer()) is not None: + output.indexer_topk_output = indexer_capturer.on_forward_end( + forward_batch=forward_batch, + can_run_graph=output.can_run_graph, + cuda_graph_batch=getattr(self.graph_runner, "bs", None), + no_copy_to_cpu=no_copy_to_cpu, + ) + if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() diff --git a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py index 26fef866bdd9..df8b114d1529 100644 --- a/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py +++ b/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py @@ -6,6 +6,9 @@ from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph from sglang.srt.layers import deep_gemm_wrapper +from sglang.srt.layers.attention.indexer_topk_capturer import ( + maybe_capture_indexer_topk, +) from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp from sglang.srt.layers.communicator import get_attn_tp_context from sglang.srt.layers.quantization.fp8_kernel import ( @@ -205,7 +208,11 @@ def forward_absorb_prepare( layer_id=self.layer_id, ) else: - topk_indices = prev_topk_indices + # skip_topk reuses prev layer's indices; mirror into this + # layer's slot so the captured buffer matches what's used. + topk_indices = maybe_capture_indexer_topk( + self.layer_id, prev_topk_indices + ) current_stream.wait_stream(self.alt_stream) else: k_nope = k_nope.unsqueeze(1) @@ -220,7 +227,9 @@ def forward_absorb_prepare( layer_id=self.layer_id, ) else: - topk_indices = prev_topk_indices + topk_indices = maybe_capture_indexer_topk( + self.layer_id, prev_topk_indices + ) else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0a554e40d022..5691ee71efc4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -693,6 +693,7 @@ class ServerArgs: keep_mm_feature_on_device: bool = False enable_return_hidden_states: bool = False enable_return_routed_experts: bool = False + enable_return_indexer_topk: bool = False scheduler_recv_interval: int = 1 numa_node: Optional[List[int]] = None enable_deterministic_inference: bool = False @@ -6266,6 +6267,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable returning routed experts of each layer with responses.", ) + parser.add_argument( + "--enable-return-indexer-topk", + action="store_true", + help="Enable returning indexer topk indices of layers with indexer with responses.", + ) parser.add_argument( "--scheduler-recv-interval", type=int, diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 3fa43aa37b0d..81552057669b 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -893,6 +893,7 @@ def verify(self, batch: ModelWorkerBatch): next_draft_input=next_draft_input, accept_lens=accept_lens, routed_experts_output=forward_batch_output.routed_experts_output, + indexer_topk_output=forward_batch_output.indexer_topk_output, ) def _mamba_verify_update( diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py index 9a5937beb49b..9eac638eeab1 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py @@ -790,6 +790,7 @@ def verify( next_draft_input=next_draft_input, accept_lens=accept_lens, routed_experts_output=forward_batch_output.routed_experts_output, + indexer_topk_output=forward_batch_output.indexer_topk_output, ) def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): diff --git a/test/registered/8-gpu-models/test_return_indexer_topk.py b/test/registered/8-gpu-models/test_return_indexer_topk.py new file mode 100644 index 000000000000..e7a45c6e8fec --- /dev/null +++ b/test/registered/8-gpu-models/test_return_indexer_topk.py @@ -0,0 +1,151 @@ +import asyncio +import logging +import unittest + +import aiohttp +import numpy as np + +from sglang.srt.layers.attention.indexer_topk_capturer import ( + extract_indexer_topk_from_meta_info, +) +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=600, suite="stage-c-test-8-gpu-h200") + +DEEPSEEK_V32_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2" + +# V3.2 config — hardcoded for response decoding (mirrors test_return_routed_experts.py). +NUM_INDEXER_LAYERS = 61 +INDEX_TOPK = 2048 + +# index_topk_freq=2 → layers 2,4,6,... reuse layer L-1's topk; exercises the +# forward_mla.py skip_topk capture path. +INDEX_TOPK_FREQ = 2 + +logger = logging.getLogger(__name__) + + +class TestReturnIndexerTopk(CustomTestCase): + """Indexer-topk capture e2e test for DSv3.2 (NSA). + + Single server with `--enable-return-indexer-topk` and `index_topk_freq=2`. + Validates the native `/generate` endpoint only — OpenAI-protocol surface + (`SglExt.indexer_topk`) not yet wired up; follow-up PR. + + Per response, validates: + 1. Captured tensor decodes to (seqlen-1, num_indexer_layers, index_topk). + 2. Indices are positional sentinels in [-1, +inf); -1 marks padding. + 3. With freq=2, layers L in {2,4,6,...} byte-equal layer L-1's slot — + regression-protects the skip_topk capture path in forward_mla.py. + """ + + @classmethod + def setUpClass(cls): + cls.other_args = [ + "--trust-remote-code", + "--tp", + "8", + "--dp", + "8", + "--enable-dp-attention", + "--enable-return-indexer-topk", + # Cap KV pool so the indexer-topk host buffer (488 KB / token for + # V3.2) stays bounded; with the default ~600k tokens × 8 procs the + # pinned allocation runs into TB-scale and OOMs the CI host. + "--max-total-tokens", + "32768", + "--model-loader-extra-config", + '{"enable_multithread_load": true, "num_threads": 64}', + "--json-model-override-args", + f'{{"index_topk_freq": {INDEX_TOPK_FREQ}}}', + ] + cls.sampling_args = {"temperature": 0, "max_new_tokens": 16} + cls.texts = [ + "What is the capital of France?", + "Solve: 2 + 3 = ?", + ] + cls.process = popen_launch_server( + DEEPSEEK_V32_MODEL_PATH, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.other_args, + ) + try: + cls.captured = asyncio.run(cls._collect_async()) + except Exception: + kill_process_tree(cls.process.pid) + raise + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_indexer_topk_generate(self): + for topk in self.captured: + self._check_shape_and_range(topk) + self._check_skip_topk_equality(topk) + + def _check_shape_and_range(self, topk: np.ndarray): + self.assertEqual(topk.ndim, 3) + seqlen_minus_1, num_layers, topk_size = topk.shape + self.assertGreater(seqlen_minus_1, 0) + self.assertEqual(num_layers, NUM_INDEXER_LAYERS) + self.assertEqual(topk_size, INDEX_TOPK) + # Indices are token positions; valid values are >= -1 (-1 = padding sentinel). + self.assertTrue((topk >= -1).all(), f"min index {topk.min()} < -1") + + def _check_skip_topk_equality(self, topk: np.ndarray): + """Layers L in {2, 4, 6, ...} must byte-equal layer L-1 with freq=2. + + With `skip_topk = max(layer_id - 1, 0) % freq != 0`, freq=2 yields + skip=True for L >= 2 with L-1 odd → L even (>= 2). The forward_mla.py + skip-path mirrors prev_topk_indices into layer L's slot. + """ + for L in range(2, NUM_INDEXER_LAYERS, 2): + np.testing.assert_array_equal( + topk[:, L, :], + topk[:, L - 1, :], + err_msg=f"layer {L} should reuse layer {L - 1}'s topk (skip_topk path)", + ) + + @classmethod + async def _collect_async(cls): + async with aiohttp.ClientSession() as session: + tasks = [ + asyncio.create_task( + make_request( + session, + f"{DEFAULT_URL_FOR_TEST}/generate", + { + "text": text, + "sampling_params": cls.sampling_args, + "return_indexer_topk": True, + }, + ) + ) + for text in cls.texts + ] + http_results = await asyncio.gather(*tasks) + # Reshape raw int32 bytes into (seqlen-1, num_indexer_layers, index_topk). + return [ + extract_indexer_topk_from_meta_info(res).reshape( + -1, NUM_INDEXER_LAYERS, INDEX_TOPK + ) + for res in http_results + ] + + +async def make_request(session, url, payload): + async with session.post(url=url, json=payload) as response: + return await response.json() + + +if __name__ == "__main__": + unittest.main()