Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,19 @@ def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
return config.index_n_heads


def get_num_indexer_layers(config) -> int:
"""Layer count for the global indexer-topk capturer's host buffer.

NSA models (V3.2) instantiate an Indexer on every transformer layer.
With index_topk_freq > 1 some layers reuse prev layer's topk; those still
get a slot (mirrored at the MLA call site). Other architectures: set
num_indexer_layers on hf_text_config; 0 disables the capturer.
"""
if is_deepseek_nsa(config):
return config.num_hidden_layers
return getattr(config, "num_indexer_layers", 0)


class ModelConfig:
def __init__(
self,
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ def process_batch_result_disagg_prefill(
if result.routed_experts_output is not None:
result.routed_experts_output.finalize()
result.routed_experts_output = None
if result.indexer_topk_output is not None:
result.indexer_topk_output.finalize()
result.indexer_topk_output = None

logprob_pt = 0
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Expand Down
103 changes: 103 additions & 0 deletions python/sglang/srt/layers/attention/indexer_topk_capturer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging
from typing import Optional

import numpy as np
import pybase64
import torch

from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.topk_capturer_base import BaseTopkCapturer

logger = logging.getLogger(__name__)


class IndexerTopkCapturer(BaseTopkCapturer):
def __init__(
self,
num_tokens: int,
num_indexer_layers: int,
index_topk: int,
max_running_requests: int,
device: str,
):
from sglang.srt.server_args import get_global_server_args

self.num_indexer_layers = num_indexer_layers
self.index_topk = index_topk

attn_tp_size = get_attention_tp_size()
assert attn_tp_size == 1, "IndexerTopkCapturer now only supports DP attention"
Comment thread
hnyls2002 marked this conversation as resolved.

# DP-attention capture is per-rank-local: each rank writes [:local_batch, ...]
# to its own device_cache, so the buffer only needs to fit one rank's batch.
server_args = get_global_server_args()
max_batch_size = max(server_args.chunked_prefill_size, max_running_requests)

super().__init__(
num_tokens=num_tokens,
max_batch_size=max_batch_size,
num_layers=self.num_indexer_layers,
topk_size=self.index_topk,
device=device,
name="indexer_topk",
)


_global_indexer_capturer: Optional[IndexerTopkCapturer] = None


def get_global_indexer_capturer() -> Optional[IndexerTopkCapturer]:
return _global_indexer_capturer


def set_global_indexer_capturer(capturer: Optional[IndexerTopkCapturer]):
global _global_indexer_capturer
_global_indexer_capturer = capturer


def maybe_capture_indexer_topk(
layer_id: int, topk_indices: Optional[torch.Tensor]
) -> Optional[torch.Tensor]:
"""Capture topk for layer_id if a capturer is set; pass through unchanged.

Works in both expression context (`return maybe_capture_indexer_topk(...)`)
and statement context (call for side-effect, ignore return).
"""
if topk_indices is None:
return None
if (cap := get_global_indexer_capturer()) is not None:
cap.capture(layer_id=layer_id, topk_indices=topk_indices)
return topk_indices


def extract_indexer_topk_from_meta_info(data):
# Mirrors extract_routed_experts_from_meta_info: indices are returned as
# base64-encoded int32 bytes. Caller reshapes to (seqlen-1, num_indexer_layers,
# index_topk).
indexer_topk_base64 = data["meta_info"].get("indexer_topk", None)
indexer_topk = np.frombuffer(
pybase64.b64decode(indexer_topk_base64.encode("utf-8")), dtype=np.int32
)
return indexer_topk


def create_indexer_capturer(
enable: bool,
num_indexer_layers: int,
index_topk: int,
num_tokens: int,
max_running_requests: int,
device: str,
) -> Optional[IndexerTopkCapturer]:
if not enable:
return None
if num_indexer_layers == 0:
logger.warning("No indexer layers found, IndexerTopkCapturer disabled")
return None
return IndexerTopkCapturer(
num_tokens=num_tokens,
num_indexer_layers=num_indexer_layers,
index_topk=index_topk,
max_running_requests=max_running_requests,
device=device,
)
42 changes: 27 additions & 15 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
fused_store_index_k_cache,
)
from sglang.srt.environ import envs
from sglang.srt.layers.attention.indexer_topk_capturer import (
maybe_capture_indexer_topk,
)
from sglang.srt.layers.dp_attention import attn_tp_all_gather_into_tensor
from sglang.srt.layers.layernorm import LayerNorm
from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype, is_fp8_fnuz
Expand Down Expand Up @@ -1121,15 +1124,18 @@ def forward_cuda(

# Optimization: fast path when skipping topk computation
if skip_logits_computation and (not self.nsa_enable_prefill_cp):
return self._forward_cuda_k_only(
x,
positions,
forward_batch,
return maybe_capture_indexer_topk(
layer_id,
act_quant,
enable_dual_stream,
metadata,
return_indices,
self._forward_cuda_k_only(
x,
positions,
forward_batch,
layer_id,
act_quant,
enable_dual_stream,
metadata,
return_indices,
),
)

if enable_dual_stream and forward_batch.forward_mode.is_decode_or_idle():
Expand Down Expand Up @@ -1227,11 +1233,14 @@ def forward_cuda(
# print(
# "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
# )
return torch.full(
(x_meta.shape[0], self.index_topk),
-1,
dtype=torch.int,
device=x_meta.device,
return maybe_capture_indexer_topk(
layer_id,
torch.full(
(x_meta.shape[0], self.index_topk),
-1,
dtype=torch.int,
device=x_meta.device,
),
)

if (
Expand Down Expand Up @@ -1281,7 +1290,10 @@ def forward_cuda(
kv_len_next,
actual_seq_q_next,
)
return torch.cat([topk_result_prev, topk_result_next], dim=0)
return maybe_capture_indexer_topk(
layer_id,
torch.cat([topk_result_prev, topk_result_next], dim=0),
)
else:
topk_result = self._get_topk_ragged(
enable_dual_stream,
Expand All @@ -1299,7 +1311,7 @@ def forward_cuda(
topk=self.index_topk,
layer_id=layer_id,
)
return topk_result
return maybe_capture_indexer_topk(layer_id, topk_result)

def forward_npu(
self,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
else []
)
routed_experts = self._b64_encode_per_request(recv_obj.routed_experts)
indexer_topk = self._b64_encode_per_request(recv_obj.indexer_topk)
return BatchStrOutput(
rids=recv_obj.rids,
http_worker_ipcs=recv_obj.http_worker_ipcs,
Expand Down Expand Up @@ -378,6 +379,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
output_token_entropy_val=recv_obj.output_token_entropy_val,
output_hidden_states=recv_obj.output_hidden_states,
routed_experts=routed_experts,
indexer_topk=indexer_topk,
customized_info=recv_obj.customized_info,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class GenerateReqInput(BaseReq):
return_hidden_states: Union[List[bool], bool] = False
# Whether to return captured routed experts
return_routed_experts: bool = False
return_indexer_topk: bool = False
# The start location in the prompt for returning routed experts.
routed_experts_start_len: int = 0

Expand Down Expand Up @@ -653,6 +654,7 @@ def __getitem__(self, i):
else self.return_hidden_states
),
return_routed_experts=self.return_routed_experts,
return_indexer_topk=self.return_indexer_topk,
modalities=self.modalities[i] if self.modalities else None,
session_params=self.session_params,
lora_path=self.lora_path[i] if self.lora_path is not None else None,
Expand Down Expand Up @@ -731,6 +733,8 @@ class TokenizedGenerateReqInput(BaseReq):
# The start location in the prompt for returning routed experts.
routed_experts_start_len: int = 0

return_indexer_topk: bool = False

# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None

Expand Down Expand Up @@ -1107,6 +1111,8 @@ class BatchTokenIDOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin):
# straight to TokenizerManager, which encodes on demand.
routed_experts: List[Optional[torch.Tensor]]

indexer_topk: List[Optional[torch.Tensor]]

# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
Expand Down Expand Up @@ -1170,6 +1176,8 @@ class BatchStrOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin):
# see BatchTokenIDOutput.routed_experts.
routed_experts: List[Optional[str]]

indexer_topk: List[Optional[str]]

# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/multi_tokenizer_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ def _handle_output_by_index(output, i):
routed_experts=_extract_field_by_index(
output, "routed_experts", i, check_length=False
),
indexer_topk=_extract_field_by_index(
output, "indexer_topk", i, check_length=False
),
customized_info=_extract_field_by_index(
output, "customized_info", i, check_length=False
),
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ def __init__(
require_reasoning: bool = False,
return_hidden_states: bool = False,
return_routed_experts: bool = False,
return_indexer_topk: bool = False,
eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
bootstrap_port: Optional[int] = None,
Expand Down Expand Up @@ -814,6 +815,11 @@ def __init__(
self.routed_experts: Optional[torch.Tensor] = (
None # cpu tensor: shape (seqlen, topk)
)

self.return_indexer_topk = return_indexer_topk
self.indexer_topk: Optional[torch.Tensor] = (
None # cpu tensor: shape (seqlen, num_indexer_layers, index_topk)
)
# Customized info
self.customized_info: Optional[Dict[str, List[Any]]] = None

Expand Down Expand Up @@ -1229,6 +1235,7 @@ def reset_for_retract(self):

self.prefix_indices = torch.empty((0,), dtype=torch.int64)
self.routed_experts = None
self.indexer_topk = None
self.last_node = None
self.cache_protected_len = 0
self.swa_uuid_for_lock = None
Expand Down Expand Up @@ -1469,6 +1476,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return captured experts
return_routed_experts: bool = False

return_indexer_topk: bool = False

# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False

Expand Down Expand Up @@ -1522,6 +1531,7 @@ def init_new(
spec_algorithm=spec_algorithm,
return_hidden_states=any(req.return_hidden_states for req in reqs),
return_routed_experts=any(req.return_routed_experts for req in reqs),
return_indexer_topk=any(req.return_indexer_topk for req in reqs),
is_prefill_only=all(req.is_prefill_only for req in reqs),
chunked_req=chunked_req,
dllm_config=dllm_config,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,7 @@ def handle_generate_request(
require_reasoning=recv_req.require_reasoning,
return_hidden_states=recv_req.return_hidden_states,
return_routed_experts=recv_req.return_routed_experts,
return_indexer_topk=recv_req.return_indexer_topk,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
Expand Down
Loading
Loading