From bfee2a34020d8e85adb334b0e5412358a30ee01b Mon Sep 17 00:00:00 2001 From: xutizhou Date: Thu, 21 May 2026 14:26:18 +0800 Subject: [PATCH] Fix DeepSeek V4 expert distribution layer context --- python/sglang/srt/layers/moe/hash_topk.py | 4 ++++ python/sglang/srt/models/deepseek_v4.py | 22 ++++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/moe/hash_topk.py b/python/sglang/srt/layers/moe/hash_topk.py index cdeca35f7724..f6929ab4bdbc 100644 --- a/python/sglang/srt/layers/moe/hash_topk.py +++ b/python/sglang/srt/layers/moe/hash_topk.py @@ -7,6 +7,9 @@ from torch import nn from sglang.srt.environ import envs +from sglang.srt.eplb.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.eplb.expert_location_dispatch import ( ExpertLocationDispatchInfo, topk_ids_logical_to_physical, @@ -145,6 +148,7 @@ def forward( topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) topk_output = StandardTopKOutput( topk_weights=topk_weights, topk_ids=topk_ids, router_logits=router_logits ) diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index cb57ea40095d..4fcb727c2c7b 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -3,6 +3,7 @@ import concurrent.futures import logging import time +from contextlib import nullcontext from typing import ( TYPE_CHECKING, Iterable, @@ -33,6 +34,9 @@ get_tp_group, ) from sglang.srt.environ import envs +from sglang.srt.eplb.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.layers.attention.dsa.utils import ( can_dsa_cp_split, @@ -1134,13 +1138,19 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states = layer( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - input_ids=input_ids, - input_ids_global=input_ids_global, + ctx = ( + nullcontext() + if not get_global_server_args().disable_piecewise_cuda_graph + else get_global_expert_distribution_recorder().with_current_layer(i) ) + with ctx: + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) # CP all-gather only on the last PP rank; PP IPC carries CP-split tensors. if self.pp_group.is_last_rank and dsa_use_prefill_cp(forward_batch):