diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index a178f6f628f0..c895eb516bd5 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -586,6 +586,7 @@ async def async_request_sglang_generate( "lora_path": request_func_input.lora_name, "return_logprob": args.return_logprob, "return_routed_experts": args.return_routed_experts, + "return_dsa_topk_indices": args.return_dsa_topk_indices, "logprob_start_len": -1, **request_func_input.extra_request_body, } @@ -3028,6 +3029,11 @@ def __call__(self, parser, namespace, values, option_string=None): action="store_true", help="Return routed experts.", ) + parser.add_argument( + "--return-dsa-topk-indices", + action="store_true", + help="Return DSA topk indices.", + ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--disable-ignore-eos", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 667923c91ea2..e0cbeb5ed8e5 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -226,6 +226,7 @@ def generate( custom_logit_processor: Optional[Union[List[str], str]] = None, return_hidden_states: bool = False, return_routed_experts: bool = False, + return_dsa_topk_indices: bool = False, stream: bool = False, bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, @@ -263,6 +264,7 @@ def generate( custom_logit_processor=custom_logit_processor, return_hidden_states=return_hidden_states, return_routed_experts=return_routed_experts, + return_dsa_topk_indices=return_dsa_topk_indices, stream=stream, bootstrap_host=bootstrap_host, bootstrap_port=bootstrap_port, diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index b464f3f8d728..d6c46ad608d2 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -1239,6 +1239,7 @@ def forward_extend( if envs.SGLANG_NSA_FUSE_TOPK.get(): page_table_1 = topk_indices else: + # zianglih: non fused path, need to transform the topk indices to indices to the page table (page_size = 1) if topk_transform_method == TopkTransformMethod.RAGGED: topk_indices_offset = metadata.topk_indices_offset assert topk_indices_offset is not None @@ -1376,6 +1377,7 @@ def forward_decode( if envs.SGLANG_NSA_FUSE_TOPK.get(): page_table_1 = topk_indices else: + # zianglih: non fused path, need to transform the topk indices to indices to the page table (page_size = 1) page_table_1 = transform_index_page_table_decode( page_table=metadata.page_table_1, topk_indices=topk_indices, diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py index 00bd68755587..b6ec3fc7736d 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -34,6 +34,8 @@ def __init__( num_experts_per_tok: int, num_fused_shared_experts: int, device: str, + enable_capture_dsa_topk_indices: bool = False, + num_dsa_topk_indices: int = 2048, ) -> None: self.buffer = torch.zeros( ( @@ -48,23 +50,56 @@ def __init__( dtype=torch.int32, device=device, ) + self.dsa_topk_indices_buffer = None + if enable_capture_dsa_topk_indices: + self.dsa_topk_indices_buffer = torch.zeros( + ( + max( + get_global_server_args().chunked_prefill_size + * get_global_server_args().dp_size, + max_running_requests, + ), + num_hidden_layers, + num_dsa_topk_indices, + ), + dtype=torch.int32, + device=device, + ) self._finalize_allocation_log() def get_buffer_size_bytes(self): assert hasattr(self, "buffer") return get_tensor_size_bytes(self.buffer) + def get_dsa_topk_indices_buffer_size_bytes(self): + assert hasattr(self, "dsa_topk_indices_buffer") + return get_tensor_size_bytes(self.dsa_topk_indices_buffer) + def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): assert layer_id is not None, "capturing routing experts but get layer_id None" batch, _ = topk_ids.shape self.buffer[:batch, layer_id, :] = topk_ids + def capture_fwd_dsa_topk_indices( + self, layer_id: int, dsa_topk_indices: torch.Tensor + ): + assert layer_id is not None, "capturing dsa topk indices but get layer_id None" + batch, _ = dsa_topk_indices.shape + self.dsa_topk_indices_buffer[:batch, layer_id, :] = dsa_topk_indices + def _finalize_allocation_log(self): """Common logging and memory usage computation for captured experts buffers.""" buffer_size_MB = self.get_buffer_size_bytes() / _MB logger.info( f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB" ) + if self.dsa_topk_indices_buffer is not None: + dsa_topk_indices_size_MB = ( + self.get_dsa_topk_indices_buffer_size_bytes() / _MB + ) + logger.info( + f"DSA topk indices device buffer allocated. #shape: {tuple(self.dsa_topk_indices_buffer.shape)}, size: {dsa_topk_indices_size_MB:.2f} MB" + ) class _RoutedExpertsHostCache: @@ -73,6 +108,8 @@ def __init__( num_tokens: int, num_hidden_layers: int, num_experts_per_tok: int, + enable_capture_dsa_topk_indices: bool = False, + num_dsa_topk_indices: int = 2048, ) -> None: self.num_tokens = num_tokens self.buffer = torch.zeros( @@ -85,12 +122,28 @@ def __init__( device="cpu", pin_memory=True, ) + self.dsa_topk_indices_buffer = None + if enable_capture_dsa_topk_indices: + self.dsa_topk_indices_buffer = torch.zeros( + ( + num_tokens, + num_hidden_layers, + num_dsa_topk_indices, + ), + dtype=torch.int32, + device="cpu", + pin_memory=True, + ) self._finalize_allocation_log() def get_buffer_size_bytes(self): assert hasattr(self, "buffer") return get_tensor_size_bytes(self.buffer) + def get_dsa_topk_indices_buffer_size_bytes(self): + assert hasattr(self, "dsa_topk_indices_buffer") + return get_tensor_size_bytes(self.dsa_topk_indices_buffer) + def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) @@ -100,6 +153,13 @@ def _finalize_allocation_log(self): logger.info( f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB" ) + if self.dsa_topk_indices_buffer is not None: + dsa_topk_indices_size_GB = ( + self.get_dsa_topk_indices_buffer_size_bytes() / _GB + ) + logger.info( + f"DSA topk indices host buffer allocated. #tokens: {self.num_tokens}, size: {dsa_topk_indices_size_GB:.2f} GB" + ) class RoutedExpertsCapturer(ABC): @@ -111,6 +171,7 @@ def create( num_tokens: int, max_running_requests: int, device: str, + enable_capture_dsa_topk_indices: bool = False, ): if enable: return _RoutedExpertsCapturerReal( @@ -119,6 +180,7 @@ def create( max_running_requests=max_running_requests, num_fused_shared_experts=num_fused_shared_experts, device=device, + enable_capture_dsa_topk_indices=enable_capture_dsa_topk_indices, ) else: return _RoutedExpertsCapturerNoop() @@ -134,6 +196,9 @@ def _sync_fwd_experts_buffer_DtoH( def capture(self, layer_id: int, topk_ids: torch.Tensor): raise NotImplementedError + def capture_dsa_topk_indices(self, layer_id: int, dsa_topk_indices: torch.Tensor): + raise NotImplementedError + def get_routed_experts( self, req_pool_idx: int, @@ -142,6 +207,14 @@ def get_routed_experts( ): raise NotImplementedError + def get_dsa_topk_indices( + self, + req_pool_idx: int, + seqlen: int, + req_to_token_pool: ReqToTokenPool, + ): + raise NotImplementedError + def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): raise NotImplementedError @@ -162,15 +235,24 @@ def __init__( max_running_requests: int, num_fused_shared_experts: int, device: str, + enable_capture_dsa_topk_indices: bool = False, ): self.num_fused_shared_experts = num_fused_shared_experts self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok + self.enable_capture_dsa_topk_indices = enable_capture_dsa_topk_indices + if enable_capture_dsa_topk_indices: + self.num_dsa_topk_indices = model_config.hf_text_config.index_topk + else: + self.num_dsa_topk_indices = None + self.host_cache = _RoutedExpertsHostCache( num_tokens=num_tokens, num_hidden_layers=self.num_hidden_layers, num_experts_per_tok=self.num_experts_per_tok, + enable_capture_dsa_topk_indices=self.enable_capture_dsa_topk_indices, + num_dsa_topk_indices=self.num_dsa_topk_indices, ) self.device_cache = _RoutedExpertsDeviceCache( @@ -179,6 +261,8 @@ def __init__( num_experts_per_tok=self.num_experts_per_tok, num_fused_shared_experts=self.num_fused_shared_experts, device=device, + enable_capture_dsa_topk_indices=self.enable_capture_dsa_topk_indices, + num_dsa_topk_indices=self.num_dsa_topk_indices, ) def _sync_fwd_experts_buffer_DtoH( @@ -205,9 +289,20 @@ def _sync_fwd_experts_buffer_DtoH( local_start_pos:local_end_pos, :, : self.num_experts_per_tok ].cpu() + if self.enable_capture_dsa_topk_indices: + self.host_cache.dsa_topk_indices_buffer[out_cache_loc_cpu] = ( + self.device_cache.dsa_topk_indices_buffer[ + local_start_pos:local_end_pos, :, : self.num_dsa_topk_indices + ].cpu() + ) + def capture(self, layer_id: int, topk_ids: torch.Tensor): self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) + def capture_dsa_topk_indices(self, layer_id: int, dsa_topk_indices: torch.Tensor): + if self.enable_capture_dsa_topk_indices: + self.device_cache.capture_fwd_dsa_topk_indices(layer_id, dsa_topk_indices) + def get_routed_experts( self, req_pool_idx: int, @@ -219,6 +314,18 @@ def get_routed_experts( ) return self.get_host_cache().buffer[cache_pool_idx] + def get_dsa_topk_indices( + self, + req_pool_idx: int, + seqlen: int, + req_to_token_pool: ReqToTokenPool, + ): + if self.enable_capture_dsa_topk_indices: + cache_pool_idx = ( + req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() + ) + return self.get_host_cache().dsa_topk_indices_buffer[cache_pool_idx] + def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): self._sync_fwd_experts_buffer_DtoH( forward_batch=forward_batch, @@ -248,6 +355,9 @@ def _sync_fwd_experts_buffer_DtoH( def capture(self, layer_id: int, topk_ids: torch.Tensor): pass + def capture_dsa_topk_indices(self, layer_id: int, dsa_topk_indices: torch.Tensor): + pass + def get_routed_experts( self, req_pool_idx: int, @@ -256,6 +366,14 @@ def get_routed_experts( ): pass + def get_dsa_topk_indices( + self, + req_pool_idx: int, + seqlen: int, + req_to_token_pool: ReqToTokenPool, + ): + pass + def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): pass diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 33cbacfa2629..6033d3979747 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -328,6 +328,7 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): def _extract_routed_experts(self, recv_obj: BatchTokenIDOutput) -> List[List[int]]: output_routed_experts = None + output_dsa_topk_indices = None if recv_obj.output_routed_experts is not None: output_routed_experts = [ ( @@ -339,7 +340,18 @@ def _extract_routed_experts(self, recv_obj: BatchTokenIDOutput) -> List[List[int ) for output_routed_experts in recv_obj.output_routed_experts ] - return output_routed_experts + if recv_obj.output_dsa_topk_indices is not None: + output_dsa_topk_indices = [ + ( + pybase64.b64encode( + output_dsa_topk_indices.numpy().tobytes() + ).decode("utf-8") + if output_dsa_topk_indices is not None + else [] + ) + for output_dsa_topk_indices in recv_obj.output_dsa_topk_indices + ] + return output_routed_experts, output_dsa_topk_indices def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): # If handling idle batch, set output_strs to []. @@ -348,7 +360,9 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): if len(recv_obj.rids) > 0 else [] ) - output_routed_experts = self._extract_routed_experts(recv_obj) + output_routed_experts, output_dsa_topk_indices = self._extract_routed_experts( + recv_obj + ) return BatchStrOutput( rids=recv_obj.rids, @@ -376,6 +390,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): output_token_entropy_val=recv_obj.output_token_entropy_val, output_hidden_states=recv_obj.output_hidden_states, output_routed_experts=output_routed_experts, + output_dsa_topk_indices=output_dsa_topk_indices, customized_info=recv_obj.customized_info, placeholder_tokens_idx=None, placeholder_tokens_val=None, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 18cc3d2aa636..13bf2ce0bd33 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -202,6 +202,8 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): return_hidden_states: Union[List[bool], bool] = False # Whether to return captured routed experts return_routed_experts: bool = False + # Whether to return dsa topk indices + return_dsa_topk_indices: bool = False # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None @@ -637,6 +639,7 @@ def __getitem__(self, i): else self.return_hidden_states ), return_routed_experts=self.return_routed_experts, + return_dsa_topk_indices=self.return_dsa_topk_indices, modalities=self.modalities[i] if self.modalities else None, session_params=self.session_params, lora_path=self.lora_path[i] if self.lora_path is not None else None, @@ -710,6 +713,9 @@ class TokenizedGenerateReqInput(BaseReq): # Whether to return captured routed experts return_routed_experts: bool = False + # Whether to return dsa topk indices + return_dsa_topk_indices: bool = False + # The input embeds input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None @@ -984,6 +990,9 @@ class BatchTokenIDOutput( # The routed experts for each output token output_routed_experts: List[torch.Tensor] + # The dsa topk indices for each output token + output_dsa_topk_indices: List[torch.Tensor] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. @@ -1071,6 +1080,9 @@ class BatchStrOutput( # The routed experts for each output token output_routed_experts: List[List[int]] + # The dsa topk indices for each output token + output_dsa_topk_indices: List[List[int]] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 0e9d314e337c..48112ea5ff09 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -281,6 +281,9 @@ def _handle_output_by_index(output, i): output_routed_experts=_extract_field_by_index( output, "output_routed_experts", i, check_length=False ), + output_dsa_topk_indices=_extract_field_by_index( + output, "output_dsa_topk_indices", i, check_length=False + ), customized_info=_extract_field_by_index( output, "customized_info", i, check_length=False ), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5f168595c5e5..8ae4dc7b19c5 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -511,6 +511,7 @@ def __init__( require_reasoning: bool = False, return_hidden_states: bool = False, return_routed_experts: bool = False, + return_dsa_topk_indices: bool = False, eos_token_ids: Optional[Set[int]] = None, bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, @@ -712,6 +713,11 @@ def __init__( self.routed_experts: Optional[torch.Tensor] = ( None # cpu tensor: shape (seqlen, topk) ) + # capture dsa topk indices + self.return_dsa_topk_indices = return_dsa_topk_indices + self.dsa_topk_indices: Optional[torch.Tensor] = ( + None # cpu tensor: shape (seqlen, topk) + ) # Customized info self.customized_info: Optional[Dict[str, List[Any]]] = None @@ -1071,6 +1077,7 @@ def reset_for_retract(self): self.prefix_indices = torch.empty((0,), dtype=torch.int64) self.routed_experts = None + self.dsa_topk_indices = None self.last_node = None self.swa_uuid_for_lock = None self.extend_input_len = 0 @@ -1272,6 +1279,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Whether to return captured experts return_routed_experts: bool = False + # Whether to return dsa topk indices + return_dsa_topk_indices: bool = False + # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False @@ -1318,6 +1328,7 @@ def init_new( spec_algorithm=spec_algorithm, return_hidden_states=any(req.return_hidden_states for req in reqs), return_routed_experts=any(req.return_routed_experts for req in reqs), + return_dsa_topk_indices=any(req.return_dsa_topk_indices for req in reqs), is_prefill_only=all(req.is_prefill_only for req in reqs), chunked_req=chunked_req, dllm_config=dllm_config, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f06a9a8e8f02..f5da62459531 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1438,6 +1438,7 @@ def handle_generate_request( require_reasoning=recv_req.require_reasoning, return_hidden_states=recv_req.return_hidden_states, return_routed_experts=recv_req.return_routed_experts, + return_dsa_topk_indices=recv_req.return_dsa_topk_indices, eos_token_ids=self.model_config.hf_eos_token_id, bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index ed614fea9d35..d810fade093b 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -71,6 +71,14 @@ def maybe_collect_routed_experts(self: Scheduler, req: Req): req_to_token_pool=self.req_to_token_pool, ) + def maybe_collect_dsa_topk_indices(self: Scheduler, req: Req): + """Collect DSA topk indices for a finished request.""" + req.dsa_topk_indices = get_global_experts_capturer().get_dsa_topk_indices( + req_pool_idx=req.req_pool_idx, + seqlen=req.seqlen, + req_to_token_pool=self.req_to_token_pool, + ) + def maybe_collect_customized_info( self: Scheduler, i: int, req: Req, logits_output: LogitsProcessorOutput ): @@ -137,6 +145,7 @@ def process_batch_result_prefill( if req.finished(): self.maybe_collect_routed_experts(req) + self.maybe_collect_dsa_topk_indices(req) release_kv_cache(req, self.tree_cache) req.time_stats.completion_time = time.perf_counter() elif not batch.decoding_reqs or req not in batch.decoding_reqs: @@ -411,6 +420,7 @@ def process_batch_result_decode( if req.finished(): self.maybe_collect_routed_experts(req) + self.maybe_collect_dsa_topk_indices(req) if self.server_args.disaggregation_decode_enable_offload_kvcache: # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes @@ -853,6 +863,7 @@ def stream_output_generation( output_hidden_states = None load = self.get_load() output_routed_experts = None + output_dsa_topk_indices = None customized_info = {} queue_times = [] @@ -1052,6 +1063,10 @@ def stream_output_generation( if output_routed_experts is None: output_routed_experts = [] output_routed_experts.append(req.routed_experts) + if req.return_dsa_topk_indices: + if output_dsa_topk_indices is None: + output_dsa_topk_indices = [] + output_dsa_topk_indices.append(req.dsa_topk_indices) if req.customized_info is not None: for k, v in req.customized_info.items(): @@ -1108,6 +1123,7 @@ def stream_output_generation( output_token_entropy_val=None, output_hidden_states=output_hidden_states, output_routed_experts=output_routed_experts, + output_dsa_topk_indices=output_dsa_topk_indices, customized_info=customized_info, placeholder_tokens_idx=None, placeholder_tokens_val=None, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a433a0597bf1..b39ddb436a97 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -926,6 +926,7 @@ def _create_tokenized_object( require_reasoning=obj.require_reasoning, return_hidden_states=obj.return_hidden_states, return_routed_experts=obj.return_routed_experts, + return_dsa_topk_indices=obj.return_dsa_topk_indices, data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, @@ -1523,6 +1524,8 @@ def _handle_batch_output( meta_info["hidden_states"] = recv_obj.output_hidden_states[i] if getattr(recv_obj, "output_routed_experts", None): meta_info["routed_experts"] = recv_obj.output_routed_experts[i] + if getattr(recv_obj, "output_dsa_topk_indices", None): + meta_info["dsa_topk_indices"] = recv_obj.output_dsa_topk_indices[i] if getattr(recv_obj, "customized_info", None): for k, v in recv_obj.customized_info.items(): meta_info[k] = v[i] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f3ce8cf4a983..d374d1c64989 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -600,6 +600,7 @@ def init_routed_experts_capturer(self): num_tokens=self.max_total_num_tokens + self.page_size, max_running_requests=self.max_running_requests, device=self.device, + enable_capture_dsa_topk_indices=get_global_server_args().enable_return_dsa_topk_indices, ) ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 75746bfe7164..e502c833c980 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -97,6 +97,7 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.kt_ep_wrapper import KTEPWrapperMethod +from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, CombineInput, @@ -1857,6 +1858,11 @@ def forward_absorb_prepare( latent_cache, forward_batch, k_nope, k_pe ) + get_global_experts_capturer().capture_dsa_topk_indices( + layer_id=self.layer_id, + dsa_topk_indices=topk_indices, + ) + return ( q_pe, k_pe, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4451a08dd914..8063681fb11f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -580,6 +580,7 @@ class ServerArgs: keep_mm_feature_on_device: bool = False enable_return_hidden_states: bool = False enable_return_routed_experts: bool = False + enable_return_dsa_topk_indices: bool = False scheduler_recv_interval: int = 1 numa_node: Optional[List[int]] = None enable_deterministic_inference: bool = False @@ -4312,6 +4313,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable returning routed experts of each layer with responses.", ) + parser.add_argument( + "--enable-return-dsa-topk-indices", + action="store_true", + help="Enable returning DSA topk indices of each layer with responses.", + ) parser.add_argument( "--scheduler-recv-interval", type=int, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index ca0e7c735c58..60419c9f1633 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -802,6 +802,7 @@ def get_benchmark_args( disable_stream=disable_stream, return_logprob=False, return_routed_experts=False, + return_dsa_topk_indices=False, seed=seed, disable_ignore_eos=disable_ignore_eos, extra_request_body=None,