Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions python/sglang/srt/state_capturer/routed_experts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

import numpy as np
import pybase64
Expand All @@ -9,7 +9,6 @@
attn_tp_all_gather_into_tensor,
get_attention_dp_rank,
get_attention_tp_size,
get_dp_local_info,
is_dp_attention_enabled,
)
from sglang.srt.layers.moe import get_moe_a2a_backend
Expand All @@ -18,6 +17,21 @@
from sglang.srt.state_capturer.base import BaseTopkCapturer


def _get_dp_local_slice_indices(
Comment thread
zyzshishui marked this conversation as resolved.
Outdated
forward_batch: ForwardBatch,
can_run_graph: bool,
cuda_graph_batch: Optional[int],
) -> Tuple[int, int]:
global_num_tokens = forward_batch.global_num_tokens_cpu
dp_rank = get_attention_dp_rank()
local_num_tokens = global_num_tokens[dp_rank]
Comment thread
zyzshishui marked this conversation as resolved.
Outdated
if can_run_graph:
local_start_pos = dp_rank * cuda_graph_batch
else:
local_start_pos = sum(global_num_tokens[:dp_rank])
return local_start_pos, local_start_pos + local_num_tokens


class RoutedExpertsCapturer(BaseTopkCapturer):
"""Capturer for routed experts with host buffer.

Expand Down Expand Up @@ -112,10 +126,9 @@ def _get_local_slice(
# the per-rank buffer, so the local DP rank's data lives at [0:N_local]
# rather than at the global [start_pos:end_pos] offset.
if is_dp_attention_enabled() and not get_moe_a2a_backend().is_deepep():
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
if can_run_graph:
local_start_pos = get_attention_dp_rank() * cuda_graph_batch
local_end_pos = local_start_pos + local_num_tokens
local_start_pos, local_end_pos = _get_dp_local_slice_indices(
forward_batch, can_run_graph, cuda_graph_batch
)
else:
local_start_pos, local_end_pos = 0, forward_batch.out_cache_loc.shape[0]
return self.device_cache.buffer[
Expand Down
Loading