diff --git a/python/sglang/srt/layers/moe/hash_topk.py b/python/sglang/srt/layers/moe/hash_topk.py index cdeca35f7724..f6929ab4bdbc 100644 --- a/python/sglang/srt/layers/moe/hash_topk.py +++ b/python/sglang/srt/layers/moe/hash_topk.py @@ -7,6 +7,9 @@ from torch import nn from sglang.srt.environ import envs +from sglang.srt.eplb.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.eplb.expert_location_dispatch import ( ExpertLocationDispatchInfo, topk_ids_logical_to_physical, @@ -145,6 +148,7 @@ def forward( topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) topk_output = StandardTopKOutput( topk_weights=topk_weights, topk_ids=topk_ids, router_logits=router_logits ) diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index cb57ea40095d..b8c3f0c22129 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -3,6 +3,7 @@ import concurrent.futures import logging import time +from contextlib import nullcontext from typing import ( TYPE_CHECKING, Iterable, @@ -33,6 +34,7 @@ get_tp_group, ) from sglang.srt.environ import envs +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.layers.attention.dsa.utils import ( can_dsa_cp_split, @@ -1134,13 +1136,19 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states = layer( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - input_ids=input_ids, - input_ids_global=input_ids_global, + ctx = ( + nullcontext() + if not get_global_server_args().disable_piecewise_cuda_graph + else get_global_expert_distribution_recorder().with_current_layer(i) ) + with ctx: + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) # CP all-gather only on the last PP rank; PP IPC carries CP-split tensors. if self.pp_group.is_last_rank and dsa_use_prefill_cp(forward_batch):