Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/sglang/srt/hardware_backend/npu/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def fused_topk_npu(
if expert_location_dispatch_info is not None:
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
get_global_experts_capturer().capture(
layer_id=layer_id,
topk_ids=topk_ids,
)
if (cap := get_global_experts_capturer()) is not None:
cap.capture(
layer_id=layer_id,
topk_indices=topk_ids,
)

return StandardTopKOutput(topk_weights, topk_ids, router_logits)
321 changes: 49 additions & 272 deletions python/sglang/srt/layers/moe/routed_experts_capturer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import dataclasses
import logging
from abc import ABC
from typing import Optional

import numpy as np
Expand All @@ -13,116 +10,21 @@
get_dp_local_info,
is_dp_attention_enabled,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.layers.topk_capturer_base import BaseTopkCapturer
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args

logger = logging.getLogger(__name__)

_GB = 1024 * 1024 * 1024
_MB = 1024 * 1024
class RoutedExpertsCapturer(BaseTopkCapturer):
"""Capturer for routed experts with host buffer.


def get_tensor_size_bytes(t: torch.Tensor):
return np.prod(t.shape) * t.dtype.itemsize


@dataclasses.dataclass
class RoutedExpertsOutput:
"""Holds GPU tensors captured during forward for overlap scheduling.
Call copy_to_cpu() inside forward stream (before copy_done.record()),
then finalize() after copy_done.synchronize().
Routed experts share a global device buffer across DP ranks (indexed by
dp_rank), so `_get_local_slice` overrides the default to apply DP-rank-aware
slicing. The device cache also holds extra columns for any fused shared
experts; the host cache and user-facing return drop them via the
[:topk_size] truncation.
"""

out_cache_loc: torch.Tensor
routed_experts: torch.Tensor
host_cache: "_RoutedExpertsHostCache"

def copy_to_cpu(self):
self.out_cache_loc = self.out_cache_loc.to("cpu", non_blocking=True)
self.routed_experts = self.routed_experts.to("cpu", non_blocking=True)

def finalize(self):
self.host_cache.buffer[self.out_cache_loc] = self.routed_experts


class _RoutedExpertsDeviceCache:
def __init__(
self,
max_running_requests: int,
num_hidden_layers: int,
num_experts_per_tok: int,
num_fused_shared_experts: int,
device: str,
) -> None:
self.buffer = torch.zeros(
(
max(
get_global_server_args().chunked_prefill_size
* get_global_server_args().dp_size,
max_running_requests,
),
num_hidden_layers,
num_experts_per_tok + num_fused_shared_experts,
),
dtype=torch.int32,
device=device,
)
self._finalize_allocation_log()

def get_buffer_size_bytes(self):
assert hasattr(self, "buffer")
return get_tensor_size_bytes(self.buffer)

def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor):
assert layer_id is not None, "capturing routing experts but get layer_id None"
batch, _ = topk_ids.shape
self.buffer[:batch, layer_id, :] = topk_ids

def _finalize_allocation_log(self):
"""Common logging and memory usage computation for captured experts buffers."""
buffer_size_MB = self.get_buffer_size_bytes() / _MB
logger.info(
f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB"
)


class _RoutedExpertsHostCache:
def __init__(
self,
num_tokens: int,
num_hidden_layers: int,
num_experts_per_tok: int,
) -> None:
self.num_tokens = num_tokens
self.buffer = torch.zeros(
(
num_tokens,
num_hidden_layers,
num_experts_per_tok,
),
dtype=torch.int32,
device="cpu",
pin_memory=True,
)
self._finalize_allocation_log()

def get_buffer_size_bytes(self):
assert hasattr(self, "buffer")
return get_tensor_size_bytes(self.buffer)

def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor):
self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True)

def _finalize_allocation_log(self):
"""Common logging and memory usage computation for captured experts buffers."""
buffer_size_GB = self.get_buffer_size_bytes() / _GB
logger.info(
f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB"
)


class RoutedExpertsCapturer(ABC):
@staticmethod
def create(
enable: bool,
Expand All @@ -131,51 +33,16 @@ def create(
num_tokens: int,
max_running_requests: int,
device: str,
):
if enable:
return _RoutedExpertsCapturerReal(
model_config,
num_tokens=num_tokens,
max_running_requests=max_running_requests,
num_fused_shared_experts=num_fused_shared_experts,
device=device,
)
else:
return _RoutedExpertsCapturerNoop()

def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
raise NotImplementedError

def capture(self, layer_id: int, topk_ids: torch.Tensor):
raise NotImplementedError

def get_routed_experts(
self,
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
):
raise NotImplementedError

def on_forward_end(
self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False
) -> Optional[RoutedExpertsOutput]:
raise NotImplementedError

def get_host_cache(self):
raise NotImplementedError

def get_device_cache(self):
raise NotImplementedError


class _RoutedExpertsCapturerReal(RoutedExpertsCapturer):
"""Capturer for routed experts with host buffer"""
) -> Optional["RoutedExpertsCapturer"]:
if not enable:
return None
return RoutedExpertsCapturer(
model_config,
num_tokens=num_tokens,
max_running_requests=max_running_requests,
num_fused_shared_experts=num_fused_shared_experts,
device=device,
)

def __init__(
self,
Expand All @@ -186,144 +53,54 @@ def __init__(
device: str,
):
self.num_fused_shared_experts = num_fused_shared_experts
self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers
self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok

self.host_cache = _RoutedExpertsHostCache(
num_tokens=num_tokens,
num_hidden_layers=self.num_hidden_layers,
num_experts_per_tok=self.num_experts_per_tok,
topk_size = model_config.hf_text_config.num_experts_per_tok
num_layers = model_config.hf_text_config.num_hidden_layers

server_args = get_global_server_args()
# FIXME: spec decoding is not accounted for here. The device buffer can
# overflow when max_running_requests * num_verify_tokens exceeds
# chunked_prefill_size * dp_size.
max_batch_size = max(
server_args.chunked_prefill_size * server_args.dp_size,
max_running_requests,
)

self.device_cache = _RoutedExpertsDeviceCache(
max_running_requests=max_running_requests,
num_hidden_layers=self.num_hidden_layers,
num_experts_per_tok=self.num_experts_per_tok,
num_fused_shared_experts=self.num_fused_shared_experts,
super().__init__(
num_tokens=num_tokens,
max_batch_size=max_batch_size,
num_layers=num_layers,
topk_size=topk_size,
device=device,
name="routed_experts",
device_topk_size=topk_size + num_fused_shared_experts,
)

def _get_local_range(self, forward_batch, can_run_graph, cuda_graph_batch):
def _get_local_slice(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: Optional[int],
) -> torch.Tensor:
if is_dp_attention_enabled():
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
if can_run_graph:
local_start_pos = get_attention_dp_rank() * cuda_graph_batch
return local_start_pos, local_start_pos + local_num_tokens
local_end_pos = local_start_pos + local_num_tokens
else:
return 0, forward_batch.out_cache_loc.shape[0]

def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
local_start_pos, local_end_pos = self._get_local_range(
forward_batch, can_run_graph, cuda_graph_batch
)
out_cache_loc_cpu = forward_batch.out_cache_loc.cpu()
self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[
local_start_pos:local_end_pos, :, : self.num_experts_per_tok
].cpu()

def _prepare_routed_experts_output(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
) -> RoutedExpertsOutput:
local_start_pos, local_end_pos = self._get_local_range(
forward_batch, can_run_graph, cuda_graph_batch
)
return RoutedExpertsOutput(
out_cache_loc=forward_batch.out_cache_loc,
routed_experts=self.device_cache.buffer[
local_start_pos:local_end_pos, :, : self.num_experts_per_tok
],
host_cache=self.host_cache,
)

def capture(self, layer_id: int, topk_ids: torch.Tensor):
self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)

def get_routed_experts(
self,
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
):
cache_pool_idx = (
req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone()
)
return self.get_host_cache().buffer[cache_pool_idx]

def on_forward_end(
self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False
) -> Optional[RoutedExpertsOutput]:
if no_copy_to_cpu:
return self._prepare_routed_experts_output(
forward_batch=forward_batch,
can_run_graph=can_run_graph,
cuda_graph_batch=cuda_graph_batch,
)
else:
self._sync_fwd_experts_buffer_DtoH(
forward_batch=forward_batch,
can_run_graph=can_run_graph,
cuda_graph_batch=cuda_graph_batch,
)
return None

def get_host_cache(self):
return self.host_cache

def get_device_cache(self):
return self.device_cache


class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer):
def __init__(self):
pass

def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
pass

def capture(self, layer_id: int, topk_ids: torch.Tensor):
pass

def get_routed_experts(
self,
req_pool_idx: int,
seqlen: int,
req_to_token_pool: ReqToTokenPool,
):
pass

def on_forward_end(
self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False
) -> Optional[RoutedExpertsOutput]:
return None

def get_host_cache(self):
pass

def get_device_cache(self):
pass
local_start_pos, local_end_pos = 0, forward_batch.out_cache_loc.shape[0]
return self.device_cache.buffer[
local_start_pos:local_end_pos, :, : self.topk_size
]


_global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop()
_global_expert_capturer: Optional[RoutedExpertsCapturer] = None


def get_global_experts_capturer():
def get_global_experts_capturer() -> Optional[RoutedExpertsCapturer]:
return _global_expert_capturer


def set_global_experts_capturer(capturer: RoutedExpertsCapturer):
def set_global_experts_capturer(capturer: Optional[RoutedExpertsCapturer]):
global _global_expert_capturer
_global_expert_capturer = capturer

Expand Down
Loading
Loading