diff --git a/python/sglang/srt/hardware_backend/npu/moe/topk.py b/python/sglang/srt/hardware_backend/npu/moe/topk.py index 3ceb0eaad10b..3e1b6d464603 100644 --- a/python/sglang/srt/hardware_backend/npu/moe/topk.py +++ b/python/sglang/srt/hardware_backend/npu/moe/topk.py @@ -92,9 +92,10 @@ def fused_topk_npu( if expert_location_dispatch_info is not None: topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) - get_global_experts_capturer().capture( - layer_id=layer_id, - topk_ids=topk_ids, - ) + if (cap := get_global_experts_capturer()) is not None: + cap.capture( + layer_id=layer_id, + topk_indices=topk_ids, + ) return StandardTopKOutput(topk_weights, topk_ids, router_logits) diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py index dc68b7074cda..8b0d9e5930e6 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -1,6 +1,3 @@ -import dataclasses -import logging -from abc import ABC from typing import Optional import numpy as np @@ -13,116 +10,21 @@ get_dp_local_info, is_dp_attention_enabled, ) -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.layers.topk_capturer_base import BaseTopkCapturer from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import get_global_server_args -logger = logging.getLogger(__name__) -_GB = 1024 * 1024 * 1024 -_MB = 1024 * 1024 +class RoutedExpertsCapturer(BaseTopkCapturer): + """Capturer for routed experts with host buffer. - -def get_tensor_size_bytes(t: torch.Tensor): - return np.prod(t.shape) * t.dtype.itemsize - - -@dataclasses.dataclass -class RoutedExpertsOutput: - """Holds GPU tensors captured during forward for overlap scheduling. - Call copy_to_cpu() inside forward stream (before copy_done.record()), - then finalize() after copy_done.synchronize(). + Routed experts share a global device buffer across DP ranks (indexed by + dp_rank), so `_get_local_slice` overrides the default to apply DP-rank-aware + slicing. The device cache also holds extra columns for any fused shared + experts; the host cache and user-facing return drop them via the + [:topk_size] truncation. """ - out_cache_loc: torch.Tensor - routed_experts: torch.Tensor - host_cache: "_RoutedExpertsHostCache" - - def copy_to_cpu(self): - self.out_cache_loc = self.out_cache_loc.to("cpu", non_blocking=True) - self.routed_experts = self.routed_experts.to("cpu", non_blocking=True) - - def finalize(self): - self.host_cache.buffer[self.out_cache_loc] = self.routed_experts - - -class _RoutedExpertsDeviceCache: - def __init__( - self, - max_running_requests: int, - num_hidden_layers: int, - num_experts_per_tok: int, - num_fused_shared_experts: int, - device: str, - ) -> None: - self.buffer = torch.zeros( - ( - max( - get_global_server_args().chunked_prefill_size - * get_global_server_args().dp_size, - max_running_requests, - ), - num_hidden_layers, - num_experts_per_tok + num_fused_shared_experts, - ), - dtype=torch.int32, - device=device, - ) - self._finalize_allocation_log() - - def get_buffer_size_bytes(self): - assert hasattr(self, "buffer") - return get_tensor_size_bytes(self.buffer) - - def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): - assert layer_id is not None, "capturing routing experts but get layer_id None" - batch, _ = topk_ids.shape - self.buffer[:batch, layer_id, :] = topk_ids - - def _finalize_allocation_log(self): - """Common logging and memory usage computation for captured experts buffers.""" - buffer_size_MB = self.get_buffer_size_bytes() / _MB - logger.info( - f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB" - ) - - -class _RoutedExpertsHostCache: - def __init__( - self, - num_tokens: int, - num_hidden_layers: int, - num_experts_per_tok: int, - ) -> None: - self.num_tokens = num_tokens - self.buffer = torch.zeros( - ( - num_tokens, - num_hidden_layers, - num_experts_per_tok, - ), - dtype=torch.int32, - device="cpu", - pin_memory=True, - ) - self._finalize_allocation_log() - - def get_buffer_size_bytes(self): - assert hasattr(self, "buffer") - return get_tensor_size_bytes(self.buffer) - - def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): - self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) - - def _finalize_allocation_log(self): - """Common logging and memory usage computation for captured experts buffers.""" - buffer_size_GB = self.get_buffer_size_bytes() / _GB - logger.info( - f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB" - ) - - -class RoutedExpertsCapturer(ABC): @staticmethod def create( enable: bool, @@ -131,51 +33,16 @@ def create( num_tokens: int, max_running_requests: int, device: str, - ): - if enable: - return _RoutedExpertsCapturerReal( - model_config, - num_tokens=num_tokens, - max_running_requests=max_running_requests, - num_fused_shared_experts=num_fused_shared_experts, - device=device, - ) - else: - return _RoutedExpertsCapturerNoop() - - def _sync_fwd_experts_buffer_DtoH( - self, - forward_batch: ForwardBatch, - can_run_graph: bool, - cuda_graph_batch: int, - ): - raise NotImplementedError - - def capture(self, layer_id: int, topk_ids: torch.Tensor): - raise NotImplementedError - - def get_routed_experts( - self, - req_pool_idx: int, - seqlen: int, - req_to_token_pool: ReqToTokenPool, - ): - raise NotImplementedError - - def on_forward_end( - self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False - ) -> Optional[RoutedExpertsOutput]: - raise NotImplementedError - - def get_host_cache(self): - raise NotImplementedError - - def get_device_cache(self): - raise NotImplementedError - - -class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): - """Capturer for routed experts with host buffer""" + ) -> Optional["RoutedExpertsCapturer"]: + if not enable: + return None + return RoutedExpertsCapturer( + model_config, + num_tokens=num_tokens, + max_running_requests=max_running_requests, + num_fused_shared_experts=num_fused_shared_experts, + device=device, + ) def __init__( self, @@ -186,144 +53,54 @@ def __init__( device: str, ): self.num_fused_shared_experts = num_fused_shared_experts - self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers - self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok - - self.host_cache = _RoutedExpertsHostCache( - num_tokens=num_tokens, - num_hidden_layers=self.num_hidden_layers, - num_experts_per_tok=self.num_experts_per_tok, + topk_size = model_config.hf_text_config.num_experts_per_tok + num_layers = model_config.hf_text_config.num_hidden_layers + + server_args = get_global_server_args() + # FIXME: spec decoding is not accounted for here. The device buffer can + # overflow when max_running_requests * num_verify_tokens exceeds + # chunked_prefill_size * dp_size. + max_batch_size = max( + server_args.chunked_prefill_size * server_args.dp_size, + max_running_requests, ) - self.device_cache = _RoutedExpertsDeviceCache( - max_running_requests=max_running_requests, - num_hidden_layers=self.num_hidden_layers, - num_experts_per_tok=self.num_experts_per_tok, - num_fused_shared_experts=self.num_fused_shared_experts, + super().__init__( + num_tokens=num_tokens, + max_batch_size=max_batch_size, + num_layers=num_layers, + topk_size=topk_size, device=device, + name="routed_experts", + device_topk_size=topk_size + num_fused_shared_experts, ) - def _get_local_range(self, forward_batch, can_run_graph, cuda_graph_batch): + def _get_local_slice( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], + ) -> torch.Tensor: if is_dp_attention_enabled(): local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) if can_run_graph: local_start_pos = get_attention_dp_rank() * cuda_graph_batch - return local_start_pos, local_start_pos + local_num_tokens + local_end_pos = local_start_pos + local_num_tokens else: - return 0, forward_batch.out_cache_loc.shape[0] - - def _sync_fwd_experts_buffer_DtoH( - self, - forward_batch: ForwardBatch, - can_run_graph: bool, - cuda_graph_batch: int, - ): - local_start_pos, local_end_pos = self._get_local_range( - forward_batch, can_run_graph, cuda_graph_batch - ) - out_cache_loc_cpu = forward_batch.out_cache_loc.cpu() - self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[ - local_start_pos:local_end_pos, :, : self.num_experts_per_tok - ].cpu() - - def _prepare_routed_experts_output( - self, - forward_batch: ForwardBatch, - can_run_graph: bool, - cuda_graph_batch: int, - ) -> RoutedExpertsOutput: - local_start_pos, local_end_pos = self._get_local_range( - forward_batch, can_run_graph, cuda_graph_batch - ) - return RoutedExpertsOutput( - out_cache_loc=forward_batch.out_cache_loc, - routed_experts=self.device_cache.buffer[ - local_start_pos:local_end_pos, :, : self.num_experts_per_tok - ], - host_cache=self.host_cache, - ) - - def capture(self, layer_id: int, topk_ids: torch.Tensor): - self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) - - def get_routed_experts( - self, - req_pool_idx: int, - seqlen: int, - req_to_token_pool: ReqToTokenPool, - ): - cache_pool_idx = ( - req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() - ) - return self.get_host_cache().buffer[cache_pool_idx] - - def on_forward_end( - self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False - ) -> Optional[RoutedExpertsOutput]: - if no_copy_to_cpu: - return self._prepare_routed_experts_output( - forward_batch=forward_batch, - can_run_graph=can_run_graph, - cuda_graph_batch=cuda_graph_batch, - ) - else: - self._sync_fwd_experts_buffer_DtoH( - forward_batch=forward_batch, - can_run_graph=can_run_graph, - cuda_graph_batch=cuda_graph_batch, - ) - return None - - def get_host_cache(self): - return self.host_cache - - def get_device_cache(self): - return self.device_cache - - -class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer): - def __init__(self): - pass - - def _sync_fwd_experts_buffer_DtoH( - self, - forward_batch: ForwardBatch, - can_run_graph: bool, - cuda_graph_batch: int, - ): - pass - - def capture(self, layer_id: int, topk_ids: torch.Tensor): - pass - - def get_routed_experts( - self, - req_pool_idx: int, - seqlen: int, - req_to_token_pool: ReqToTokenPool, - ): - pass - - def on_forward_end( - self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False - ) -> Optional[RoutedExpertsOutput]: - return None - - def get_host_cache(self): - pass - - def get_device_cache(self): - pass + local_start_pos, local_end_pos = 0, forward_batch.out_cache_loc.shape[0] + return self.device_cache.buffer[ + local_start_pos:local_end_pos, :, : self.topk_size + ] -_global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop() +_global_expert_capturer: Optional[RoutedExpertsCapturer] = None -def get_global_experts_capturer(): +def get_global_experts_capturer() -> Optional[RoutedExpertsCapturer]: return _global_expert_capturer -def set_global_experts_capturer(capturer: RoutedExpertsCapturer): +def set_global_experts_capturer(capturer: Optional[RoutedExpertsCapturer]): global _global_expert_capturer _global_expert_capturer = capturer diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 2eba4f094a3a..6fdbb1a70f6c 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -1071,10 +1071,11 @@ def _post_process_topk_ids( fused_shared_experts_scaling_factor = ( topk_config.fused_shared_experts_scaling_factor ) - get_global_experts_capturer().capture( - layer_id=layer_id, - topk_ids=topk_ids, - ) + if (cap := get_global_experts_capturer()) is not None: + cap.capture( + layer_id=layer_id, + topk_indices=topk_ids, + ) if _is_cuda: # When shared experts are fused (appended as extra columns in topk_ids), # EPLB dispatch must only remap the routed expert columns. diff --git a/python/sglang/srt/layers/topk_capturer_base.py b/python/sglang/srt/layers/topk_capturer_base.py new file mode 100644 index 000000000000..8620804e7b6f --- /dev/null +++ b/python/sglang/srt/layers/topk_capturer_base.py @@ -0,0 +1,178 @@ +import dataclasses +import logging +from typing import Optional + +import torch + +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +logger = logging.getLogger(__name__) + +_GB = 1024 * 1024 * 1024 +_MB = 1024 * 1024 + + +def get_tensor_size_bytes(t: torch.Tensor) -> int: + return t.numel() * t.element_size() + + +class BaseDeviceCache: + def __init__( + self, + max_batch_size: int, + num_layers: int, + topk_size: int, + device: str, + name: str, + ): + self.buffer = torch.zeros( + (max_batch_size, num_layers, topk_size), + dtype=torch.int32, + device=device, + ) + self.num_layers = num_layers + self.topk_size = topk_size + self.name = name + self._log_allocation() + + def capture(self, layer_id: int, topk_indices: torch.Tensor): + batch = topk_indices.shape[0] + self.buffer[:batch, layer_id, :] = topk_indices + + def get_buffer_size_bytes(self): + return get_tensor_size_bytes(self.buffer) + + def _log_allocation(self): + size_mb = self.get_buffer_size_bytes() / _MB + logger.info( + f"DeviceCache[{self.name}] allocated: shape={tuple(self.buffer.shape)}, " + f"size={size_mb:.2f} MB" + ) + + +class BaseHostCache: + def __init__(self, num_tokens: int, num_layers: int, topk_size: int, name: str): + self.buffer = torch.zeros( + (num_tokens, num_layers, topk_size), + dtype=torch.int32, + device="cpu", + pin_memory=True, + ) + self.num_tokens = num_tokens + self.num_layers = num_layers + self.topk_size = topk_size + self.name = name + self._log_allocation() + + def get_buffer_size_bytes(self): + return get_tensor_size_bytes(self.buffer) + + def _log_allocation(self): + size_gb = self.get_buffer_size_bytes() / _GB + logger.info( + f"HostCache[{self.name}] allocated: shape={tuple(self.buffer.shape)}, " + f"size={size_gb:.2f} GB" + ) + + +@dataclasses.dataclass +class TopkCaptureOutput: + """Holds GPU tensors captured during forward for overlap scheduling. + Call copy_to_cpu() inside forward stream (before copy_done.record()), + then finalize() after copy_done.synchronize(). + """ + + out_cache_loc: torch.Tensor + topk: torch.Tensor + host_cache: BaseHostCache + + def copy_to_cpu(self): + self.out_cache_loc = self.out_cache_loc.to("cpu", non_blocking=True) + self.topk = self.topk.to("cpu", non_blocking=True) + + def finalize(self): + self.host_cache.buffer[self.out_cache_loc] = self.topk + + +class BaseTopkCapturer: + def __init__( + self, + num_tokens: int, + max_batch_size: int, + num_layers: int, + topk_size: int, + device: str, + name: str, + device_topk_size: Optional[int] = None, + ): + """device_topk_size defaults to topk_size; pass a different value when + the device buffer needs extra columns (e.g. fused shared experts) that + are dropped before writing to host_cache via [:topk_size] truncation. + """ + self.num_layers = num_layers + self.topk_size = topk_size + + self.host_cache = BaseHostCache(num_tokens, num_layers, topk_size, name=name) + self.device_cache = BaseDeviceCache( + max_batch_size, + num_layers, + device_topk_size if device_topk_size is not None else topk_size, + device, + name=name, + ) + + def capture(self, layer_id: int, topk_indices: torch.Tensor): + self.device_cache.capture(layer_id, topk_indices) + + def _get_local_slice( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], + ) -> torch.Tensor: + """Return the device_cache slice for this forward batch, GPU-resident. + + Default assumes per-rank-local capture: each rank writes [:local_num_tokens) + to its own device_cache. Subclasses with global-tensor capture semantics + (e.g. shared cuda graph buffer indexed by dp_rank) should override and + consume can_run_graph / cuda_graph_batch. + """ + del can_run_graph, cuda_graph_batch # reserved for subclass override + num_tokens = forward_batch.out_cache_loc.shape[0] + return self.device_cache.buffer[:num_tokens, :, : self.topk_size] + + def get_topk( + self, + req_pool_idx: int, + seqlen: int, + req_to_token_pool: ReqToTokenPool, + ) -> torch.Tensor: + cache_pool_idx = req_to_token_pool.req_to_token[req_pool_idx][ + : seqlen - 1 + ].cpu() + return self.host_cache.buffer[cache_pool_idx] + + def on_forward_end( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], + no_copy_to_cpu: bool = False, + ) -> Optional[TopkCaptureOutput]: + """If no_copy_to_cpu is True, return a TopkCaptureOutput holding GPU tensors so + the overlap thread can do non-blocking D2H + finalize itself. Otherwise sync + D2H inline and return None (legacy non-overlap path). + """ + slice_gpu = self._get_local_slice( + forward_batch, can_run_graph, cuda_graph_batch + ) + if no_copy_to_cpu: + return TopkCaptureOutput( + out_cache_loc=forward_batch.out_cache_loc, + topk=slice_gpu, + host_cache=self.host_cache, + ) + out_cache_loc_cpu = forward_batch.out_cache_loc.cpu() + self.host_cache.buffer[out_cache_loc_cpu] = slice_gpu.cpu() + return None diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 044deafdd77c..76983add6d53 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -106,7 +106,10 @@ def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch): def maybe_collect_routed_experts(self: Scheduler, req: Req): """Collect routed experts for a finished request.""" - req.routed_experts = get_global_experts_capturer().get_routed_experts( + capturer = get_global_experts_capturer() + if capturer is None: + return + req.routed_experts = capturer.get_topk( req_pool_idx=req.req_pool_idx, seqlen=req.seqlen, req_to_token_pool=self.req_to_token_pool, diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index 81fbc8a05b96..9d34ae31df20 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -8,7 +8,7 @@ from sglang.srt.eplb.expert_distribution import ExpertDistributionMetrics from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.moe.routed_experts_capturer import RoutedExpertsOutput +from sglang.srt.layers.topk_capturer_base import TopkCaptureOutput from sglang.srt.managers.overlap_utils import FutureIndices from sglang.srt.managers.schedule_batch import Req from sglang.srt.model_executor.forward_batch_info import PPProxyTensors @@ -48,7 +48,7 @@ class GenerationBatchResult: next_draft_input: Optional[EagleDraftInput] = None # Routed experts: pending async D2H for overlap scheduling - routed_experts_output: Optional[RoutedExpertsOutput] = None + routed_experts_output: Optional[TopkCaptureOutput] = None # metrics expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f5d9f6f6ab8e..565e20bfe8fe 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -118,13 +118,13 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.routed_experts_capturer import ( RoutedExpertsCapturer, - RoutedExpertsOutput, get_global_experts_capturer, set_global_experts_capturer, ) from sglang.srt.layers.pooler import EmbeddingPoolerOutput from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype from sglang.srt.layers.sampler import create_sampler +from sglang.srt.layers.topk_capturer_base import TopkCaptureOutput from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_registry import LoRARef @@ -305,7 +305,7 @@ class ModelRunnerOutput: logits_output: Union[LogitsProcessorOutput, PPProxyTensors] can_run_graph: bool expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None - routed_experts_output: Optional[RoutedExpertsOutput] = None + routed_experts_output: Optional[TopkCaptureOutput] = None class ModelRunner(ModelRunnerKVCacheMixin): @@ -3219,12 +3219,13 @@ def forward( output.expert_distribution_metrics = recorder_outputs.get("metrics") no_copy_to_cpu = not self.server_args.disable_overlap_schedule - output.routed_experts_output = get_global_experts_capturer().on_forward_end( - forward_batch=forward_batch, - can_run_graph=output.can_run_graph, - cuda_graph_batch=getattr(self.graph_runner, "bs", None), - no_copy_to_cpu=no_copy_to_cpu, - ) + if (experts_capturer := get_global_experts_capturer()) is not None: + output.routed_experts_output = experts_capturer.on_forward_end( + forward_batch=forward_batch, + can_run_graph=output.can_run_graph, + cuda_graph_batch=getattr(self.graph_runner, "bs", None), + no_copy_to_cpu=no_copy_to_cpu, + ) if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end()