Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/sglang/srt/managers/expert_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def with_current_layer(self, layer_idx):
def with_debug_name(self, debug_name):
yield

@contextmanager
def disable_this_region(self):
yield

@contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
yield
Expand Down Expand Up @@ -116,6 +120,7 @@ def __init__(
self._expert_location_metadata = expert_location_metadata

self._recording = False
self._disable_all = False
self._current_forward_pass_id = Withable()
self._current_layer_idx = Withable()
self._current_debug_name = Withable()
Expand Down Expand Up @@ -148,6 +153,16 @@ def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
finally:
self._on_forward_pass_end(forward_pass_id)

@contextmanager
def disable_this_region(self):
"""Context manager to temporarily disable recording."""
previous_disable_all = self._disable_all
self._disable_all = True
try:
yield
finally:
self._disable_all = previous_disable_all

def _on_forward_pass_start(self, forward_batch: ForwardBatch):
if not self._recording:
return
Expand Down Expand Up @@ -189,6 +204,8 @@ def on_deepep_dispatch_low_latency(
)

def _on_hook(self, hook_name: str, **kwargs):
if self._disable_all:
return
if not (self._recording or torch.cuda.is_current_stream_capturing()):
return
gatherer = self._single_pass_gatherers[
Expand Down Expand Up @@ -462,6 +479,10 @@ def __init__(self, *args, **kwargs):
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids = topk_ids.flatten()
mask = topk_ids != -1
assert self._data[layer_idx, :].shape == topk_ids.shape, (
"Shape mismatch between data and topk_ids."
"Selecting expert is not supported for multiple token prediction at the moment."
)
self._data[layer_idx, :].scatter_add_(
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
)
Expand Down
11 changes: 7 additions & 4 deletions python/sglang/srt/models/deepseek_nextn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
Expand Down Expand Up @@ -82,7 +85,6 @@ def forward(
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:

zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
Expand All @@ -108,9 +110,10 @@ def forward(
)

residual = None
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual, zero_allocator
)
with get_global_expert_distribution_recorder().disable_this_region():
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual, zero_allocator
)

if not forward_batch.forward_mode.is_idle():
if residual is not None:
Expand Down
Loading