diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 4b6568aaf0a5..3154802d9ab1 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -565,6 +565,7 @@ async def async_request_sglang_generate( "stream": not args.disable_stream, "lora_path": request_func_input.lora_name, "return_logprob": args.return_logprob, + "return_routed_experts": args.return_routed_experts, "logprob_start_len": -1, **request_func_input.extra_request_body, } @@ -2809,6 +2810,11 @@ def __call__(self, parser, namespace, values, option_string=None): action="store_true", help="Return logprob.", ) + parser.add_argument( + "--return-routed-experts", + action="store_true", + help="Return routed experts.", + ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--disable-ignore-eos", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 0f37b4a870da..5050a65e0b01 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -286,6 +286,7 @@ def generate( lora_path: Optional[List[Optional[str]]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None, return_hidden_states: bool = False, + return_routed_experts: bool = False, stream: bool = False, bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, @@ -321,6 +322,7 @@ def generate( lora_path=lora_path, custom_logit_processor=custom_logit_processor, return_hidden_states=return_hidden_states, + return_routed_experts=return_routed_experts, stream=stream, bootstrap_host=bootstrap_host, bootstrap_port=bootstrap_port, diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py new file mode 100644 index 000000000000..00bd68755587 --- /dev/null +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -0,0 +1,289 @@ +import logging +from abc import ABC +from typing import Optional + +import numpy as np +import pybase64 +import torch + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.layers.dp_attention import ( + get_attention_dp_rank, + get_dp_local_info, + is_dp_attention_enabled, +) +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args + +logger = logging.getLogger(__name__) + +_GB = 1024 * 1024 * 1024 +_MB = 1024 * 1024 + + +def get_tensor_size_bytes(t: torch.Tensor): + return np.prod(t.shape) * t.dtype.itemsize + + +class _RoutedExpertsDeviceCache: + def __init__( + self, + max_running_requests: int, + num_hidden_layers: int, + num_experts_per_tok: int, + num_fused_shared_experts: int, + device: str, + ) -> None: + self.buffer = torch.zeros( + ( + max( + get_global_server_args().chunked_prefill_size + * get_global_server_args().dp_size, + max_running_requests, + ), + num_hidden_layers, + num_experts_per_tok + num_fused_shared_experts, + ), + dtype=torch.int32, + device=device, + ) + self._finalize_allocation_log() + + def get_buffer_size_bytes(self): + assert hasattr(self, "buffer") + return get_tensor_size_bytes(self.buffer) + + def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): + assert layer_id is not None, "capturing routing experts but get layer_id None" + batch, _ = topk_ids.shape + self.buffer[:batch, layer_id, :] = topk_ids + + def _finalize_allocation_log(self): + """Common logging and memory usage computation for captured experts buffers.""" + buffer_size_MB = self.get_buffer_size_bytes() / _MB + logger.info( + f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB" + ) + + +class _RoutedExpertsHostCache: + def __init__( + self, + num_tokens: int, + num_hidden_layers: int, + num_experts_per_tok: int, + ) -> None: + self.num_tokens = num_tokens + self.buffer = torch.zeros( + ( + num_tokens, + num_hidden_layers, + num_experts_per_tok, + ), + dtype=torch.int32, + device="cpu", + pin_memory=True, + ) + self._finalize_allocation_log() + + def get_buffer_size_bytes(self): + assert hasattr(self, "buffer") + return get_tensor_size_bytes(self.buffer) + + def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): + self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) + + def _finalize_allocation_log(self): + """Common logging and memory usage computation for captured experts buffers.""" + buffer_size_GB = self.get_buffer_size_bytes() / _GB + logger.info( + f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB" + ) + + +class RoutedExpertsCapturer(ABC): + @staticmethod + def create( + enable: bool, + model_config: ModelConfig, + num_fused_shared_experts: int, + num_tokens: int, + max_running_requests: int, + device: str, + ): + if enable: + return _RoutedExpertsCapturerReal( + model_config, + num_tokens=num_tokens, + max_running_requests=max_running_requests, + num_fused_shared_experts=num_fused_shared_experts, + device=device, + ) + else: + return _RoutedExpertsCapturerNoop() + + def _sync_fwd_experts_buffer_DtoH( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: int, + ): + raise NotImplementedError + + def capture(self, layer_id: int, topk_ids: torch.Tensor): + raise NotImplementedError + + def get_routed_experts( + self, + req_pool_idx: int, + seqlen: int, + req_to_token_pool: ReqToTokenPool, + ): + raise NotImplementedError + + def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): + raise NotImplementedError + + def get_host_cache(self): + raise NotImplementedError + + def get_device_cache(self): + raise NotImplementedError + + +class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): + """Capturer for routed experts with host buffer""" + + def __init__( + self, + model_config: ModelConfig, + num_tokens: int, + max_running_requests: int, + num_fused_shared_experts: int, + device: str, + ): + self.num_fused_shared_experts = num_fused_shared_experts + self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers + self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok + + self.host_cache = _RoutedExpertsHostCache( + num_tokens=num_tokens, + num_hidden_layers=self.num_hidden_layers, + num_experts_per_tok=self.num_experts_per_tok, + ) + + self.device_cache = _RoutedExpertsDeviceCache( + max_running_requests=max_running_requests, + num_hidden_layers=self.num_hidden_layers, + num_experts_per_tok=self.num_experts_per_tok, + num_fused_shared_experts=self.num_fused_shared_experts, + device=device, + ) + + def _sync_fwd_experts_buffer_DtoH( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: int, + ): + if is_dp_attention_enabled(): + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + # handle with cuda graph padding + if can_run_graph: + local_start_pos = get_attention_dp_rank() * cuda_graph_batch + local_end_pos = local_start_pos + local_num_tokens + else: + local_end_pos = local_start_pos + local_num_tokens + else: + local_start_pos = 0 + local_end_pos = forward_batch.out_cache_loc.shape[0] + + # FIXME: sync explicitly here, overlap scheduler breaks here. + out_cache_loc_cpu = forward_batch.out_cache_loc.cpu() + self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[ + local_start_pos:local_end_pos, :, : self.num_experts_per_tok + ].cpu() + + def capture(self, layer_id: int, topk_ids: torch.Tensor): + self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) + + def get_routed_experts( + self, + req_pool_idx: int, + seqlen: int, + req_to_token_pool: ReqToTokenPool, + ): + cache_pool_idx = ( + req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() + ) + return self.get_host_cache().buffer[cache_pool_idx] + + def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): + self._sync_fwd_experts_buffer_DtoH( + forward_batch=forward_batch, + can_run_graph=can_run_graph, + cuda_graph_batch=cuda_graph_batch, + ) + + def get_host_cache(self): + return self.host_cache + + def get_device_cache(self): + return self.device_cache + + +class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer): + def __init__(self): + pass + + def _sync_fwd_experts_buffer_DtoH( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: int, + ): + pass + + def capture(self, layer_id: int, topk_ids: torch.Tensor): + pass + + def get_routed_experts( + self, + req_pool_idx: int, + seqlen: int, + req_to_token_pool: ReqToTokenPool, + ): + pass + + def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): + pass + + def get_host_cache(self): + pass + + def get_device_cache(self): + pass + + +_global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop() + + +def get_global_experts_capturer(): + return _global_expert_capturer + + +def set_global_experts_capturer(capturer: RoutedExpertsCapturer): + global _global_expert_capturer + _global_expert_capturer = capturer + + +def extract_routed_experts_from_meta_info(data): + # To solve the performance issue, we return the experts_ids in base64 + # We left this function for user to change it back to normal int32 + # See detokenizer_manager::_extract_routed_experts + routed_experts_base64 = data["meta_info"].get("routed_experts", None) + routed_experts = np.frombuffer( + pybase64.b64decode(routed_experts_base64.encode("utf-8")), dtype=np.int32 + ) + return routed_experts diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index a43fa350d805..65a5939918a2 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -48,6 +48,7 @@ ) from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import get_moe_runner_backend +from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -203,6 +204,7 @@ def __init__( self, top_k: int, *, + layer_id: Optional[int] = None, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, @@ -224,6 +226,7 @@ def __init__( if use_grouped_topk: assert num_expert_group is not None and topk_group is not None + self.layer_id = layer_id self.topk_config = TopKConfig( top_k=top_k, use_grouped_topk=use_grouped_topk, @@ -251,6 +254,7 @@ def forward_native( self.topk_config.torch_native = True return select_experts( hidden_states=hidden_states, + layer_id=self.layer_id, router_logits=router_logits, topk_config=self.topk_config, num_token_non_padded=num_token_non_padded, @@ -300,6 +304,7 @@ def forward_cuda( ): topk_output = select_experts( hidden_states=hidden_states, + layer_id=self.layer_id, router_logits=router_logits, topk_config=self.topk_config, num_token_non_padded=num_token_non_padded, @@ -317,6 +322,7 @@ def forward_cpu( ) -> TopKOutput: return select_experts( hidden_states=hidden_states, + layer_id=self.layer_id, router_logits=router_logits, topk_config=self.topk_config, num_token_non_padded=num_token_non_padded, @@ -847,6 +853,7 @@ def select_experts( router_logits: torch.Tensor, topk_config: TopKConfig, *, + layer_id: Optional[int] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ) -> StandardTopKOutput: @@ -974,7 +981,10 @@ def select_experts( ) get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) - + get_global_experts_capturer().capture( + layer_id=layer_id, + topk_ids=topk_ids, + ) return StandardTopKOutput(topk_weights, topk_ids, router_logits) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 2019704d761a..9a207b554e38 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -21,6 +21,7 @@ from typing import Dict, List, Union import psutil +import pybase64 import setproctitle import zmq @@ -266,8 +267,24 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): return output_strs + def _extract_routed_experts(self, recv_obj: BatchTokenIDOutput) -> List[List[int]]: + output_routed_experts = None + if recv_obj.output_routed_experts is not None: + output_routed_experts = [ + ( + pybase64.b64encode(output_routed_experts.numpy().tobytes()).decode( + "utf-8" + ) + if output_routed_experts is not None + else [] + ) + for output_routed_experts in recv_obj.output_routed_experts + ] + return output_routed_experts + def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): output_strs = self._decode_batch_token_id_output(recv_obj) + output_routed_experts = self._extract_routed_experts(recv_obj) return BatchStrOutput( rids=recv_obj.rids, @@ -294,6 +311,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_token_entropy_val=recv_obj.output_token_entropy_val, output_hidden_states=recv_obj.output_hidden_states, + output_routed_experts=output_routed_experts, placeholder_tokens_idx=None, placeholder_tokens_val=None, retraction_counts=recv_obj.retraction_counts, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 97df9c6db552..ef9f4886e557 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -23,6 +23,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +import torch + from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.multimodal.mm_utils import has_valid_data @@ -196,6 +198,8 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): log_metrics: bool = True # Whether to return hidden states return_hidden_states: Union[List[bool], bool] = False + # Whether to return captured routed experts + return_routed_experts: bool = False # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None @@ -618,6 +622,7 @@ def __getitem__(self, i): if isinstance(self.return_hidden_states, list) else self.return_hidden_states ), + return_routed_experts=self.return_routed_experts, modalities=self.modalities[i] if self.modalities else None, session_params=self.session_params, lora_path=self.lora_path[i] if self.lora_path is not None else None, @@ -687,6 +692,9 @@ class TokenizedGenerateReqInput(BaseReq): # Whether to return hidden states return_hidden_states: bool = False + # Whether to return captured routed experts + return_routed_experts: bool = False + # The input embeds input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None @@ -953,6 +961,9 @@ class BatchTokenIDOutput( # Hidden states output_hidden_states: List[List[float]] + # The routed experts for each output token + output_routed_experts: List[torch.Tensor] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. @@ -1032,6 +1043,9 @@ class BatchStrOutput( # Hidden states output_hidden_states: List[List[float]] + # The routed experts for each output token + output_routed_experts: List[List[int]] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index b8c9174aa03c..978be0cd076a 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -270,6 +270,9 @@ def _handle_output_by_index(output, i): output_hidden_states=_extract_field_by_index( output, "output_hidden_states", i, check_length=False ), + output_routed_experts=_extract_field_by_index( + output, "output_routed_experts", i, check_length=False + ), placeholder_tokens_idx=None, placeholder_tokens_val=None, retraction_counts=_extract_field_by_index(output, "retraction_counts", i), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 286908866bf9..07a6455a442b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -489,6 +489,7 @@ def __init__( custom_logit_processor: Optional[str] = None, require_reasoning: bool = False, return_hidden_states: bool = False, + return_routed_experts: bool = False, eos_token_ids: Optional[Set[int]] = None, bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, @@ -679,6 +680,12 @@ def __init__( self.output_topk_p = None self.output_topk_index = None + # capture routed experts + self.return_routed_experts = return_routed_experts + self.routed_experts: Optional[torch.Tensor] = ( + None # cpu tensor: shape (seqlen, topk) + ) + # Embedding (return values) self.embedding = None @@ -1043,6 +1050,7 @@ def reset_for_retract(self): self.retraction_count += 1 self.prefix_indices = torch.empty((0,), dtype=torch.int64) + self.routed_experts = None self.last_node = None self.swa_uuid_for_lock = None self.extend_input_len = 0 @@ -1219,6 +1227,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Whether to return hidden states return_hidden_states: bool = False + # Whether to return captured experts + return_routed_experts: bool = False + # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False @@ -1266,6 +1277,7 @@ def init_new( device=req_to_token_pool.device, spec_algorithm=spec_algorithm, return_hidden_states=any(req.return_hidden_states for req in reqs), + return_routed_experts=any(req.return_routed_experts for req in reqs), is_prefill_only=all(req.is_prefill_only for req in reqs), chunked_req=chunked_req, dllm_config=dllm_config, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8631452ce669..5e48d95fdc8b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1413,6 +1413,7 @@ def handle_generate_request( custom_logit_processor=recv_req.custom_logit_processor, require_reasoning=recv_req.require_reasoning, return_hidden_states=recv_req.return_hidden_states, + return_routed_experts=recv_req.return_routed_experts, eos_token_ids=self.model_config.hf_eos_token_id, bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index bdc009ff4af9..4228dc536ffd 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -9,6 +9,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOutput, @@ -62,6 +63,14 @@ def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch): trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs) self.stream_output(batch.reqs, batch.return_logprob) + def maybe_collect_routed_experts(self: Scheduler, req: Req): + """Collect routed experts for a finished request.""" + req.routed_experts = get_global_experts_capturer().get_routed_experts( + req_pool_idx=req.req_pool_idx, + seqlen=req.seqlen, + req_to_token_pool=self.req_to_token_pool, + ) + def process_batch_result_prefill( self: Scheduler, batch: ScheduleBatch, @@ -116,6 +125,7 @@ def process_batch_result_prefill( req.check_finished() if req.finished(): + self.maybe_collect_routed_experts(req) release_kv_cache(req, self.tree_cache) req.time_stats.completion_time = time.perf_counter() elif not batch.decoding_reqs or req not in batch.decoding_reqs: @@ -372,6 +382,8 @@ def process_batch_result_decode( req.check_finished(new_accepted_len) if req.finished(): + self.maybe_collect_routed_experts(req) + if self.server_args.disaggregation_decode_enable_offload_kvcache: # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes if not self.decode_offload_manager.offload_kv_cache(req): @@ -791,6 +803,7 @@ def stream_output_generation( spec_accepted_tokens = [] retraction_counts = [] output_hidden_states = None + output_routed_experts = None queue_times = [] forward_entry_times = [] @@ -985,6 +998,10 @@ def stream_output_generation( if output_hidden_states is None: output_hidden_states = [] output_hidden_states.append(req.hidden_states) + if req.return_routed_experts: + if output_routed_experts is None: + output_routed_experts = [] + output_routed_experts.append(req.routed_experts) if ( req.finished() @@ -1034,6 +1051,7 @@ def stream_output_generation( output_token_ids_logprobs_idx=output_token_ids_logprobs_idx, output_token_entropy_val=None, output_hidden_states=output_hidden_states, + output_routed_experts=output_routed_experts, placeholder_tokens_idx=None, placeholder_tokens_val=None, retraction_counts=retraction_counts, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2d46471decd5..0116fa10ec77 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -835,6 +835,7 @@ def _create_tokenized_object( custom_logit_processor=obj.custom_logit_processor, require_reasoning=obj.require_reasoning, return_hidden_states=obj.return_hidden_states, + return_routed_experts=obj.return_routed_experts, data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, @@ -1574,6 +1575,9 @@ def _handle_batch_output( if getattr(recv_obj, "output_hidden_states", None): meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + if getattr(recv_obj, "output_routed_experts", None): + meta_info["routed_experts"] = recv_obj.output_routed_experts[i] + if isinstance(recv_obj, BatchStrOutput): state.text += recv_obj.output_strs[i] if self.server_args.stream_output and state.obj.stream: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8d8c85e8f687..0d561aab7b02 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -317,15 +317,7 @@ def __init__( # Profile number of tokens self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_prefill_tokens = server_args.max_prefill_tokens - self.max_running_requests = min( - ( - self.max_total_num_tokens // 2 - if server_args.max_running_requests is None - else server_args.max_running_requests - // (server_args.dp_size if server_args.enable_dp_attention else 1) - ), - self.model_runner.req_to_token_pool.size, - ) + self.max_running_requests = self.model_runner.max_running_requests assert self.max_running_requests > 0, "max_running_request is zero" self.max_queued_requests = server_args.max_queued_requests assert ( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 26d61239fc7b..2e4497cd5b6e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -97,6 +97,11 @@ set_is_extend_in_batch, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.moe.routed_experts_capturer import ( + RoutedExpertsCapturer, + get_global_experts_capturer, + set_global_experts_capturer, +) from sglang.srt.layers.moe.utils import get_moe_a2a_backend from sglang.srt.layers.pooler import EmbeddingPoolerOutput from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype @@ -557,6 +562,21 @@ def initialize(self, min_per_gpu_memory: float): server_args.max_running_requests, server_args.max_total_tokens, ) + + # Init max running requests + self.max_running_requests = min( + ( + self.max_total_num_tokens // 2 + if server_args.max_running_requests is None + else server_args.max_running_requests + // (server_args.dp_size if server_args.enable_dp_attention else 1) + ), + self.req_to_token_pool.size, + ) + + # Init routed experts capturer + self.init_routed_experts_capturer() + if self.device == "cuda": self.init_cublas() self.init_attention_backend() @@ -600,6 +620,40 @@ def initialize(self, min_per_gpu_memory: float): # Initialize piecewise CUDA graph self.init_piecewise_cuda_graphs() + def init_routed_experts_capturer(self): + # TODO: the redundant logic with TpModelWorker + max_running_requests = min( + ( + self.max_total_num_tokens // 2 + if self.server_args.max_running_requests is None + else self.server_args.max_running_requests + // ( + self.server_args.dp_size + if self.server_args.enable_dp_attention + else 1 + ) + ), + self.req_to_token_pool.size, + ) + + if not self.server_args.disable_shared_experts_fusion and hasattr( + self.model, "num_fused_shared_experts" + ): + num_fused_shared_experts = self.model.num_fused_shared_experts + else: + num_fused_shared_experts = 0 + + set_global_experts_capturer( + RoutedExpertsCapturer.create( + enable=get_global_server_args().enable_return_routed_experts, + model_config=self.model_config, + num_fused_shared_experts=num_fused_shared_experts, + num_tokens=self.max_total_num_tokens + self.page_size, + max_running_requests=max_running_requests, + device=self.device, + ) + ) + def remote_instance_init_transfer_engine(self): try: from mooncake.engine import TransferEngine @@ -2840,6 +2894,13 @@ def forward( ) output.expert_distribution_metrics = recorder_outputs.get("metrics") + # Copy cached routing experts' buffers back to CPU cache + get_global_experts_capturer().on_forward_end( + forward_batch=forward_batch, + can_run_graph=output.can_run_graph, + cuda_graph_batch=getattr(self.graph_runner, "bs", None), + ) + if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d9bb49559a35..38a4f3cde6c9 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -666,6 +666,7 @@ def __init__( self.topk = TopK( top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + layer_id=self.layer_id, renormalize=config.norm_topk_prob, use_grouped_topk=True, num_expert_group=config.n_group, diff --git a/python/sglang/srt/models/ernie4.py b/python/sglang/srt/models/ernie4.py index ab1b6576bfb6..dffd8f09a8bd 100644 --- a/python/sglang/srt/models/ernie4.py +++ b/python/sglang/srt/models/ernie4.py @@ -87,6 +87,7 @@ def __init__( self.topk = TopK( top_k=config.moe_k, + layer_id=layer_id, renormalize=True, use_grouped_topk=False, correction_bias=self.gate.e_score_correction_bias, diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index c27816c2c93a..0e8c9cf79a2d 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -394,6 +394,7 @@ def __init__( self.topk = TopK( top_k=self.top_k + self.num_fused_shared_experts, + layer_id=self.layer_id, renormalize=config.norm_topk_prob, use_grouped_topk=True, num_expert_group=config.n_group, diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 3e7283dd49c5..0d79f898bfd7 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -113,6 +113,7 @@ def __init__( self.topk = TopK( top_k=config.num_experts_per_tok, renormalize=True, + layer_id=layer_id, ) self.top_k = config.num_experts_per_tok diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index fd513060a911..a089475b7aa5 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -142,6 +142,7 @@ def __init__( self.topk = TopK( top_k=top_k, renormalize=False, + layer_id=layer_id, custom_routing_function=custom_routing_function, ) diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index 87f1628679a3..300493a3f1e8 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -150,6 +150,7 @@ def __init__( self.topk = TopK( top_k=top_k, + layer_id=layer_id, renormalize=True if top_k > 1 else False, ) diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py index 4c2a625f54ae..275ee901917f 100644 --- a/python/sglang/srt/models/longcat_flash.py +++ b/python/sglang/srt/models/longcat_flash.py @@ -245,6 +245,7 @@ def __init__( renormalize=False, use_grouped_topk=False, correction_bias=self.router.e_score_correction_bias.data, + layer_id=layer_id, ) self.topk.forward = self.topk.forward_native diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 03dad18458ba..c978d2c11614 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -161,6 +161,7 @@ def __init__( self.topk = TopK( top_k=config.num_experts_per_tok, renormalize=config.norm_topk_prob, + layer_id=layer_id, ) self.experts = get_moe_impl_class(quant_config)( diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 9388b974a33b..9a6d3070e8d9 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -227,6 +227,7 @@ def __init__( top_k=config.num_experts_per_tok, renormalize=config.norm_topk_prob, use_grouped_topk=False, + layer_id=layer_id, ) self.experts = get_moe_impl_class(quant_config)( diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py index 8d1673a05a0d..ad8fb6f9c142 100644 --- a/python/sglang/srt/models/step3_vl.py +++ b/python/sglang/srt/models/step3_vl.py @@ -129,6 +129,7 @@ def __init__( top_k=config.moe_top_k, renormalize=config.norm_expert_weight, use_grouped_topk=False, + layer_id=layer_id, ) self.experts = get_moe_impl_class(quant_config)( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 56ac53a5df9e..ae04fb46af55 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -572,6 +572,7 @@ class ServerArgs: disable_fast_image_processor: bool = False keep_mm_feature_on_device: bool = False enable_return_hidden_states: bool = False + enable_return_routed_experts: bool = False scheduler_recv_interval: int = 1 numa_node: Optional[List[int]] = None enable_deterministic_inference: bool = False @@ -4070,6 +4071,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable returning hidden states with responses.", ) + parser.add_argument( + "--enable-return-routed-experts", + action="store_true", + help="Enable returning routed experts of each layer with responses.", + ) parser.add_argument( "--scheduler-recv-interval", type=int, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 3923a2573488..2500d6193551 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -120,6 +120,7 @@ ) DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST = "Qwen/Qwen3-30B-A3B" DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST = "Barrrrry/DeepSeek-R1-W4AFP8" +DEFAULT_ENABLE_ROUTED_EXPERTS_MODEL_NAME_FOR_TEST = "Qwen/Qwen3-30B-A3B" # Nightly tests DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it" @@ -772,6 +773,7 @@ def get_benchmark_args( disable_tqdm=False, disable_stream=disable_stream, return_logprob=False, + return_routed_experts=False, seed=seed, disable_ignore_eos=disable_ignore_eos, extra_request_body=None, diff --git a/test/srt/rl/test_return_routed_experts.py b/test/srt/rl/test_return_routed_experts.py new file mode 100644 index 000000000000..da995cb17b0e --- /dev/null +++ b/test/srt/rl/test_return_routed_experts.py @@ -0,0 +1,187 @@ +import asyncio +import logging +import unittest +from typing import List + +import aiohttp +import requests +import torch +from torch.nn.utils.rnn import pad_sequence + +from sglang.srt.layers.moe.routed_experts_capturer import ( + extract_routed_experts_from_meta_info, +) +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_ENABLE_ROUTED_EXPERTS_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +SHAREGPT_URL = ( + "https://huggingface.co/datasets/anon8231489123/" + "ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" +) +logger = logging.getLogger(__name__) + + +class TestReturnRoutedExperts(CustomTestCase): + # modified from test_hicache.py + @classmethod + def setUpClass(cls): + + cls.baseline_args = [ + "--enable-return-routed-experts", + "--enable-deterministic-inference", + "--disable-overlap-schedule", + "--disable-cuda-graph", + "--disable-radix-cache", + "--tp", + 4, + "--dp", + 4, + "--enable-dp-attention", + ] + cls.reference_args = [ + "--enable-return-routed-experts", + "--enable-deterministic-inference", + "--tp", + 4, + "--dp", + 4, + "--enable-dp-attention", + ] + cls.sampling_args = { + "temperature": 0, + } + # prepare ShareGPT dataset + try: + response = requests.get(SHAREGPT_URL, timeout=60) + response.raise_for_status() + data = response.json() + print(f"Dataset size: {len(data)}") + except requests.exceptions.RequestException as e: + raise Exception(f"Failed to download ShareGPT dataset: {e}") from e + cls.texts = [] + for s in data: + if "conversations" in s and len(s["conversations"]) > 0: + try: + text = s["conversations"][0]["value"] + if isinstance(text, str) and len(text) <= 2000: + cls.texts.append(text) + except (KeyError, IndexError, TypeError) as e: + print(f"Warning: Skipping invalid conversation data: {e}") + continue + + if not cls.texts: + raise ValueError("No valid texts found in the dataset") + cls.texts = cls.texts[:100] + + @classmethod + def test_return_routed_experts(cls): + captured_baseline_experts = asyncio.run( + cls.fetch_result("baseline", cls.baseline_args) + ) + captured_reference_experts = asyncio.run( + cls.fetch_result("reference", cls.reference_args) + ) + + check_all_experts_id_valid(captured_baseline_experts) + check_all_experts_id_valid(captured_reference_experts) + + num_baseline_topks = ( + sum([len(seq) for seq in captured_baseline_experts]) + * len(captured_baseline_experts[0][0]) + * len(captured_baseline_experts[0][0][0]) + ) + + num_mismatches = compare_baseline_w_reference( + captured_baseline_experts, captured_reference_experts + ) + logger.info( + f"Total mismatches report: {num_mismatches} out of {num_baseline_topks} ({num_mismatches/num_baseline_topks:.4%})" + ) + print( + f"Total mismatches report: {num_mismatches} out of {num_baseline_topks} ({num_mismatches/num_baseline_topks:.4%})" + ) + assert ( + num_mismatches / num_baseline_topks < 0.05 + ), f"Too many mismatches: {num_mismatches} out of {num_baseline_topks} ({num_mismatches/num_baseline_topks:.4%})" + + @classmethod + async def fetch_result(cls, title, other_args): + try: + process = popen_launch_server( + DEFAULT_ENABLE_ROUTED_EXPERTS_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + async with aiohttp.ClientSession() as session: + tasks = [ + asyncio.create_task( + make_request( + session, + f"{DEFAULT_URL_FOR_TEST}/generate", + { + "text": text, + "sampling_params": cls.sampling_args, + "return_routed_experts": True, + "max_new_tokens": 100, + }, + ) + ) + for text in cls.texts + ] + # return value shape: List[[seq_len, num_layers, topk]...] + http_result = await asyncio.gather(*tasks) + except Exception as e: + raise e + finally: + kill_process_tree(process.pid) + + result = [ + extract_routed_experts_from_meta_info(res).reshape(-1, 48, 8) + for res in http_result + ] + + return result + + +async def make_request(session, url, payload): + """Make a single async HTTP request""" + async with session.post(url=url, json=payload) as response: + return await response.json() + + +def check_all_experts_id_valid(experts: List[List[List[int]]]): + tensor_list = [torch.tensor(lst) for lst in experts] + padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=0) + + # temporary hardcode as we only use Qwen3 30BA3B + if not ((padded_tensor >= 0) & (padded_tensor <= 127)).all(): + raise ValueError( + f"Some expert indices are out of valid range [0, 127], MAX: {padded_tensor.max()} MIN: {padded_tensor.min()}" + ) + + +def compare_baseline_w_reference(baseline, reference): + num_total_mismatches = 0 + for baseline_seq, reference_seq in zip(baseline, reference): + for bsl_token, ref_token in zip(baseline_seq, reference_seq): + for bsl_topk, ref_topk in zip(bsl_token, ref_token): + len_bsl, len_ref = len(bsl_topk), len(ref_topk) + set_bsl, set_ref = set(bsl_topk), set(ref_topk) + if set_bsl != set_ref: + num_total_mismatches += len(set_bsl - set_ref) + if (len_bsl != len_ref) or (len_bsl != len(set_bsl)): + raise ValueError( + f"Duplicates experts ids found: Baseline({len_bsl}): {bsl_topk} vs Reference({len_ref}): {ref_topk}" + ) + return num_total_mismatches + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 14552ca4619c..35d4e21dd418 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -150,6 +150,7 @@ TestFile("test_multi_instance_release_memory_occupation.py", 64), TestFile("test_pp_single_node.py", 500), TestFile("test_epd_disaggregation.py", 150), + TestFile("rl/test_return_routed_experts.py", 300), ], "per-commit-8-gpu-h200": [ TestFile("test_deepseek_v3_basic.py", 275),