diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 09d307aeb73f..b8d761784499 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -402,6 +402,23 @@ def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch. return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens +def get_dp_local_slice_cpu( + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], +) -> Tuple[int, int]: + # CPU (start, length) slice for DP-local data in a rank-padded buffer. + # Returns Python ints (no D2H sync) and handles the cuda-graph-padded layout. + global_num_tokens = forward_batch.global_num_tokens_cpu + dp_rank = get_attention_dp_rank() + local_num_tokens = global_num_tokens[dp_rank] + if can_run_graph: + local_start_pos = dp_rank * cuda_graph_batch + else: + local_start_pos = sum(global_num_tokens[:dp_rank]) + return local_start_pos, local_num_tokens + + @triton.jit def memcpy_triton_kernel( dst_ptr, diff --git a/python/sglang/srt/state_capturer/routed_experts.py b/python/sglang/srt/state_capturer/routed_experts.py index eab516cb6e1d..553bb30ff0e6 100644 --- a/python/sglang/srt/state_capturer/routed_experts.py +++ b/python/sglang/srt/state_capturer/routed_experts.py @@ -7,9 +7,8 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.layers.dp_attention import ( attn_tp_all_gather_into_tensor, - get_attention_dp_rank, get_attention_tp_size, - get_dp_local_info, + get_dp_local_slice_cpu, is_dp_attention_enabled, ) from sglang.srt.layers.moe import get_moe_a2a_backend @@ -112,9 +111,10 @@ def _get_local_slice( # the per-rank buffer, so the local DP rank's data lives at [0:N_local] # rather than at the global [start_pos:end_pos] offset. if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep(): - local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) - if can_run_graph: - local_start_pos = get_attention_dp_rank() * cuda_graph_batch + # GPU->CPU sync would break overlap; operate on CPU directly. + local_start_pos, local_num_tokens = get_dp_local_slice_cpu( + forward_batch, can_run_graph, cuda_graph_batch + ) local_end_pos = local_start_pos + local_num_tokens else: local_start_pos, local_end_pos = 0, forward_batch.out_cache_loc.shape[0]