Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ async def async_request_sglang_generate(
"lora_path": request_func_input.lora_name,
"return_logprob": args.return_logprob,
"return_routed_experts": args.return_routed_experts,
"return_dsa_topk_indices": args.return_dsa_topk_indices,
"logprob_start_len": -1,
**request_func_input.extra_request_body,
}
Expand Down Expand Up @@ -3028,6 +3029,11 @@ def __call__(self, parser, namespace, values, option_string=None):
action="store_true",
help="Return routed experts.",
)
parser.add_argument(
"--return-dsa-topk-indices",
action="store_true",
help="Return DSA topk indices.",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--disable-ignore-eos",
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def generate(
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
return_routed_experts: bool = False,
return_dsa_topk_indices: bool = False,
stream: bool = False,
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
Expand Down Expand Up @@ -263,6 +264,7 @@ def generate(
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
return_routed_experts=return_routed_experts,
return_dsa_topk_indices=return_dsa_topk_indices,
stream=stream,
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,7 @@ def forward_extend(
if envs.SGLANG_NSA_FUSE_TOPK.get():
page_table_1 = topk_indices
else:
# zianglih: non fused path, need to transform the topk indices to indices to the page table (page_size = 1)
if topk_transform_method == TopkTransformMethod.RAGGED:
topk_indices_offset = metadata.topk_indices_offset
assert topk_indices_offset is not None
Expand Down Expand Up @@ -1376,6 +1377,7 @@ def forward_decode(
if envs.SGLANG_NSA_FUSE_TOPK.get():
page_table_1 = topk_indices
else:
# zianglih: non fused path, need to transform the topk indices to indices to the page table (page_size = 1)
page_table_1 = transform_index_page_table_decode(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
Expand Down
118 changes: 118 additions & 0 deletions python/sglang/srt/layers/moe/routed_experts_capturer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(
num_experts_per_tok: int,
num_fused_shared_experts: int,
device: str,
enable_capture_dsa_topk_indices: bool = False,
num_dsa_topk_indices: int = 2048,
) -> None:
self.buffer = torch.zeros(
(
Expand All @@ -48,23 +50,56 @@ def __init__(
dtype=torch.int32,
device=device,
)
self.dsa_topk_indices_buffer = None
if enable_capture_dsa_topk_indices:
self.dsa_topk_indices_buffer = torch.zeros(
(
max(
get_global_server_args().chunked_prefill_size
* get_global_server_args().dp_size,
max_running_requests,
),
num_hidden_layers,
num_dsa_topk_indices,
),
dtype=torch.int32,
device=device,
)
self._finalize_allocation_log()

def get_buffer_size_bytes(self):
assert hasattr(self, "buffer")
return get_tensor_size_bytes(self.buffer)

def get_dsa_topk_indices_buffer_size_bytes(self):
assert hasattr(self, "dsa_topk_indices_buffer")
return get_tensor_size_bytes(self.dsa_topk_indices_buffer)

def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor):
assert layer_id is not None, "capturing routing experts but get layer_id None"
batch, _ = topk_ids.shape
self.buffer[:batch, layer_id, :] = topk_ids

def capture_fwd_dsa_topk_indices(
self, layer_id: int, dsa_topk_indices: torch.Tensor
):
assert layer_id is not None, "capturing dsa topk indices but get layer_id None"
batch, _ = dsa_topk_indices.shape
self.dsa_topk_indices_buffer[:batch, layer_id, :] = dsa_topk_indices

def _finalize_allocation_log(self):
"""Common logging and memory usage computation for captured experts buffers."""
buffer_size_MB = self.get_buffer_size_bytes() / _MB
logger.info(
f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB"
)
if self.dsa_topk_indices_buffer is not None:
dsa_topk_indices_size_MB = (
self.get_dsa_topk_indices_buffer_size_bytes() / _MB
)
logger.info(
f"DSA topk indices device buffer allocated. #shape: {tuple(self.dsa_topk_indices_buffer.shape)}, size: {dsa_topk_indices_size_MB:.2f} MB"
)


class _RoutedExpertsHostCache:
Expand All @@ -73,6 +108,8 @@ def __init__(
num_tokens: int,
num_hidden_layers: int,
num_experts_per_tok: int,
enable_capture_dsa_topk_indices: bool = False,
num_dsa_topk_indices: int = 2048,
) -> None:
self.num_tokens = num_tokens
self.buffer = torch.zeros(
Expand All @@ -85,12 +122,28 @@ def __init__(
device="cpu",
pin_memory=True,
)
self.dsa_topk_indices_buffer = None
if enable_capture_dsa_topk_indices:
self.dsa_topk_indices_buffer = torch.zeros(
(
num_tokens,
num_hidden_layers,
num_dsa_topk_indices,
),
dtype=torch.int32,
device="cpu",
pin_memory=True,
)
self._finalize_allocation_log()

def get_buffer_size_bytes(self):
assert hasattr(self, "buffer")
return get_tensor_size_bytes(self.buffer)

def get_dsa_topk_indices_buffer_size_bytes(self):
assert hasattr(self, "dsa_topk_indices_buffer")
return get_tensor_size_bytes(self.dsa_topk_indices_buffer)

def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor):
self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True)

Expand All @@ -100,6 +153,13 @@ def _finalize_allocation_log(self):
logger.info(
f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB"
)
if self.dsa_topk_indices_buffer is not None:
dsa_topk_indices_size_GB = (
self.get_dsa_topk_indices_buffer_size_bytes() / _GB
)
logger.info(
f"DSA topk indices host buffer allocated. #tokens: {self.num_tokens}, size: {dsa_topk_indices_size_GB:.2f} GB"
)


class RoutedExpertsCapturer(ABC):
Expand All @@ -111,6 +171,7 @@ def create(
num_tokens: int,
max_running_requests: int,
device: str,
enable_capture_dsa_topk_indices: bool = False,
):
if enable:
return _RoutedExpertsCapturerReal(
Expand All @@ -119,6 +180,7 @@ def create(
max_running_requests=max_running_requests,
num_fused_shared_experts=num_fused_shared_experts,
device=device,
enable_capture_dsa_topk_indices=enable_capture_dsa_topk_indices,
)
else:
return _RoutedExpertsCapturerNoop()
Expand All @@ -134,6 +196,9 @@ def _sync_fwd_experts_buffer_DtoH(
def capture(self, layer_id: int, topk_ids: torch.Tensor):
raise NotImplementedError

def capture_dsa_topk_indices(self, layer_id: int, dsa_topk_indices: torch.Tensor):
raise NotImplementedError

def get_routed_experts(
self,
req_pool_idx: int,
Expand All @@ -142,6 +207,14 @@ def get_routed_experts(
):
raise NotImplementedError

def get_dsa_topk_indices(
self,
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
):
raise NotImplementedError

def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch):
raise NotImplementedError

Expand All @@ -162,15 +235,24 @@ def __init__(
max_running_requests: int,
num_fused_shared_experts: int,
device: str,
enable_capture_dsa_topk_indices: bool = False,
):
self.num_fused_shared_experts = num_fused_shared_experts
self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers
self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok

self.enable_capture_dsa_topk_indices = enable_capture_dsa_topk_indices
if enable_capture_dsa_topk_indices:
self.num_dsa_topk_indices = model_config.hf_text_config.index_topk
else:
self.num_dsa_topk_indices = None

self.host_cache = _RoutedExpertsHostCache(
num_tokens=num_tokens,
num_hidden_layers=self.num_hidden_layers,
num_experts_per_tok=self.num_experts_per_tok,
enable_capture_dsa_topk_indices=self.enable_capture_dsa_topk_indices,
num_dsa_topk_indices=self.num_dsa_topk_indices,
)

self.device_cache = _RoutedExpertsDeviceCache(
Expand All @@ -179,6 +261,8 @@ def __init__(
num_experts_per_tok=self.num_experts_per_tok,
num_fused_shared_experts=self.num_fused_shared_experts,
device=device,
enable_capture_dsa_topk_indices=self.enable_capture_dsa_topk_indices,
num_dsa_topk_indices=self.num_dsa_topk_indices,
)

def _sync_fwd_experts_buffer_DtoH(
Expand All @@ -205,9 +289,20 @@ def _sync_fwd_experts_buffer_DtoH(
local_start_pos:local_end_pos, :, : self.num_experts_per_tok
].cpu()

if self.enable_capture_dsa_topk_indices:
self.host_cache.dsa_topk_indices_buffer[out_cache_loc_cpu] = (
self.device_cache.dsa_topk_indices_buffer[
local_start_pos:local_end_pos, :, : self.num_dsa_topk_indices
].cpu()
)

def capture(self, layer_id: int, topk_ids: torch.Tensor):
self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)

def capture_dsa_topk_indices(self, layer_id: int, dsa_topk_indices: torch.Tensor):
if self.enable_capture_dsa_topk_indices:
self.device_cache.capture_fwd_dsa_topk_indices(layer_id, dsa_topk_indices)

def get_routed_experts(
self,
req_pool_idx: int,
Expand All @@ -219,6 +314,18 @@ def get_routed_experts(
)
return self.get_host_cache().buffer[cache_pool_idx]

def get_dsa_topk_indices(
self,
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
):
if self.enable_capture_dsa_topk_indices:
cache_pool_idx = (
req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone()
)
return self.get_host_cache().dsa_topk_indices_buffer[cache_pool_idx]

def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch):
self._sync_fwd_experts_buffer_DtoH(
forward_batch=forward_batch,
Expand Down Expand Up @@ -248,6 +355,9 @@ def _sync_fwd_experts_buffer_DtoH(
def capture(self, layer_id: int, topk_ids: torch.Tensor):
pass

def capture_dsa_topk_indices(self, layer_id: int, dsa_topk_indices: torch.Tensor):
pass

def get_routed_experts(
self,
req_pool_idx: int,
Expand All @@ -256,6 +366,14 @@ def get_routed_experts(
):
pass

def get_dsa_topk_indices(
self,
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
):
pass

def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch):
pass

Expand Down
19 changes: 17 additions & 2 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput):

def _extract_routed_experts(self, recv_obj: BatchTokenIDOutput) -> List[List[int]]:
output_routed_experts = None
output_dsa_topk_indices = None
if recv_obj.output_routed_experts is not None:
output_routed_experts = [
(
Expand All @@ -339,7 +340,18 @@ def _extract_routed_experts(self, recv_obj: BatchTokenIDOutput) -> List[List[int
)
for output_routed_experts in recv_obj.output_routed_experts
]
return output_routed_experts
if recv_obj.output_dsa_topk_indices is not None:
output_dsa_topk_indices = [
(
pybase64.b64encode(
output_dsa_topk_indices.numpy().tobytes()
).decode("utf-8")
if output_dsa_topk_indices is not None
else []
)
for output_dsa_topk_indices in recv_obj.output_dsa_topk_indices
]
return output_routed_experts, output_dsa_topk_indices

def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
# If handling idle batch, set output_strs to [].
Expand All @@ -348,7 +360,9 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
if len(recv_obj.rids) > 0
else []
)
output_routed_experts = self._extract_routed_experts(recv_obj)
output_routed_experts, output_dsa_topk_indices = self._extract_routed_experts(
recv_obj
)

return BatchStrOutput(
rids=recv_obj.rids,
Expand Down Expand Up @@ -376,6 +390,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
output_token_entropy_val=recv_obj.output_token_entropy_val,
output_hidden_states=recv_obj.output_hidden_states,
output_routed_experts=output_routed_experts,
output_dsa_topk_indices=output_dsa_topk_indices,
customized_info=recv_obj.customized_info,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
Expand Down
Loading
Loading