diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 0f1e453bfd0..6de1797ae1b 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -24,8 +24,10 @@ _ATTN_TP_GROUP = None _ATTN_TP_RANK = None _ATTN_TP_SIZE = None -_DP_RANK = None -_DP_SIZE = None +_ATTN_DP_RANK = None +_ATTN_DP_SIZE = None +_LOCAL_ATTN_DP_SIZE = None +_LOCAL_ATTN_DP_RANK = None def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): @@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si return tp_rank, tp_size, 0 attn_tp_size = tp_size // dp_size - dp_rank = tp_rank // attn_tp_size + attn_dp_rank = tp_rank // attn_tp_size attn_tp_rank = tp_rank % attn_tp_size - return attn_tp_rank, attn_tp_size, dp_rank + + return attn_tp_rank, attn_tp_size, attn_dp_rank + + +def compute_dp_attention_local_info( + enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size +): + if not enable_dp_attention: + return tp_rank, tp_size, 0 + + local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size + local_tp_rank = tp_rank % local_tp_size + local_dp_size = max(1, dp_size // (tp_size // local_tp_size)) + + local_attn_tp_size = local_tp_size // local_dp_size + local_attn_dp_rank = local_tp_rank // local_attn_tp_size + local_attn_tp_rank = local_tp_rank % local_attn_tp_size + + return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank def initialize_dp_attention( @@ -43,22 +63,32 @@ def initialize_dp_attention( tp_rank: int, tp_size: int, dp_size: int, + moe_dense_tp_size: int, pp_size: int, ): - global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE + global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE + global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP - _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( + _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info( enable_dp_attention, tp_rank, tp_size, dp_size ) + _, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info( + enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size + ) if enable_dp_attention: local_rank = tp_rank % (tp_size // dp_size) - _DP_SIZE = dp_size + _ATTN_DP_SIZE = dp_size + if moe_dense_tp_size is None: + _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE + else: + _LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size)) else: local_rank = tp_rank - _DP_SIZE = 1 + _ATTN_DP_SIZE = 1 + _LOCAL_ATTN_DP_SIZE = 1 tp_group = get_tp_group() _ATTN_TP_GROUP = GroupCoordinator( @@ -93,13 +123,33 @@ def get_attention_tp_size(): def get_attention_dp_rank(): - assert _DP_RANK is not None, "dp attention not initialized!" - return _DP_RANK + assert _ATTN_DP_RANK is not None, "dp attention not initialized!" + return _ATTN_DP_RANK def get_attention_dp_size(): - assert _DP_SIZE is not None, "dp attention not initialized!" - return _DP_SIZE + assert _ATTN_DP_SIZE is not None, "dp attention not initialized!" + return _ATTN_DP_SIZE + + +def get_local_attention_dp_rank(): + assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!" + return _LOCAL_ATTN_DP_RANK + + +def get_local_attention_dp_size(): + assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!" + return _LOCAL_ATTN_DP_SIZE + + +def get_local_attention_dp_rank(): + assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!" + return _LOCAL_ATTN_DP_RANK + + +def get_local_attention_dp_size(): + assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!" + return _LOCAL_ATTN_DP_SIZE @contextmanager @@ -112,19 +162,19 @@ def disable_dp_size(): Args: tp_group (GroupCoordinator): the tp group coordinator """ - global _DP_SIZE - assert _DP_SIZE is not None, "dp attention not initialized!" + global _ATTN_DP_SIZE + assert _ATTN_DP_SIZE is not None, "dp attention not initialized!" - old_dp_size = _DP_SIZE - _DP_SIZE = 1 + old_dp_size = _ATTN_DP_SIZE + _ATTN_DP_SIZE = 1 try: yield finally: - _DP_SIZE = old_dp_size + _ATTN_DP_SIZE = old_dp_size def get_dp_local_info(forward_batch: ForwardBatch): - dp_rank = get_attention_dp_rank() + dp_rank = get_local_attention_dp_rank() if forward_batch.dp_local_start_pos is None: cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 5a4f0781729..60091b9a483 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -30,9 +30,10 @@ attn_tp_all_gather, dp_gather_replicate, dp_scatter, - get_attention_dp_rank, get_attention_dp_size, get_attention_tp_size, + get_local_attention_dp_rank, + get_local_attention_dp_size, ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -46,6 +47,18 @@ logger = logging.getLogger(__name__) +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.utils import dump_to_file + +logger = logging.getLogger(__name__) + + @dataclasses.dataclass class LogitsProcessorOutput: ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor @@ -170,7 +183,7 @@ def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): return cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) - dp_rank = get_attention_dp_rank() + dp_rank = get_local_attention_dp_rank() if dp_rank == 0: dp_local_start_pos = torch.zeros_like( self.global_num_tokens_for_logprob_gpu[0] @@ -324,7 +337,8 @@ def forward( if self.debug_tensor_dump_output_folder: assert ( - not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1 + not self.do_tensor_parallel_all_gather + or get_local_attention_dp_size() == 1 ), "dp attention + sharded lm_head doesn't support full logits" full_logits = self._get_logits(hidden_states, lm_head, logits_metadata) dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fa1e20b0c6a..6ae3004c61c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -207,7 +207,8 @@ def __init__( self.page_size = server_args.page_size # Distributed rank info - self.attn_tp_rank, self.attn_tp_size, self.dp_rank = ( + self.dp_size = server_args.dp_size + self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( compute_dp_attention_world_info( server_args.enable_dp_attention, self.tp_rank, @@ -772,7 +773,7 @@ def event_loop_pp(self): if not self.pp_group.is_last_rank: # send out reqs to the next stage - dp_offset = self.dp_rank * self.attn_tp_size + dp_offset = self.attn_dp_rank * self.attn_tp_size if self.attn_tp_rank == 0: point_to_point_pyobj( recv_reqs, @@ -819,7 +820,7 @@ def recv_requests(self) -> List[Req]: recv_reqs = None else: if self.attn_tp_rank == 0: - dp_offset = self.dp_rank * self.attn_tp_size + dp_offset = self.attn_dp_rank * self.attn_tp_size recv_reqs = point_to_point_pyobj( [], self.pp_rank * self.tp_size + dp_offset, @@ -1622,6 +1623,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): local_batch, dp_size=self.server_args.dp_size, attn_tp_size=self.attn_tp_size, + moe_dense_tp_size=self.server_args.moe_dense_tp_size, tp_cpu_group=self.tp_cpu_group, get_idle_batch=self.get_idle_batch, disable_cuda_graph=self.server_args.disable_cuda_graph, @@ -1634,6 +1636,7 @@ def prepare_dp_attn_batch_raw( local_batch: ScheduleBatch, dp_size, attn_tp_size: int, + moe_dense_tp_size: Optional[int], tp_cpu_group, get_idle_batch, disable_cuda_graph: bool, @@ -1643,15 +1646,15 @@ def prepare_dp_attn_batch_raw( # Check if other DP workers have running batches if local_batch is None: num_tokens = 0 - global_num_tokens_for_logprob = 0 + num_tokens_for_logprob = 0 elif local_batch.forward_mode.is_decode(): num_tokens = local_batch.batch_size() if not spec_algorithm.is_none() and spec_algorithm.is_eagle(): num_tokens = num_tokens * speculative_num_draft_tokens - global_num_tokens_for_logprob = num_tokens + num_tokens_for_logprob = num_tokens else: num_tokens = local_batch.extend_num_tokens - global_num_tokens_for_logprob = sum( + num_tokens_for_logprob = sum( [ # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) @@ -1678,7 +1681,7 @@ def prepare_dp_attn_batch_raw( [ num_tokens, can_cuda_graph, - global_num_tokens_for_logprob, + num_tokens_for_logprob, is_extend_in_batch, ], dtype=torch.int64, @@ -1701,8 +1704,15 @@ def prepare_dp_attn_batch_raw( local_batch = get_idle_batch() if local_batch is not None: - local_batch.global_num_tokens = global_num_tokens - local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob + # TODO: handle the case when moe_dense_tp_size != 1 + if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]: + local_batch.global_num_tokens = [num_tokens] + local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob] + else: + local_batch.global_num_tokens = global_num_tokens + local_batch.global_num_tokens_for_logprob = ( + global_num_tokens_for_logprob + ) # Check forward mode for cuda graph if not disable_cuda_graph: @@ -2182,8 +2192,8 @@ def close_session(self, recv_req: CloseSessionReqInput): def get_print_prefix(self): prefix = "" - if self.dp_rank is not None: - prefix += f" DP{self.dp_rank}" + if self.attn_dp_rank is not None: + prefix += f" DP{self.attn_dp_rank}" if self.server_args.tp_size > 1: prefix += f" TP{self.tp_rank}" if self.pp_size > 1: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a102e63ae1a..1b81155d24c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -400,6 +400,7 @@ def init_torch_distributed(self): tp_rank=self.tp_rank, tp_size=self.tp_size, dp_size=self.server_args.dp_size, + moe_dense_tp_size=self.server_args.moe_dense_tp_size, pp_size=self.server_args.pp_size, ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e8ef96a6eec..4dfbad77d7e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -40,9 +40,9 @@ attn_tp_reduce_scatter, dp_gather_partial, dp_scatter, - get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, + get_local_attention_dp_size, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -438,7 +438,6 @@ def __init__( self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank - self.dp_size = get_attention_dp_size() attn_tp_rank = get_attention_tp_rank() attn_tp_size = get_attention_tp_size() @@ -1133,7 +1132,7 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] self.layer_id = layer_id - self.dp_size = get_attention_dp_size() + self.local_dp_size = get_local_attention_dp_size() self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() self.self_attn = DeepseekV2AttentionMLA( @@ -1184,7 +1183,8 @@ def __init__( ) self.input_is_scattered = ( - previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED + layer_id > 0 + and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -1264,7 +1264,7 @@ def forward_ffn_with_full_input( # Gather if get_tensor_model_parallel_world_size() > 1: # all gather and all reduce - if self.dp_size != 1: + if self.local_dp_size != 1: if self.attn_tp_rank == 0: hidden_states += residual hidden_states, local_hidden_states = ( @@ -1289,7 +1289,7 @@ def forward_ffn_with_full_input( # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # Scatter - if self.dp_size != 1: + if self.local_dp_size != 1: # important: forward batch.gathered_buffer is used both after scatter and after gather. # be careful about this! hidden_states, global_hidden_states = ( @@ -1413,7 +1413,7 @@ def __init__( ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.dp_size = get_attention_dp_size() + self.dp_size = get_local_attention_dp_size() def get_input_embeddings(self) -> torch.Tensor: return self.embed_tokens @@ -1478,7 +1478,7 @@ def __init__( use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) - self.dp_size = get_attention_dp_size() + self.dp_size = get_local_attention_dp_size() def determine_n_share_experts_fusion( self, architecture: str = "DeepseekV3ForCausalLM" diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index a4f2d03a8d8..d309d0be1a0 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import ( dp_gather_partial, dp_scatter, - get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, + get_local_attention_dp_size, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -198,7 +198,6 @@ def __init__( self.use_rope = int((layer_id + 1) % 4 != 0) self.use_qk_norm = config.use_qk_norm and self.use_rope - self.dp_size = get_attention_dp_size() attn_tp_rank = get_attention_tp_rank() attn_tp_size = get_attention_tp_size() @@ -342,7 +341,7 @@ def __init__( rope_theta = config.rope_theta rope_scaling = config.rope_scaling max_position_embeddings = config.max_position_embeddings - self.dp_size = get_attention_dp_size() + self.local_dp_size = get_local_attention_dp_size() self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() @@ -405,7 +404,7 @@ def forward( # Gather if get_tensor_model_parallel_world_size() > 1: # all gather and all reduce - if self.dp_size != 1: + if self.local_dp_size != 1: if self.attn_tp_rank == 0: hidden_states += residual hidden_states, local_hidden_states = ( @@ -430,7 +429,7 @@ def forward( # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # Scatter - if self.dp_size != 1: + if self.local_dp_size != 1: # important: forward batch.gathered_buffer is used both after scatter and after gather. # be careful about this! hidden_states, global_hidden_states = (