Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,23 @@ def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.
return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens


def get_dp_local_slice_cpu(
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: Optional[int],
) -> Tuple[int, int]:
# CPU (start, length) slice for DP-local data in a rank-padded buffer.
# Returns Python ints (no D2H sync) and handles the cuda-graph-padded layout.
global_num_tokens = forward_batch.global_num_tokens_cpu
dp_rank = get_attention_dp_rank()
local_num_tokens = global_num_tokens[dp_rank]
if can_run_graph:
local_start_pos = dp_rank * cuda_graph_batch
else:
local_start_pos = sum(global_num_tokens[:dp_rank])
return local_start_pos, local_num_tokens


@triton.jit
def memcpy_triton_kernel(
dst_ptr,
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/state_capturer/routed_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor,
get_attention_dp_rank,
get_attention_tp_size,
get_dp_local_info,
get_dp_local_slice_cpu,
is_dp_attention_enabled,
)
from sglang.srt.layers.moe import get_moe_a2a_backend
Expand Down Expand Up @@ -112,9 +111,10 @@ def _get_local_slice(
# the per-rank buffer, so the local DP rank's data lives at [0:N_local]
# rather than at the global [start_pos:end_pos] offset.
if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep():
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
if can_run_graph:
local_start_pos = get_attention_dp_rank() * cuda_graph_batch
# GPU->CPU sync would break overlap; operate on CPU directly.
local_start_pos, local_num_tokens = get_dp_local_slice_cpu(
forward_batch, can_run_graph, cuda_graph_batch
)
local_end_pos = local_start_pos + local_num_tokens
else:
local_start_pos, local_end_pos = 0, forward_batch.out_cache_loc.shape[0]
Expand Down
Loading