diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py index 12d5577af2ba..5fe3e95c72a6 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -1,3 +1,4 @@ +import dataclasses import logging from abc import ABC from typing import Optional @@ -31,6 +32,25 @@ def get_tensor_size_bytes(t: torch.Tensor): return np.prod(t.shape) * t.dtype.itemsize +@dataclasses.dataclass +class RoutedExpertsOutput: + """Holds GPU tensors captured during forward for overlap scheduling. + Call copy_to_cpu() inside forward stream (before copy_done.record()), + then finalize() after copy_done.synchronize(). + """ + + out_cache_loc: torch.Tensor + routed_experts: torch.Tensor + host_cache: "_RoutedExpertsHostCache" + + def copy_to_cpu(self): + self.out_cache_loc = self.out_cache_loc.to("cpu", non_blocking=True) + self.routed_experts = self.routed_experts.to("cpu", non_blocking=True) + + def finalize(self): + self.host_cache.buffer[self.out_cache_loc] = self.routed_experts + + class _RoutedExpertsDeviceCache: def __init__( self, @@ -147,7 +167,9 @@ def get_routed_experts( ): raise NotImplementedError - def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): + def on_forward_end( + self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False + ) -> Optional[RoutedExpertsOutput]: raise NotImplementedError def get_host_cache(self): @@ -197,32 +219,48 @@ def __init__( device=device, ) - def _sync_fwd_experts_buffer_DtoH( - self, - forward_batch: ForwardBatch, - can_run_graph: bool, - cuda_graph_batch: int, - ): + def _get_local_range(self, forward_batch, can_run_graph, cuda_graph_batch): # When DeepEP is enabled, capture() already does all_gather, so device_cache.buffer # contains data from all DP ranks. We should not slice by DP rank in this case. if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) - # handle with cuda graph padding if can_run_graph: local_start_pos = get_attention_dp_rank() * cuda_graph_batch - local_end_pos = local_start_pos + local_num_tokens - else: - local_end_pos = local_start_pos + local_num_tokens + return local_start_pos, local_start_pos + local_num_tokens else: - local_start_pos = 0 - local_end_pos = forward_batch.out_cache_loc.shape[0] + return 0, forward_batch.out_cache_loc.shape[0] - # FIXME: sync explicitly here, overlap scheduler breaks here. + def _sync_fwd_experts_buffer_DtoH( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: int, + ): + local_start_pos, local_end_pos = self._get_local_range( + forward_batch, can_run_graph, cuda_graph_batch + ) out_cache_loc_cpu = forward_batch.out_cache_loc.cpu() self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[ local_start_pos:local_end_pos, :, : self.num_experts_per_tok ].cpu() + def _prepare_routed_experts_output( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: int, + ) -> RoutedExpertsOutput: + local_start_pos, local_end_pos = self._get_local_range( + forward_batch, can_run_graph, cuda_graph_batch + ) + return RoutedExpertsOutput( + out_cache_loc=forward_batch.out_cache_loc, + routed_experts=self.device_cache.buffer[ + local_start_pos:local_end_pos, :, : self.num_experts_per_tok + ], + host_cache=self.host_cache, + ) + def capture(self, layer_id: int, topk_ids: torch.Tensor): if get_moe_a2a_backend().is_deepep(): local_topk_ids = topk_ids @@ -243,12 +281,22 @@ def get_routed_experts( ) return self.get_host_cache().buffer[cache_pool_idx] - def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): - self._sync_fwd_experts_buffer_DtoH( - forward_batch=forward_batch, - can_run_graph=can_run_graph, - cuda_graph_batch=cuda_graph_batch, - ) + def on_forward_end( + self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False + ) -> Optional[RoutedExpertsOutput]: + if no_copy_to_cpu: + return self._prepare_routed_experts_output( + forward_batch=forward_batch, + can_run_graph=can_run_graph, + cuda_graph_batch=cuda_graph_batch, + ) + else: + self._sync_fwd_experts_buffer_DtoH( + forward_batch=forward_batch, + can_run_graph=can_run_graph, + cuda_graph_batch=cuda_graph_batch, + ) + return None def get_host_cache(self): return self.host_cache @@ -280,8 +328,10 @@ def get_routed_experts( ): pass - def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): - pass + def on_forward_end( + self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False + ) -> Optional[RoutedExpertsOutput]: + return None def get_host_cache(self): pass diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index a9b7be0f14f1..ee8a9555e1ac 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -132,6 +132,9 @@ def process_batch_result_prefill( if self.is_generation: if result.copy_done is not None: result.copy_done.synchronize() + if result.routed_experts_output is not None: + result.routed_experts_output.finalize() + result.routed_experts_output = None ( logits_output, @@ -380,6 +383,9 @@ def process_batch_result_decode( ): if result.copy_done is not None: result.copy_done.synchronize() + if result.routed_experts_output is not None: + result.routed_experts_output.finalize() + result.routed_experts_output = None logits_output, next_token_ids, can_run_cuda_graph = ( result.logits_output, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index fb56de158374..cb6a0000efe3 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -482,6 +482,7 @@ def forward_batch_generation( logits_output=logits_output, can_run_cuda_graph=can_run_cuda_graph, expert_distribution_metrics=out.expert_distribution_metrics, + routed_experts_output=out.routed_experts_output, ) if is_verify: diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index ba777330000f..8c3336fcb056 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -8,6 +8,7 @@ from sglang.srt.eplb.expert_distribution import ExpertDistributionMetrics from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.moe.routed_experts_capturer import RoutedExpertsOutput from sglang.srt.managers.overlap_utils import FutureIndices from sglang.srt.managers.schedule_batch import Req from sglang.srt.model_executor.forward_batch_info import PPProxyTensors @@ -46,6 +47,9 @@ class GenerationBatchResult: # relay path: forward stream -> next step forward next_draft_input: Optional[EagleDraftInput] = None + # Routed experts: pending async D2H for overlap scheduling + routed_experts_output: Optional[RoutedExpertsOutput] = None + # metrics expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None @@ -87,6 +91,9 @@ def copy_to_cpu(self, return_logprob: bool): if self.accept_lens is not None: self.accept_lens = self.accept_lens.to("cpu", non_blocking=True) + if self.routed_experts_output is not None: + self.routed_experts_output.copy_to_cpu() + if (x := self.expert_distribution_metrics) is not None: x.copy_to_cpu() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7a68e69c4f09..0f69e9248781 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -112,6 +112,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.routed_experts_capturer import ( RoutedExpertsCapturer, + RoutedExpertsOutput, get_global_experts_capturer, set_global_experts_capturer, ) @@ -284,6 +285,7 @@ class ModelRunnerOutput: logits_output: Union[LogitsProcessorOutput, PPProxyTensors] can_run_graph: bool expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None + routed_experts_output: Optional[RoutedExpertsOutput] = None class ModelRunner(ModelRunnerKVCacheMixin): @@ -2868,7 +2870,6 @@ def forward( ) output.expert_distribution_metrics = recorder_outputs.get("metrics") - # Copy cached routing experts' buffers back to CPU cache if not self.is_draft_worker: # In speculative decoding, num_tokens_per_bs > 1, so we need to pass # the actual number of tokens per dp rank in cuda graph, not batch size. @@ -2877,10 +2878,12 @@ def forward( cuda_graph_num_tokens = ( self.graph_runner.bs * self.graph_runner.num_tokens_per_bs ) - get_global_experts_capturer().on_forward_end( + no_copy_to_cpu = not self.server_args.disable_overlap_schedule + output.routed_experts_output = get_global_experts_capturer().on_forward_end( forward_batch=forward_batch, can_run_graph=output.can_run_graph, cuda_graph_batch=cuda_graph_num_tokens, + no_copy_to_cpu=no_copy_to_cpu, ) if self.eplb_manager is not None: diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 0ed93e198881..b234fdc9fb1c 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -866,6 +866,7 @@ def verify(self, batch: ModelWorkerBatch): can_run_cuda_graph=can_run_cuda_graph, next_draft_input=next_draft_input, accept_lens=accept_length, + routed_experts_output=forward_batch_output.routed_experts_output, ) def _compute_spec_v2_logprobs( diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py index fbad4adb94e2..c5fbb12b3809 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py @@ -749,4 +749,5 @@ def verify( can_run_cuda_graph=can_run_cuda_graph, next_draft_input=next_draft_input, accept_lens=accept_length, + routed_experts_output=forward_batch_output.routed_experts_output, )