Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
3d2d2e9
disable vocab parallel head
ch-wan Apr 19, 2025
d6934d0
llama4 support
ch-wan Apr 19, 2025
8b322c6
Merge remote-tracking branch 'upstream/HEAD' into dev/dp-head
ch-wan Apr 21, 2025
2e2332a
use attn tp group for lm head
ch-wan Apr 21, 2025
24bcd75
fix
ch-wan Apr 21, 2025
14ed913
pass accuracy test
ch-wan Apr 21, 2025
6b43aa5
format
ch-wan Apr 21, 2025
c9dde02
use local attn dp size
ch-wan Apr 19, 2025
d0a9b99
fix
ch-wan Apr 19, 2025
515f20f
several fix
ch-wan Apr 22, 2025
462f51e
Update .gitignore
ch-wan Apr 23, 2025
5adc5e5
fix refactor
ch-wan Apr 23, 2025
9769217
optimize memory
ch-wan Apr 23, 2025
3b6b6d7
add debug info
ch-wan Apr 23, 2025
16c4b74
format
ch-wan Apr 23, 2025
f0674f7
format
ch-wan Apr 23, 2025
bad6e91
Merge branch 'main' into dev/fix-dp-ffn-cuda-graph
ch-wan Apr 24, 2025
62e05aa
Merge branch 'main' into dev/dp-head
liusy58 May 10, 2025
27a8ec3
Merge branch 'main' into dev/fix-dp-ffn-cuda-graph
liusy58 May 10, 2025
182aa52
Add `use_attn_tp_group` for user to decide whether to use vocabulary …
liusy58 May 10, 2025
4712ed0
Add `use_attn_tp_group` for user to decide whether to use vocabulary …
liusy58 May 10, 2025
c747204
Merge branch 'main' into dev/dp-head
ch-wan May 10, 2025
804311d
Merge branch 'main' into dev/fix-dp-ffn-cuda-graph
ch-wan May 10, 2025
5e8e44e
Rename `use_attn_tp_group` to `enable_dp_lm_head` and refactor the `_…
liusy58 May 11, 2025
027290c
Merge branch 'main' into dev/fix-dp-ffn-cuda-graph
liusy58 May 11, 2025
8c6ec17
Rename `use_attn_tp_group` to `enable_dp_lm_head` and refactor the `_…
liusy58 May 11, 2025
71f13f6
Merge branch 'main' into dev/dp-head
ch-wan May 11, 2025
f7e990f
Merge branch 'main' into dev/dp-head
liusy58 May 11, 2025
a8e3315
Merge branch 'main' into dev/fix-dp-ffn-cuda-graph
liusy58 May 11, 2025
e75d496
Merge branch 'main' into dev/dp-head
liusy58 May 11, 2025
efea846
Gather is needed if `enable_dp_lm_head` is not set.
liusy58 May 11, 2025
f84c245
Update scheduler.py
ch-wan May 11, 2025
0f43319
Merge branch 'main' into dev/dp-head
ch-wan May 11, 2025
e1500ff
Merge branch 'main' into dev/fix-dp-ffn-cuda-graph
ch-wan May 11, 2025
71c12f6
update code style
ch-wan May 11, 2025
160517b
format
ch-wan May 11, 2025
5d02170
fix
ch-wan May 11, 2025
8f1395a
Merge branch 'dev/dp-head' into dev/fix-dp-ffn-cuda-graph
liusy58 May 12, 2025
bb61b5c
Merge branch 'main' into dev/fix-dp-ffn-cuda-graph
ch-wan May 12, 2025
bf10e71
Update logits_processor.py
ch-wan May 12, 2025
54b9e5b
Merge branch 'main' into dev/fix-dp-ffn-cuda-graph
May 12, 2025
6d006c0
Merge branch 'main' into dev/fix-dp-ffn-cuda-graph
ch-wan May 12, 2025
25c838f
rename `dp_rank` to `attn_dp_rank`
May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 68 additions & 18 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,71 @@
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None
_ATTN_TP_SIZE = None
_DP_RANK = None
_DP_SIZE = None
_ATTN_DP_RANK = None
_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK = None


def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0

attn_tp_size = tp_size // dp_size
dp_rank = tp_rank // attn_tp_size
attn_dp_rank = tp_rank // attn_tp_size
attn_tp_rank = tp_rank % attn_tp_size
return attn_tp_rank, attn_tp_size, dp_rank

return attn_tp_rank, attn_tp_size, attn_dp_rank


def compute_dp_attention_local_info(
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
):
if not enable_dp_attention:
return tp_rank, tp_size, 0

local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
local_tp_rank = tp_rank % local_tp_size
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))

local_attn_tp_size = local_tp_size // local_dp_size
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
local_attn_tp_rank = local_tp_rank % local_attn_tp_size

return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank


def initialize_dp_attention(
enable_dp_attention: bool,
tp_rank: int,
tp_size: int,
dp_size: int,
moe_dense_tp_size: int,
pp_size: int,
):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK

from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP

_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
)

if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size
_ATTN_DP_SIZE = dp_size
if moe_dense_tp_size is None:
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
else:
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
else:
local_rank = tp_rank
_DP_SIZE = 1
_ATTN_DP_SIZE = 1
_LOCAL_ATTN_DP_SIZE = 1

tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
Expand Down Expand Up @@ -93,13 +123,33 @@ def get_attention_tp_size():


def get_attention_dp_rank():
assert _DP_RANK is not None, "dp attention not initialized!"
return _DP_RANK
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
return _ATTN_DP_RANK


def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE


def get_local_attention_dp_rank():
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK


def get_local_attention_dp_size():
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE


def get_local_attention_dp_rank():
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK


def get_local_attention_dp_size():
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE
Comment on lines +145 to +152
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated



@contextmanager
Expand All @@ -112,19 +162,19 @@ def disable_dp_size():
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _DP_SIZE
assert _DP_SIZE is not None, "dp attention not initialized!"
global _ATTN_DP_SIZE
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"

old_dp_size = _DP_SIZE
_DP_SIZE = 1
old_dp_size = _ATTN_DP_SIZE
_ATTN_DP_SIZE = 1
try:
yield
finally:
_DP_SIZE = old_dp_size
_ATTN_DP_SIZE = old_dp_size


def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
dp_rank = get_local_attention_dp_rank()

if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
Expand Down
20 changes: 17 additions & 3 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
attn_tp_all_gather,
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_size,
get_local_attention_dp_rank,
get_local_attention_dp_size,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
Expand All @@ -46,6 +47,18 @@
logger = logging.getLogger(__name__)


from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import dump_to_file

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class LogitsProcessorOutput:
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
Expand Down Expand Up @@ -170,7 +183,7 @@ def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
return

cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_attention_dp_rank()
dp_rank = get_local_attention_dp_rank()
if dp_rank == 0:
dp_local_start_pos = torch.zeros_like(
self.global_num_tokens_for_logprob_gpu[0]
Expand Down Expand Up @@ -324,7 +337,8 @@ def forward(

if self.debug_tensor_dump_output_folder:
assert (
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
not self.do_tensor_parallel_all_gather
or get_local_attention_dp_size() == 1
), "dp attention + sharded lm_head doesn't support full logits"
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
Expand Down
32 changes: 21 additions & 11 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ def __init__(
self.page_size = server_args.page_size

# Distributed rank info
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
self.tp_rank,
Expand Down Expand Up @@ -772,7 +773,7 @@ def event_loop_pp(self):

if not self.pp_group.is_last_rank:
# send out reqs to the next stage
dp_offset = self.dp_rank * self.attn_tp_size
dp_offset = self.attn_dp_rank * self.attn_tp_size
if self.attn_tp_rank == 0:
point_to_point_pyobj(
recv_reqs,
Expand Down Expand Up @@ -819,7 +820,7 @@ def recv_requests(self) -> List[Req]:
recv_reqs = None
else:
if self.attn_tp_rank == 0:
dp_offset = self.dp_rank * self.attn_tp_size
dp_offset = self.attn_dp_rank * self.attn_tp_size
recv_reqs = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
Expand Down Expand Up @@ -1622,6 +1623,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
Expand All @@ -1634,6 +1636,7 @@ def prepare_dp_attn_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
moe_dense_tp_size: Optional[int],
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
Expand All @@ -1643,15 +1646,15 @@ def prepare_dp_attn_batch_raw(
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum(
num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
Expand All @@ -1678,7 +1681,7 @@ def prepare_dp_attn_batch_raw(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
num_tokens_for_logprob,
is_extend_in_batch,
],
dtype=torch.int64,
Expand All @@ -1701,8 +1704,15 @@ def prepare_dp_attn_batch_raw(
local_batch = get_idle_batch()

if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
local_batch.global_num_tokens = [num_tokens]
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
else:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)

# Check forward mode for cuda graph
if not disable_cuda_graph:
Expand Down Expand Up @@ -2182,8 +2192,8 @@ def close_session(self, recv_req: CloseSessionReqInput):

def get_print_prefix(self):
prefix = ""
if self.dp_rank is not None:
prefix += f" DP{self.dp_rank}"
if self.attn_dp_rank is not None:
prefix += f" DP{self.attn_dp_rank}"
if self.server_args.tp_size > 1:
prefix += f" TP{self.tp_rank}"
if self.pp_size > 1:
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def init_torch_distributed(self):
tp_rank=self.tp_rank,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
pp_size=self.server_args.pp_size,
)

Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
Expand Down Expand Up @@ -438,7 +438,6 @@ def __init__(
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()

Expand Down Expand Up @@ -1133,7 +1132,7 @@ def __init__(
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.dp_size = get_attention_dp_size()
self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.self_attn = DeepseekV2AttentionMLA(
Expand Down Expand Up @@ -1184,7 +1183,8 @@ def __init__(
)

self.input_is_scattered = (
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1

Expand Down Expand Up @@ -1264,7 +1264,7 @@ def forward_ffn_with_full_input(
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
Expand All @@ -1289,7 +1289,7 @@ def forward_ffn_with_full_input(

# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.dp_size != 1:
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
Expand Down Expand Up @@ -1413,7 +1413,7 @@ def __init__(
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.dp_size = get_attention_dp_size()
self.dp_size = get_local_attention_dp_size()

def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def __init__(
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
self.dp_size = get_local_attention_dp_size()

def determine_n_share_experts_fusion(
self, architecture: str = "DeepseekV3ForCausalLM"
Expand Down
Loading
Loading