diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 2b413b4467bd..99d0fcec51aa 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -419,7 +419,24 @@ def _dp_gather_via_all_reduce( forward_batch: ForwardBatch, is_partial: bool, ): - local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + # LogitsMetadata should have dp_local_start_pos set by compute_dp_attention_metadata(). + # Avoid calling get_dp_local_info() to maintain separation of concerns. + if type(forward_batch).__name__ == "LogitsMetadata": + assert ( + forward_batch.dp_local_start_pos is not None + ), "LogitsMetadata.dp_local_start_pos should be set by compute_dp_attention_metadata()" + local_start_pos = forward_batch.dp_local_start_pos + local_num_tokens = forward_batch.dp_local_num_tokens + else: + # ForwardBatch: compute position info using global_num_tokens_gpu + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + + # Use actual tensor size for correctness (scheduler's estimate may not account for pruning) + actual_local_size = local_tokens.shape[0] + if isinstance(local_num_tokens, torch.Tensor): + local_num_tokens.fill_(actual_local_size) + else: + local_num_tokens = actual_local_size global_tokens.fill_(0) assert local_tokens.is_contiguous()