Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions python/sglang/srt/layers/moe/routed_experts_capturer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import logging
from abc import ABC
from typing import Optional
Expand Down Expand Up @@ -26,6 +27,25 @@ def get_tensor_size_bytes(t: torch.Tensor):
return np.prod(t.shape) * t.dtype.itemsize


@dataclasses.dataclass
class RoutedExpertsOutput:
"""Holds GPU tensors captured during forward for overlap scheduling.
Call copy_to_cpu() inside forward stream (before copy_done.record()),
then finalize() after copy_done.synchronize().
"""

out_cache_loc: torch.Tensor
routed_experts: torch.Tensor
host_cache: "_RoutedExpertsHostCache"

def copy_to_cpu(self):
self.out_cache_loc = self.out_cache_loc.to("cpu", non_blocking=True)
self.routed_experts = self.routed_experts.to("cpu", non_blocking=True)

def finalize(self):
self.host_cache.buffer[self.out_cache_loc] = self.routed_experts


class _RoutedExpertsDeviceCache:
def __init__(
self,
Expand Down Expand Up @@ -142,7 +162,9 @@ def get_routed_experts(
):
raise NotImplementedError

def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch):
def on_forward_end(
self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False
) -> Optional[RoutedExpertsOutput]:
raise NotImplementedError

def get_host_cache(self):
Expand Down Expand Up @@ -181,30 +203,46 @@ def __init__(
device=device,
)

def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
def _get_local_range(self, forward_batch, can_run_graph, cuda_graph_batch):
if is_dp_attention_enabled():
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
# handle with cuda graph padding
if can_run_graph:
local_start_pos = get_attention_dp_rank() * cuda_graph_batch
local_end_pos = local_start_pos + local_num_tokens
else:
local_end_pos = local_start_pos + local_num_tokens
return local_start_pos, local_start_pos + local_num_tokens
else:
local_start_pos = 0
local_end_pos = forward_batch.out_cache_loc.shape[0]
return 0, forward_batch.out_cache_loc.shape[0]

# FIXME: sync explicitly here, overlap scheduler breaks here.
def _sync_fwd_experts_buffer_DtoH(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
):
local_start_pos, local_end_pos = self._get_local_range(
forward_batch, can_run_graph, cuda_graph_batch
)
out_cache_loc_cpu = forward_batch.out_cache_loc.cpu()
self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[
local_start_pos:local_end_pos, :, : self.num_experts_per_tok
].cpu()

def _prepare_routed_experts_output(
self,
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: int,
) -> RoutedExpertsOutput:
local_start_pos, local_end_pos = self._get_local_range(
forward_batch, can_run_graph, cuda_graph_batch
)
return RoutedExpertsOutput(
out_cache_loc=forward_batch.out_cache_loc,
routed_experts=self.device_cache.buffer[
local_start_pos:local_end_pos, :, : self.num_experts_per_tok
],
host_cache=self.host_cache,
)

def capture(self, layer_id: int, topk_ids: torch.Tensor):
self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids)

Expand All @@ -219,12 +257,22 @@ def get_routed_experts(
)
return self.get_host_cache().buffer[cache_pool_idx]

def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch):
self._sync_fwd_experts_buffer_DtoH(
forward_batch=forward_batch,
can_run_graph=can_run_graph,
cuda_graph_batch=cuda_graph_batch,
)
def on_forward_end(
self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False
) -> Optional[RoutedExpertsOutput]:
if no_copy_to_cpu:
return self._prepare_routed_experts_output(
forward_batch=forward_batch,
can_run_graph=can_run_graph,
cuda_graph_batch=cuda_graph_batch,
)
else:
self._sync_fwd_experts_buffer_DtoH(
forward_batch=forward_batch,
can_run_graph=can_run_graph,
cuda_graph_batch=cuda_graph_batch,
)
return None

def get_host_cache(self):
return self.host_cache
Expand Down Expand Up @@ -256,8 +304,10 @@ def get_routed_experts(
):
pass

def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch):
pass
def on_forward_end(
self, forward_batch, can_run_graph, cuda_graph_batch, no_copy_to_cpu=False
) -> Optional[RoutedExpertsOutput]:
return None

def get_host_cache(self):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def process_batch_result_prefill(
if self.is_generation:
if result.copy_done is not None:
result.copy_done.synchronize()
if result.routed_experts_output is not None:
result.routed_experts_output.finalize()
result.routed_experts_output = None

(
logits_output,
Expand Down Expand Up @@ -391,6 +394,9 @@ def process_batch_result_decode(
):
if result.copy_done is not None:
result.copy_done.synchronize()
if result.routed_experts_output is not None:
result.routed_experts_output.finalize()
result.routed_experts_output = None

logits_output, next_token_ids, can_run_cuda_graph = (
result.logits_output,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def forward_batch_generation(
logits_output=logits_output,
can_run_cuda_graph=can_run_cuda_graph,
expert_distribution_metrics=out.expert_distribution_metrics,
routed_experts_output=out.routed_experts_output,
)

if is_verify:
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/managers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sglang.srt.eplb.expert_distribution import ExpertDistributionMetrics
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.routed_experts_capturer import RoutedExpertsOutput
from sglang.srt.managers.overlap_utils import FutureIndices
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
Expand Down Expand Up @@ -46,6 +47,9 @@ class GenerationBatchResult:
# relay path: forward stream -> next step forward
next_draft_input: Optional[EagleDraftInput] = None

# Routed experts: pending async D2H for overlap scheduling
routed_experts_output: Optional[RoutedExpertsOutput] = None

# metrics
expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None

Expand Down Expand Up @@ -87,6 +91,9 @@ def copy_to_cpu(self, return_logprob: bool):
if self.accept_lens is not None:
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)

if self.routed_experts_output is not None:
self.routed_experts_output.copy_to_cpu()

if (x := self.expert_distribution_metrics) is not None:
x.copy_to_cpu()

Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.routed_experts_capturer import (
RoutedExpertsCapturer,
RoutedExpertsOutput,
get_global_experts_capturer,
set_global_experts_capturer,
)
Expand Down Expand Up @@ -287,6 +288,7 @@ class ModelRunnerOutput:
logits_output: Union[LogitsProcessorOutput, PPProxyTensors]
can_run_graph: bool
expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None
routed_experts_output: Optional[RoutedExpertsOutput] = None


class ModelRunner(ModelRunnerKVCacheMixin):
Expand Down Expand Up @@ -2934,11 +2936,12 @@ def forward(
)
output.expert_distribution_metrics = recorder_outputs.get("metrics")

# Copy cached routing experts' buffers back to CPU cache
get_global_experts_capturer().on_forward_end(
no_copy_to_cpu = not self.server_args.disable_overlap_schedule
output.routed_experts_output = get_global_experts_capturer().on_forward_end(
forward_batch=forward_batch,
can_run_graph=output.can_run_graph,
cuda_graph_batch=getattr(self.graph_runner, "bs", None),
no_copy_to_cpu=no_copy_to_cpu,
)

if self.eplb_manager is not None:
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/speculative/eagle_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ def verify(self, batch: ModelWorkerBatch):
can_run_cuda_graph=can_run_cuda_graph,
next_draft_input=next_draft_input,
accept_lens=accept_length,
routed_experts_output=forward_batch_output.routed_experts_output,
)

def _mamba_verify_update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def verify(
can_run_cuda_graph=can_run_cuda_graph,
next_draft_input=next_draft_input,
accept_lens=accept_length,
routed_experts_output=forward_batch_output.routed_experts_output,
)

def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
Expand Down
Loading