Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,24 @@ def _dp_gather_via_all_reduce(
forward_batch: ForwardBatch,
is_partial: bool,
):
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
# LogitsMetadata should have dp_local_start_pos set by compute_dp_attention_metadata().
# Avoid calling get_dp_local_info() to maintain separation of concerns.
if type(forward_batch).__name__ == "LogitsMetadata":
assert (
forward_batch.dp_local_start_pos is not None
), "LogitsMetadata.dp_local_start_pos should be set by compute_dp_attention_metadata()"
local_start_pos = forward_batch.dp_local_start_pos
local_num_tokens = forward_batch.dp_local_num_tokens
else:
# ForwardBatch: compute position info using global_num_tokens_gpu
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)

# Use actual tensor size for correctness (scheduler's estimate may not account for pruning)
actual_local_size = local_tokens.shape[0]
if isinstance(local_num_tokens, torch.Tensor):
local_num_tokens.fill_(actual_local_size)
else:
local_num_tokens = actual_local_size

global_tokens.fill_(0)
assert local_tokens.is_contiguous()
Expand Down
Loading