Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group,
get_world_group,
tensor_model_parallel_all_reduce,
)

Expand Down Expand Up @@ -79,14 +80,14 @@ def initialize_dp_attention(
)

if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
# local_rank = tp_rank % (tp_size // dp_size)
_ATTN_DP_SIZE = dp_size
if moe_dense_tp_size is None:
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
else:
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
else:
local_rank = tp_rank
# local_rank = tp_rank
_ATTN_DP_SIZE = 1
_LOCAL_ATTN_DP_SIZE = 1

Expand All @@ -96,7 +97,7 @@ def initialize_dp_attention(
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
],
local_rank,
get_world_group().local_rank,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can use tp_group.local_rank here

torch.distributed.get_backend(tp_group.device_group),
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
use_pymscclpp=False,
Expand Down