From 3d2d2e9b30d04c96cbd3d4259566f7a961d82c52 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Sat, 19 Apr 2025 09:58:06 +0000 Subject: [PATCH 01/26] disable vocab parallel head --- python/sglang/srt/layers/logits_processor.py | 30 ++++++++++--------- .../srt/layers/vocab_parallel_embedding.py | 2 ++ python/sglang/srt/models/deepseek_v2.py | 14 ++++----- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 981040d0dfd..d5d4b4754e5 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -422,7 +422,7 @@ def _get_logits( last position (e.g., extend without input logprobs). The caller should guarantee the given hidden_states follow this constraint. """ - if self.do_tensor_parallel_all_gather_dp_attn: + if lm_head.enable_tp and self.do_tensor_parallel_all_gather_dp_attn: logits_metadata.compute_dp_attention_metadata(hidden_states) hidden_states, local_hidden_states = ( logits_metadata.gathered_buffer, @@ -441,19 +441,21 @@ def _get_logits( if self.logit_scale is not None: logits.mul_(self.logit_scale) - if self.do_tensor_parallel_all_gather: - logits = tensor_model_parallel_all_gather(logits) - - if self.do_tensor_parallel_all_gather_dp_attn: - logits, global_logits = ( - torch.empty( - (local_hidden_states.shape[0], logits.shape[1]), - device=logits.device, - dtype=logits.dtype, - ), - logits, - ) - dp_scatter(logits, global_logits, logits_metadata) + if lm_head.enable_tp: + + if self.do_tensor_parallel_all_gather: + logits = tensor_model_parallel_all_gather(logits) + + if self.do_tensor_parallel_all_gather_dp_attn: + logits, global_logits = ( + torch.empty( + (local_hidden_states.shape[0], logits.shape[1]), + device=logits.device, + dtype=logits.dtype, + ), + logits, + ) + dp_scatter(logits, global_logits, logits_metadata) logits = logits[:, : self.config.vocab_size].float() diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index ebc148feb97..8f7fe07c624 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -525,6 +525,7 @@ def __init__( padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_tp: bool = True, use_presharded_weights: bool = False, ): super().__init__( @@ -535,6 +536,7 @@ def __init__( padding_size, quant_config, prefix, + enable_tp, use_presharded_weights=use_presharded_weights, ) self.quant_config = quant_config diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 47dc5beaf5a..a6224edd9ef 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -96,6 +96,9 @@ logger = logging.getLogger(__name__) +def _enable_moe_dense_fully_dp(): + return global_server_args_dict["moe_dense_tp_size"] == 1 + class AttnForwardMethod(IntEnum): # Use multi-head attention MHA = auto() @@ -1076,7 +1079,7 @@ def __init__( prefix=add_prefix("mlp", prefix), ) else: - if self._enable_moe_dense_fully_dp(): + if _enable_moe_dense_fully_dp(): mlp_tp_rank, mlp_tp_size = 0, 1 else: mlp_tp_rank, mlp_tp_size = None, None @@ -1100,10 +1103,6 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) - @staticmethod - def _enable_moe_dense_fully_dp(): - return global_server_args_dict["moe_dense_tp_size"] == 1 - @staticmethod def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): is_sparse = is_nextn or ( @@ -1114,7 +1113,7 @@ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): ffn_input_mode = ( _FFNInputMode.SCATTERED if (global_server_args_dict["enable_deepep_moe"] and is_sparse) - or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) + or (_enable_moe_dense_fully_dp() and not is_sparse) else _FFNInputMode.FULL ) return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode) @@ -1267,7 +1266,7 @@ def forward_ffn_with_scattered_input( ) if not ( - self._enable_moe_dense_fully_dp() + _enable_moe_dense_fully_dp() and (not self.info.is_sparse) and hidden_states.shape[0] == 0 ): @@ -1391,6 +1390,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), + enable_tp=not _enable_moe_dense_fully_dp(), ) self.logits_processor = LogitsProcessor(config) self.dp_size = get_attention_dp_size() From d6934d02aaa3d07a71acb87a7a1ccd8612e6e57e Mon Sep 17 00:00:00 2001 From: ch-wan Date: Sat, 19 Apr 2025 22:46:02 +0000 Subject: [PATCH 02/26] llama4 support --- python/sglang/srt/models/deepseek_v2.py | 15 ++++++++------- python/sglang/srt/models/llama.py | 2 ++ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a6224edd9ef..d2d9054faea 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -96,9 +96,6 @@ logger = logging.getLogger(__name__) -def _enable_moe_dense_fully_dp(): - return global_server_args_dict["moe_dense_tp_size"] == 1 - class AttnForwardMethod(IntEnum): # Use multi-head attention MHA = auto() @@ -1079,7 +1076,7 @@ def __init__( prefix=add_prefix("mlp", prefix), ) else: - if _enable_moe_dense_fully_dp(): + if self._enable_moe_dense_fully_dp(): mlp_tp_rank, mlp_tp_size = 0, 1 else: mlp_tp_rank, mlp_tp_size = None, None @@ -1103,6 +1100,10 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) + @staticmethod + def _enable_moe_dense_fully_dp(): + return global_server_args_dict["moe_dense_tp_size"] == 1 + @staticmethod def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): is_sparse = is_nextn or ( @@ -1113,7 +1114,7 @@ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): ffn_input_mode = ( _FFNInputMode.SCATTERED if (global_server_args_dict["enable_deepep_moe"] and is_sparse) - or (_enable_moe_dense_fully_dp() and not is_sparse) + or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) else _FFNInputMode.FULL ) return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode) @@ -1266,7 +1267,7 @@ def forward_ffn_with_scattered_input( ) if not ( - _enable_moe_dense_fully_dp() + self._enable_moe_dense_fully_dp() and (not self.info.is_sparse) and hidden_states.shape[0] == 0 ): @@ -1390,7 +1391,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - enable_tp=not _enable_moe_dense_fully_dp(), + enable_tp=not global_server_args_dict["enable_dp_attention"], ) self.logits_processor = LogitsProcessor(config) self.dp_size = get_attention_dp_size() diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 008e542048a..209a162318c 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -43,6 +43,7 @@ ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( default_weight_loader, @@ -389,6 +390,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), + enable_tp=not global_server_args_dict["enable_dp_attention"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) From 2e2332a64fccc6e17999386579666b4f580c1d08 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Mon, 21 Apr 2025 04:35:57 +0000 Subject: [PATCH 03/26] use attn tp group for lm head --- python/sglang/srt/layers/dp_attention.py | 4 +-- python/sglang/srt/layers/logits_processor.py | 30 ++----------------- .../srt/layers/vocab_parallel_embedding.py | 30 +++++++++++++------ python/sglang/srt/models/deepseek_v2.py | 14 ++++----- python/sglang/srt/models/llama.py | 2 +- 5 files changed, 33 insertions(+), 47 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 5f140a3df96..10b4bbabf01 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -249,12 +249,12 @@ def dp_scatter( ) -def tp_reduce_scatter( +def attn_tp_reduce_scatter( output: torch.Tensor, input_list: List[torch.Tensor], ): return get_attention_tp_group().reduce_scatter(output, input_list) -def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): +def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): return get_attention_tp_group().all_gather(input_, tensor_list=output_list) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index d5d4b4754e5..bb488e78f1f 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -23,13 +23,10 @@ from torch import nn from sglang.srt.distributed import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from sglang.srt.layers.dp_attention import ( - dp_gather_replicate, - dp_scatter, get_attention_dp_rank, get_attention_dp_size, ) @@ -201,9 +198,6 @@ def __init__( self.do_tensor_parallel_all_gather = ( not skip_all_gather and get_tensor_model_parallel_world_size() > 1 ) - self.do_tensor_parallel_all_gather_dp_attn = ( - self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1 - ) self.final_logit_softcapping = getattr( self.config, "final_logit_softcapping", None ) @@ -422,13 +416,6 @@ def _get_logits( last position (e.g., extend without input logprobs). The caller should guarantee the given hidden_states follow this constraint. """ - if lm_head.enable_tp and self.do_tensor_parallel_all_gather_dp_attn: - logits_metadata.compute_dp_attention_metadata(hidden_states) - hidden_states, local_hidden_states = ( - logits_metadata.gathered_buffer, - hidden_states.clone(), - ) - dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) if hasattr(lm_head, "weight"): logits = torch.matmul( @@ -441,21 +428,8 @@ def _get_logits( if self.logit_scale is not None: logits.mul_(self.logit_scale) - if lm_head.enable_tp: - - if self.do_tensor_parallel_all_gather: - logits = tensor_model_parallel_all_gather(logits) - - if self.do_tensor_parallel_all_gather_dp_attn: - logits, global_logits = ( - torch.empty( - (local_hidden_states.shape[0], logits.shape[1]), - device=logits.device, - dtype=logits.dtype, - ), - logits, - ) - dp_scatter(logits, global_logits, logits_metadata) + if self.do_tensor_parallel_all_gather: + logits = tensor_model_parallel_all_gather(logits) logits = logits[:, : self.config.vocab_size].float() diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 8f7fe07c624..06cce04ccb8 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -13,6 +13,10 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -214,12 +218,14 @@ def __init__( self, num_embeddings: int, embedding_dim: int, + *, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_tp: bool = True, + use_attn_tp_group: bool = False, use_presharded_weights: bool = False, ): super().__init__() @@ -227,9 +233,14 @@ def __init__( self.enable_tp = enable_tp if self.enable_tp: - tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + if use_attn_tp_group: + tp_rank = get_attention_tp_rank() + self.tp_size = get_attention_tp_size() + else: + tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() else: + assert use_attn_tp_group is False tp_rank = 0 self.tp_size = 1 @@ -519,24 +530,25 @@ def __init__( self, num_embeddings: int, embedding_dim: int, + *, bias: bool = False, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - enable_tp: bool = True, + use_attn_tp_group: bool = False, use_presharded_weights: bool = False, ): super().__init__( num_embeddings, embedding_dim, - params_dtype, - org_num_embeddings, - padding_size, - quant_config, - prefix, - enable_tp, + params_dtype=params_dtype, + org_num_embeddings=org_num_embeddings, + padding_size=padding_size, + quant_config=quant_config, + prefix=prefix, + use_attn_tp_group=use_attn_tp_group, use_presharded_weights=use_presharded_weights, ) self.quant_config = quant_config diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bc077d26ed0..76566a92ffa 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -41,8 +41,8 @@ get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, - tp_all_gather, - tp_reduce_scatter, + attn_tp_all_gather, + attn_tp_reduce_scatter, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -1273,7 +1273,7 @@ def forward_ffn_with_scattered_input( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - tp_all_gather( + attn_tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) @@ -1289,7 +1289,7 @@ def forward_ffn_with_scattered_input( if self.input_is_scattered: tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) hidden_states = tensor_list[self.attn_tp_rank] - tp_reduce_scatter(hidden_states, tensor_list) + attn_tp_reduce_scatter(hidden_states, tensor_list) if hidden_states.shape[0] != 0: hidden_states, residual = self.post_attention_layernorm( hidden_states, residual @@ -1299,7 +1299,7 @@ def forward_ffn_with_scattered_input( hidden_states += residual tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) hidden_states = tensor_list[self.attn_tp_rank] - tp_reduce_scatter(hidden_states, tensor_list) + attn_tp_reduce_scatter(hidden_states, tensor_list) residual = hidden_states if hidden_states.shape[0] != 0: hidden_states = self.post_attention_layernorm(hidden_states) @@ -1323,7 +1323,7 @@ def forward_ffn_with_scattered_input( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - tp_all_gather( + attn_tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) @@ -1346,7 +1346,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - enable_tp=not global_server_args_dict["enable_dp_attention"], + use_attn_tp_group=global_server_args_dict["enable_dp_attention"], ) self.layers = nn.ModuleList( [ diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 209a162318c..a50e6c4e612 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -390,7 +390,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - enable_tp=not global_server_args_dict["enable_dp_attention"], + use_attn_tp_group=global_server_args_dict["enable_dp_attention"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) From 24bcd7575d4546cde4f72bcb221316f252e15e6d Mon Sep 17 00:00:00 2001 From: ch-wan Date: Mon, 21 Apr 2025 04:45:40 +0000 Subject: [PATCH 04/26] fix --- python/sglang/srt/models/deepseek_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 76566a92ffa..9e5018045dd 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1346,7 +1346,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - use_attn_tp_group=global_server_args_dict["enable_dp_attention"], + enable_tp=not global_server_args_dict["enable_dp_attention"], ) self.layers = nn.ModuleList( [ @@ -1448,7 +1448,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - enable_tp=not global_server_args_dict["enable_dp_attention"], + use_attn_tp_group=global_server_args_dict["enable_dp_attention"], ) self.logits_processor = LogitsProcessor(config) self.dp_size = get_attention_dp_size() From 14ed913d12c5a7f14cdb22deac78e3f392d6f764 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Mon, 21 Apr 2025 05:19:00 +0000 Subject: [PATCH 05/26] pass accuracy test --- python/sglang/srt/layers/logits_processor.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index bb488e78f1f..eb3a588b021 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -23,12 +23,13 @@ from torch import nn from sglang.srt.distributed import ( - get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from sglang.srt.layers.dp_attention import ( get_attention_dp_rank, get_attention_dp_size, + get_attention_tp_size, + attn_tp_all_gather, ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -195,9 +196,11 @@ def __init__( super().__init__() self.config = config self.logit_scale = logit_scale + self.attn_tp_size = get_attention_tp_size() self.do_tensor_parallel_all_gather = ( - not skip_all_gather and get_tensor_model_parallel_world_size() > 1 + not skip_all_gather and self.attn_tp_size > 1 ) + self.use_attn_tp_group = get_attention_dp_size() > 1 self.final_logit_softcapping = getattr( self.config, "final_logit_softcapping", None ) @@ -429,7 +432,17 @@ def _get_logits( logits.mul_(self.logit_scale) if self.do_tensor_parallel_all_gather: - logits = tensor_model_parallel_all_gather(logits) + if self.use_attn_tp_group: + global_logits = torch.empty( + (self.config.vocab_size, logits.shape[0]), + device=logits.device, + dtype=logits.dtype, + ) + global_logits = global_logits.T + attn_tp_all_gather(list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits) + logits = global_logits + else: + logits = tensor_model_parallel_all_gather(logits) logits = logits[:, : self.config.vocab_size].float() From 6b43aa59867c8b4ab33be075c5e0f233e8491d07 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Mon, 21 Apr 2025 05:26:12 +0000 Subject: [PATCH 06/26] format --- python/sglang/srt/layers/logits_processor.py | 10 +++++----- python/sglang/srt/layers/vocab_parallel_embedding.py | 5 +---- python/sglang/srt/models/deepseek_v2.py | 4 ++-- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index eb3a588b021..d9d2986619c 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -22,14 +22,12 @@ import triton.language as tl from torch import nn -from sglang.srt.distributed import ( - tensor_model_parallel_all_gather, -) +from sglang.srt.distributed import tensor_model_parallel_all_gather from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather, get_attention_dp_rank, get_attention_dp_size, get_attention_tp_size, - attn_tp_all_gather, ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -439,7 +437,9 @@ def _get_logits( dtype=logits.dtype, ) global_logits = global_logits.T - attn_tp_all_gather(list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits) + attn_tp_all_gather( + list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits + ) logits = global_logits else: logits = tensor_model_parallel_all_gather(logits) diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 06cce04ccb8..ec7c140ae01 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -13,10 +13,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.dp_attention import ( - get_attention_tp_rank, - get_attention_tp_size, -) +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9e5018045dd..0831e03b72e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -36,13 +36,13 @@ ) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather, + attn_tp_reduce_scatter, dp_gather_partial, dp_scatter, get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, - attn_tp_all_gather, - attn_tp_reduce_scatter, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( From c9dde0218615a54bd18e618797a2a57bcb9bc53c Mon Sep 17 00:00:00 2001 From: ch-wan Date: Sat, 19 Apr 2025 20:37:44 +0000 Subject: [PATCH 07/26] use local attn dp size (cherry picked from commit fdf58ea5150b46145f57348074fbc821778b1e6f) --- .gitignore | 1 + python/sglang/srt/layers/dp_attention.py | 68 +++++++++++++++++--- python/sglang/srt/layers/logits_processor.py | 12 ++-- python/sglang/srt/models/deepseek_v2.py | 12 ++-- python/sglang/srt/models/llama4.py | 9 ++- 5 files changed, 72 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 75e29fac373..7dfc995f1f6 100644 --- a/.gitignore +++ b/.gitignore @@ -175,6 +175,7 @@ benchmark/llava_bench/images benchmark/llava_bench/mme_pack *.jsonl tmp*.txt +core.* # Plots *.png diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 10b4bbabf01..6569f7daa48 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -26,16 +26,33 @@ _ATTN_TP_SIZE = None _DP_RANK = None _DP_SIZE = None - +_LOCAL_ATTN_DP_SIZE = None +_LOCAL_ATTN_DP_RANK = None def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 attn_tp_size = tp_size // dp_size - dp_rank = tp_rank // attn_tp_size + attn_dp_rank = tp_rank //tp_size attn_tp_rank = tp_rank % attn_tp_size - return attn_tp_rank, attn_tp_size, dp_rank + + return attn_tp_rank, attn_tp_size, attn_dp_rank + + +def compute_dp_attention_local_info(enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size): + if not enable_dp_attention: + return tp_rank, tp_size, 0 + + local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size + local_tp_rank = tp_rank % local_tp_size + local_dp_size = dp_size // (tp_size // local_tp_size) + + local_attn_tp_size = local_tp_size // local_dp_size + local_attn_dp_rank = local_tp_rank // local_attn_tp_size + local_attn_tp_rank = local_tp_rank % local_attn_tp_size + + return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank def initialize_dp_attention( @@ -44,18 +61,29 @@ def initialize_dp_attention( tp_size: int, dp_size: int, ): - global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE + global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE + global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP - _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( + _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info( enable_dp_attention, tp_rank, tp_size, dp_size ) + _, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info( + enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size + ) if enable_dp_attention: - _DP_SIZE = dp_size + _ATTN_DP_SIZE = dp_size + if moe_dense_tp_size is None: + _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE + else: + _LOCAL_ATTN_DP_SIZE = dp_size // (tp_size // moe_dense_tp_size) else: - _DP_SIZE = 1 + _ATTN_DP_SIZE = 1 + _LOCAL_ATTN_DP_SIZE = 1 + + logger.info(f"{(_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE)=}") tp_group = get_tp_group() _ATTN_TP_GROUP = GroupCoordinator( @@ -95,8 +123,28 @@ def get_attention_dp_rank(): def get_attention_dp_size(): - assert _DP_SIZE is not None, "dp attention not initialized!" - return _DP_SIZE + assert _ATTN_DP_SIZE is not None, "dp attention not initialized!" + return _ATTN_DP_SIZE + + +def get_local_attention_dp_rank(): + assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!" + return _LOCAL_ATTN_DP_RANK + + +def get_local_attention_dp_size(): + assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!" + return _LOCAL_ATTN_DP_SIZE + + +def get_local_attention_dp_rank(): + assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!" + return _LOCAL_ATTN_DP_RANK + + +def get_local_attention_dp_size(): + assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!" + return _LOCAL_ATTN_DP_SIZE @contextmanager @@ -121,7 +169,7 @@ def disable_dp_size(): def get_dp_local_info(forward_batch: ForwardBatch): - dp_rank = get_attention_dp_rank() + dp_rank = get_local_attention_dp_rank() if forward_batch.dp_local_start_pos is None: cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index d9d2986619c..9ef7b429b4e 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -25,9 +25,8 @@ from sglang.srt.distributed import tensor_model_parallel_all_gather from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, - get_attention_dp_rank, - get_attention_dp_size, - get_attention_tp_size, + get_local_attention_dp_rank, + get_local_attention_dp_size, ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -165,7 +164,7 @@ def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): return cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) - dp_rank = get_attention_dp_rank() + dp_rank = get_local_attention_dp_rank() if dp_rank == 0: dp_local_start_pos = torch.zeros_like( self.global_num_tokens_for_logprob_gpu[0] @@ -194,11 +193,10 @@ def __init__( super().__init__() self.config = config self.logit_scale = logit_scale - self.attn_tp_size = get_attention_tp_size() self.do_tensor_parallel_all_gather = ( not skip_all_gather and self.attn_tp_size > 1 ) - self.use_attn_tp_group = get_attention_dp_size() > 1 + self.use_attn_tp_group = get_local_attention_dp_size() > 1 self.final_logit_softcapping = getattr( self.config, "final_logit_softcapping", None ) @@ -310,7 +308,7 @@ def forward( if self.debug_tensor_dump_output_folder: assert ( - not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1 + not self.do_tensor_parallel_all_gather or get_local_attention_dp_size() == 1 ), "dp attention + sharded lm_head doesn't support full logits" full_logits = self._get_logits(hidden_states, lm_head, logits_metadata) dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0831e03b72e..d44ec9fe91d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -40,7 +40,7 @@ attn_tp_reduce_scatter, dp_gather_partial, dp_scatter, - get_attention_dp_size, + get_local_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, ) @@ -419,7 +419,6 @@ def __init__( self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank - self.dp_size = get_attention_dp_size() attn_tp_rank = get_attention_tp_rank() attn_tp_size = get_attention_tp_size() @@ -1084,7 +1083,7 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] self.layer_id = layer_id - self.dp_size = get_attention_dp_size() + self.local_dp_size = get_local_attention_dp_size() self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() self.self_attn = DeepseekV2AttentionMLA( @@ -1214,7 +1213,7 @@ def forward_ffn_with_full_input( # Gather if get_tensor_model_parallel_world_size() > 1: # all gather and all reduce - if self.dp_size != 1: + if self.local_dp_size != 1: if self.attn_tp_rank == 0: hidden_states += residual hidden_states, local_hidden_states = ( @@ -1239,7 +1238,7 @@ def forward_ffn_with_full_input( # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter # Scatter - if self.dp_size != 1: + if self.local_dp_size != 1: # important: forward batch.gathered_buffer is used both after scatter and after gather. # be careful about this! hidden_states, global_hidden_states = ( @@ -1361,8 +1360,6 @@ def __init__( ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.dp_size = get_attention_dp_size() - def forward( self, input_ids: torch.Tensor, @@ -1451,7 +1448,6 @@ def __init__( use_attn_tp_group=global_server_args_dict["enable_dp_attention"], ) self.logits_processor = LogitsProcessor(config) - self.dp_size = get_attention_dp_size() def get_input_embeddings(self) -> nn.Embedding: return self.model.embed_tokens diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 88c3716f76a..c21eac0a011 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -30,7 +30,7 @@ from sglang.srt.layers.dp_attention import ( dp_gather_partial, dp_scatter, - get_attention_dp_size, + get_local_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, ) @@ -152,7 +152,6 @@ def __init__( self.use_rope = int((layer_id + 1) % 4 != 0) self.use_qk_norm = config.use_qk_norm and self.use_rope - self.dp_size = get_attention_dp_size() attn_tp_rank = get_attention_tp_rank() attn_tp_size = get_attention_tp_size() @@ -297,7 +296,7 @@ def __init__( rope_theta = config.rope_theta rope_scaling = config.rope_scaling max_position_embeddings = config.max_position_embeddings - self.dp_size = get_attention_dp_size() + self.local_dp_size = get_local_attention_dp_size() self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() @@ -360,7 +359,7 @@ def forward( # Gather if get_tensor_model_parallel_world_size() > 1: # all gather and all reduce - if self.dp_size != 1: + if self.local_dp_size != 1: if self.attn_tp_rank == 0: hidden_states += residual hidden_states, local_hidden_states = ( @@ -385,7 +384,7 @@ def forward( # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter # Scatter - if self.dp_size != 1: + if self.local_dp_size != 1: # important: forward batch.gathered_buffer is used both after scatter and after gather. # be careful about this! hidden_states, global_hidden_states = ( From d0a9b996c74f788e8d0827121c260f6e5c989364 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Sat, 19 Apr 2025 20:44:45 +0000 Subject: [PATCH 08/26] fix (cherry picked from commit 13e931be9fefcdce2680a8eed31da2d3e7d45824) --- python/sglang/srt/layers/dp_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 6569f7daa48..00a4aaf1c35 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -34,7 +34,7 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si return tp_rank, tp_size, 0 attn_tp_size = tp_size // dp_size - attn_dp_rank = tp_rank //tp_size + attn_dp_rank = tp_rank // attn_tp_size attn_tp_rank = tp_rank % attn_tp_size return attn_tp_rank, attn_tp_size, attn_dp_rank From 515f20fbb1cfa3ed023efc65102c215d406d7124 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Tue, 22 Apr 2025 22:23:53 +0000 Subject: [PATCH 09/26] several fix --- python/sglang/srt/layers/dp_attention.py | 7 +++---- python/sglang/srt/layers/logits_processor.py | 2 ++ python/sglang/srt/model_executor/model_runner.py | 4 ++++ python/sglang/srt/models/deepseek_v2.py | 3 ++- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 00a4aaf1c35..1f14cae03ed 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -46,7 +46,7 @@ def compute_dp_attention_local_info(enable_dp_attention, tp_rank, tp_size, dp_si local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size local_tp_rank = tp_rank % local_tp_size - local_dp_size = dp_size // (tp_size // local_tp_size) + local_dp_size = max(1, dp_size // (tp_size // local_tp_size)) local_attn_tp_size = local_tp_size // local_dp_size local_attn_dp_rank = local_tp_rank // local_attn_tp_size @@ -60,6 +60,7 @@ def initialize_dp_attention( tp_rank: int, tp_size: int, dp_size: int, + moe_dense_tp_size: int, ): global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK @@ -78,13 +79,11 @@ def initialize_dp_attention( if moe_dense_tp_size is None: _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE else: - _LOCAL_ATTN_DP_SIZE = dp_size // (tp_size // moe_dense_tp_size) + _LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size)) else: _ATTN_DP_SIZE = 1 _LOCAL_ATTN_DP_SIZE = 1 - logger.info(f"{(_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE)=}") - tp_group = get_tp_group() _ATTN_TP_GROUP = GroupCoordinator( [ diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 9ef7b429b4e..d6ea820d805 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -25,6 +25,7 @@ from sglang.srt.distributed import tensor_model_parallel_all_gather from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, + get_attention_tp_size, get_local_attention_dp_rank, get_local_attention_dp_size, ) @@ -193,6 +194,7 @@ def __init__( super().__init__() self.config = config self.logit_scale = logit_scale + self.attn_tp_size = get_attention_tp_size() self.do_tensor_parallel_all_gather = ( not skip_all_gather and self.attn_tp_size > 1 ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6a48a60b716..de578aca92b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -362,6 +362,7 @@ def init_torch_distributed(self): tp_rank=self.tp_rank, tp_size=self.tp_size, dp_size=self.server_args.dp_size, + moe_dense_tp_size=self.server_args.moe_dense_tp_size, ) min_per_gpu_memory = get_available_gpu_memory( @@ -1022,9 +1023,12 @@ def forward( and self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch) ): + logger.info("cuda graph replay is enabled") return self.cuda_graph_runner.replay( forward_batch, skip_attn_backend_init=skip_attn_backend_init ) + else: + logger.info("cuda graph replay is disabled") if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d44ec9fe91d..ab651252c21 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1133,7 +1133,8 @@ def __init__( ) self.input_is_scattered = ( - previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED + layer_id > 0 + and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 From 462f51e8b8edea0d786534029c3963624296c64e Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Tue, 22 Apr 2025 20:47:47 -0700 Subject: [PATCH 10/26] Update .gitignore --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 7dfc995f1f6..75e29fac373 100644 --- a/.gitignore +++ b/.gitignore @@ -175,7 +175,6 @@ benchmark/llava_bench/images benchmark/llava_bench/mme_pack *.jsonl tmp*.txt -core.* # Plots *.png From 5adc5e51d6f5aa28326b2ebaf80d1bf402ce1b54 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Wed, 23 Apr 2025 03:57:10 +0000 Subject: [PATCH 11/26] fix refactor --- python/sglang/srt/layers/dp_attention.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 1f14cae03ed..58223209b1f 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -24,8 +24,8 @@ _ATTN_TP_GROUP = None _ATTN_TP_RANK = None _ATTN_TP_SIZE = None -_DP_RANK = None -_DP_SIZE = None +_ATTN_DP_RANK = None +_ATTN_DP_SIZE = None _LOCAL_ATTN_DP_SIZE = None _LOCAL_ATTN_DP_RANK = None @@ -117,8 +117,8 @@ def get_attention_tp_size(): def get_attention_dp_rank(): - assert _DP_RANK is not None, "dp attention not initialized!" - return _DP_RANK + assert _ATTN_DP_RANK is not None, "dp attention not initialized!" + return _ATTN_DP_RANK def get_attention_dp_size(): @@ -156,15 +156,15 @@ def disable_dp_size(): Args: tp_group (GroupCoordinator): the tp group coordinator """ - global _DP_SIZE - assert _DP_SIZE is not None, "dp attention not initialized!" + global _ATTN_DP_SIZE + assert _ATTN_DP_SIZE is not None, "dp attention not initialized!" - old_dp_size = _DP_SIZE - _DP_SIZE = 1 + old_dp_size = _ATTN_DP_SIZE + _ATTN_DP_SIZE = 1 try: yield finally: - _DP_SIZE = old_dp_size + _ATTN_DP_SIZE = old_dp_size def get_dp_local_info(forward_batch: ForwardBatch): From 9769217dc964ede3a65d2c5c5ff0bf767c7957de Mon Sep 17 00:00:00 2001 From: ch-wan Date: Wed, 23 Apr 2025 04:56:25 +0000 Subject: [PATCH 12/26] optimize memory --- python/sglang/srt/managers/scheduler.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d2a601f919e..ee9a6b97f08 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -193,7 +193,7 @@ def __init__( # Distributed rank info self.dp_size = server_args.dp_size - self.attn_tp_rank, self.attn_tp_size, self.dp_rank = ( + self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( compute_dp_attention_world_info( server_args.enable_dp_attention, self.tp_rank, @@ -708,7 +708,7 @@ def recv_requests(self) -> List[Req]: control_reqs = None if self.attn_tp_size != 1: - attn_tp_rank_0 = self.dp_rank * self.attn_tp_size + attn_tp_rank_0 = self.attn_dp_rank * self.attn_tp_size work_reqs = broadcast_pyobj( work_reqs, self.attn_tp_rank, @@ -1433,6 +1433,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): local_batch, dp_size=self.server_args.dp_size, attn_tp_size=self.attn_tp_size, + moe_dense_tp_size = self.server_args.moe_dense_tp_size, tp_cpu_group=self.tp_cpu_group, get_idle_batch=self.get_idle_batch, disable_cuda_graph=self.server_args.disable_cuda_graph, @@ -1445,6 +1446,7 @@ def prepare_dp_attn_batch_raw( local_batch: ScheduleBatch, dp_size, attn_tp_size: int, + moe_dense_tp_size: Optional[int], tp_cpu_group, get_idle_batch, disable_cuda_graph: bool, @@ -1454,15 +1456,15 @@ def prepare_dp_attn_batch_raw( # Check if other DP workers have running batches if local_batch is None: num_tokens = 0 - global_num_tokens_for_logprob = 0 + num_tokens_for_logprob = 0 elif local_batch.forward_mode.is_decode(): num_tokens = local_batch.batch_size() if not spec_algorithm.is_none() and spec_algorithm.is_eagle(): num_tokens = num_tokens * speculative_num_draft_tokens - global_num_tokens_for_logprob = num_tokens + num_tokens_for_logprob = num_tokens else: num_tokens = local_batch.extend_num_tokens - global_num_tokens_for_logprob = sum( + num_tokens_for_logprob = sum( [ # We should have at least 1 token for sample in every case. max(extend_len - logprob_start_len, 1) @@ -1489,7 +1491,7 @@ def prepare_dp_attn_batch_raw( [ num_tokens, can_cuda_graph, - global_num_tokens_for_logprob, + num_tokens_for_logprob, is_extend_in_batch, ], dtype=torch.int64, @@ -1512,8 +1514,12 @@ def prepare_dp_attn_batch_raw( local_batch = get_idle_batch() if local_batch is not None: - local_batch.global_num_tokens = global_num_tokens - local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob + if moe_dense_tp_size == 1: + local_batch.global_num_tokens = [num_tokens] + local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob] + else: + local_batch.global_num_tokens = global_num_tokens + local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob # Check forward mode for cuda graph if not disable_cuda_graph: From 3b6b6d79a06fca6f4714f8673dccd37c08e2ef34 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Wed, 23 Apr 2025 08:44:59 +0000 Subject: [PATCH 13/26] add debug info --- python/sglang/srt/model_executor/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index de578aca92b..dcd2788b0c7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1023,12 +1023,12 @@ def forward( and self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch) ): - logger.info("cuda graph replay is enabled") + logger.debug("Cuda graph replay is enabled") return self.cuda_graph_runner.replay( forward_batch, skip_attn_backend_init=skip_attn_backend_init ) else: - logger.info("cuda graph replay is disabled") + logger.debug("Cuda graph replay is disabled") if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) From 16c4b744074a3aef0931b2086e81b902b4e24737 Mon Sep 17 00:00:00 2001 From: ch-wan Date: Wed, 23 Apr 2025 08:47:33 +0000 Subject: [PATCH 14/26] format --- python/sglang/srt/layers/dp_attention.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 58223209b1f..3dd2d07a00a 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -29,6 +29,7 @@ _LOCAL_ATTN_DP_SIZE = None _LOCAL_ATTN_DP_RANK = None + def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 @@ -40,10 +41,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si return attn_tp_rank, attn_tp_size, attn_dp_rank -def compute_dp_attention_local_info(enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size): +def compute_dp_attention_local_info( + enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size +): if not enable_dp_attention: return tp_rank, tp_size, 0 - + local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size local_tp_rank = tp_rank % local_tp_size local_dp_size = max(1, dp_size // (tp_size // local_tp_size)) From f0674f71fe499893e9c8cbf5f00bcfd875360a9e Mon Sep 17 00:00:00 2001 From: ch-wan Date: Wed, 23 Apr 2025 09:36:31 +0000 Subject: [PATCH 15/26] format --- python/sglang/srt/layers/logits_processor.py | 3 ++- python/sglang/srt/managers/scheduler.py | 6 ++++-- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/models/llama4.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index d6ea820d805..e6bd82d848e 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -310,7 +310,8 @@ def forward( if self.debug_tensor_dump_output_folder: assert ( - not self.do_tensor_parallel_all_gather or get_local_attention_dp_size() == 1 + not self.do_tensor_parallel_all_gather + or get_local_attention_dp_size() == 1 ), "dp attention + sharded lm_head doesn't support full logits" full_logits = self._get_logits(hidden_states, lm_head, logits_metadata) dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ee9a6b97f08..e4a5b6b78fa 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1433,7 +1433,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): local_batch, dp_size=self.server_args.dp_size, attn_tp_size=self.attn_tp_size, - moe_dense_tp_size = self.server_args.moe_dense_tp_size, + moe_dense_tp_size=self.server_args.moe_dense_tp_size, tp_cpu_group=self.tp_cpu_group, get_idle_batch=self.get_idle_batch, disable_cuda_graph=self.server_args.disable_cuda_graph, @@ -1519,7 +1519,9 @@ def prepare_dp_attn_batch_raw( local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob] else: local_batch.global_num_tokens = global_num_tokens - local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob + local_batch.global_num_tokens_for_logprob = ( + global_num_tokens_for_logprob + ) # Check forward mode for cuda graph if not disable_cuda_graph: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ab651252c21..969e009adf5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -40,9 +40,9 @@ attn_tp_reduce_scatter, dp_gather_partial, dp_scatter, - get_local_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, + get_local_attention_dp_size, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index c21eac0a011..8841ae2f647 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import ( dp_gather_partial, dp_scatter, - get_local_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, + get_local_attention_dp_size, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( From 182aa52574eaefdedc4c9db1aec95d060142bdec Mon Sep 17 00:00:00 2001 From: liusy58 Date: Sat, 10 May 2025 19:49:38 +0800 Subject: [PATCH 16/26] Add `use_attn_tp_group` for user to decide whether to use vocabulary parallelism or data parallelism for LM head. --- docs/backend/server_arguments.md | 1 + python/sglang/srt/layers/logits_processor.py | 24 ++++++++++++++------ python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/models/llama.py | 2 +- python/sglang/srt/server_args.py | 6 +++++ 5 files changed, 26 insertions(+), 8 deletions(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 0516c844811..d793d9317b9 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -222,3 +222,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` | | `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` | | `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` | +| `use_attn_tp_group` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` | diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 8b1c343efe6..59866e372bd 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -22,10 +22,12 @@ import triton.language as tl from torch import nn -from sglang.srt.distributed import tensor_model_parallel_all_gather +from sglang.srt.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, - get_attention_dp_rank, get_attention_dp_size, get_attention_tp_size, ) @@ -194,11 +196,19 @@ def __init__( super().__init__() self.config = config self.logit_scale = logit_scale - self.attn_tp_size = get_attention_tp_size() - self.do_tensor_parallel_all_gather = ( - not skip_all_gather and self.attn_tp_size > 1 - ) - self.use_attn_tp_group = get_attention_dp_size() > 1 + self.use_attn_tp_group = global_server_args_dict["use_attn_tp_group"] + if self.use_attn_tp_group: + self.attn_tp_size = get_attention_tp_size() + self.do_tensor_parallel_all_gather = ( + not skip_all_gather and self.attn_tp_size > 1 + ) + else: + self.do_tensor_parallel_all_gather = ( + not skip_all_gather and get_tensor_model_parallel_world_size() > 1 + ) + self.do_tensor_parallel_all_gather_dp_attn = ( + self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1 + ) self.final_logit_softcapping = getattr( self.config, "final_logit_softcapping", None ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ee9f40719f2..bd43335abd9 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -73,6 +73,7 @@ "disable_radix_cache": ServerArgs.disable_radix_cache, "enable_deepep_moe": ServerArgs.enable_deepep_moe, "enable_dp_attention": ServerArgs.enable_dp_attention, + "use_attn_tp_group": ServerArgs.use_attn_tp_group, "enable_ep_moe": ServerArgs.enable_ep_moe, "enable_nan_detection": ServerArgs.enable_nan_detection, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 517588c91ea..a07677bfc66 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -421,7 +421,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_attention"], + use_attn_tp_group=global_server_args_dict["use_attn_tp_group"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0aa71e34478..90585b85ab0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -159,6 +159,7 @@ class ServerArgs: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False + use_attn_tp_group: bool = False enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" @@ -1049,6 +1050,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) + parser.add_argument( + "--use-attn-tp-group", + action="store_true", + help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.", + ) parser.add_argument( "--enable-ep-moe", action="store_true", From 4712ed035581658eca1e366cd1464323ec571c0f Mon Sep 17 00:00:00 2001 From: liusy58 Date: Sat, 10 May 2025 19:49:38 +0800 Subject: [PATCH 17/26] Add `use_attn_tp_group` for user to decide whether to use vocabulary parallelism or data parallelism for LM head. --- docs/backend/server_arguments.md | 1 + python/sglang/srt/layers/logits_processor.py | 24 +++++++++++++++----- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/models/llama.py | 2 +- python/sglang/srt/server_args.py | 6 +++++ 5 files changed, 27 insertions(+), 7 deletions(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 0516c844811..d793d9317b9 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -222,3 +222,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` | | `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` | | `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` | +| `use_attn_tp_group` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` | diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index ece6b9e8c04..4f132843b17 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -22,9 +22,13 @@ import triton.language as tl from torch import nn -from sglang.srt.distributed import tensor_model_parallel_all_gather +from sglang.srt.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, + get_attention_dp_size, get_attention_tp_size, get_local_attention_dp_rank, get_local_attention_dp_size, @@ -194,11 +198,19 @@ def __init__( super().__init__() self.config = config self.logit_scale = logit_scale - self.attn_tp_size = get_attention_tp_size() - self.do_tensor_parallel_all_gather = ( - not skip_all_gather and self.attn_tp_size > 1 - ) - self.use_attn_tp_group = get_local_attention_dp_size() > 1 + self.use_attn_tp_group = global_server_args_dict["use_attn_tp_group"] + if self.use_attn_tp_group: + self.attn_tp_size = get_attention_tp_size() + self.do_tensor_parallel_all_gather = ( + not skip_all_gather and self.attn_tp_size > 1 + ) + else: + self.do_tensor_parallel_all_gather = ( + not skip_all_gather and get_tensor_model_parallel_world_size() > 1 + ) + self.do_tensor_parallel_all_gather_dp_attn = ( + self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1 + ) self.final_logit_softcapping = getattr( self.config, "final_logit_softcapping", None ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ee9f40719f2..bd43335abd9 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -73,6 +73,7 @@ "disable_radix_cache": ServerArgs.disable_radix_cache, "enable_deepep_moe": ServerArgs.enable_deepep_moe, "enable_dp_attention": ServerArgs.enable_dp_attention, + "use_attn_tp_group": ServerArgs.use_attn_tp_group, "enable_ep_moe": ServerArgs.enable_ep_moe, "enable_nan_detection": ServerArgs.enable_nan_detection, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 517588c91ea..a07677bfc66 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -421,7 +421,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_attention"], + use_attn_tp_group=global_server_args_dict["use_attn_tp_group"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0aa71e34478..90585b85ab0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -159,6 +159,7 @@ class ServerArgs: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False + use_attn_tp_group: bool = False enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" @@ -1049,6 +1050,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) + parser.add_argument( + "--use-attn-tp-group", + action="store_true", + help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.", + ) parser.add_argument( "--enable-ep-moe", action="store_true", From 5e8e44e137e6af57a121a7ae9bbfe4bb01b061fd Mon Sep 17 00:00:00 2001 From: liusy58 Date: Sun, 11 May 2025 13:30:40 +0800 Subject: [PATCH 18/26] Rename `use_attn_tp_group` to `enable_dp_lm_head` and refactor the `_get_logits` function to support vocabulary parallelism by default. --- docs/backend/server_arguments.md | 2 +- python/sglang/srt/layers/logits_processor.py | 62 ++++++++++++++++---- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/models/llama.py | 2 +- python/sglang/srt/server_args.py | 9 ++- 5 files changed, 59 insertions(+), 18 deletions(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index d793d9317b9..429cda14bdc 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -222,4 +222,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` | | `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` | | `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` | -| `use_attn_tp_group` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` | +| `enable_dp_lm_head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` | diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 4f132843b17..91fc93e12be 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -28,6 +28,8 @@ ) from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, + dp_gather_replicate, + dp_scatter, get_attention_dp_size, get_attention_tp_size, get_local_attention_dp_rank, @@ -198,7 +200,7 @@ def __init__( super().__init__() self.config = config self.logit_scale = logit_scale - self.use_attn_tp_group = global_server_args_dict["use_attn_tp_group"] + self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"] if self.use_attn_tp_group: self.attn_tp_size = get_attention_tp_size() self.do_tensor_parallel_all_gather = ( @@ -430,20 +432,21 @@ def _get_logits( last position (e.g., extend without input logprobs). The caller should guarantee the given hidden_states follow this constraint. """ + if self.use_attn_tp_group: + if hasattr(lm_head, "weight"): + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T + ) + else: + # GGUF models + logits = lm_head.quant_method.apply( + lm_head, hidden_states, embedding_bias + ) - if hasattr(lm_head, "weight"): - logits = torch.matmul( - hidden_states.to(lm_head.weight.dtype), lm_head.weight.T - ) - else: - # GGUF models - logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias) - - if self.logit_scale is not None: - logits.mul_(self.logit_scale) + if self.logit_scale is not None: + logits.mul_(self.logit_scale) - if self.do_tensor_parallel_all_gather: - if self.use_attn_tp_group: + if self.do_tensor_parallel_all_gather: global_logits = torch.empty( (self.config.vocab_size, logits.shape[0]), device=logits.device, @@ -454,9 +457,42 @@ def _get_logits( list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits ) logits = global_logits + else: + if self.do_tensor_parallel_all_gather_dp_attn: + logits_metadata.compute_dp_attention_metadata(hidden_states) + hidden_states, local_hidden_states = ( + logits_metadata.gathered_buffer, + hidden_states.clone(), + ) + dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) + + if hasattr(lm_head, "weight"): + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T + ) else: + # GGUF models + logits = lm_head.quant_method.apply( + lm_head, hidden_states, embedding_bias + ) + + if self.logit_scale is not None: + logits.mul_(self.logit_scale) + + if self.do_tensor_parallel_all_gather: logits = tensor_model_parallel_all_gather(logits) + if self.do_tensor_parallel_all_gather_dp_attn: + logits, global_logits = ( + torch.empty( + (local_hidden_states.shape[0], logits.shape[1]), + device=logits.device, + dtype=logits.dtype, + ), + logits, + ) + dp_scatter(logits, global_logits, logits_metadata) + logits = logits[:, : self.config.vocab_size].float() if self.final_logit_softcapping: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index bd43335abd9..f6728d59230 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -73,7 +73,7 @@ "disable_radix_cache": ServerArgs.disable_radix_cache, "enable_deepep_moe": ServerArgs.enable_deepep_moe, "enable_dp_attention": ServerArgs.enable_dp_attention, - "use_attn_tp_group": ServerArgs.use_attn_tp_group, + "enable_dp_lm_head": ServerArgs.enable_dp_lm_head, "enable_ep_moe": ServerArgs.enable_ep_moe, "enable_nan_detection": ServerArgs.enable_nan_detection, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index a07677bfc66..dc4d8f9df35 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -421,7 +421,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["use_attn_tp_group"], + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 90585b85ab0..610a66974ee 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -159,7 +159,7 @@ class ServerArgs: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False - use_attn_tp_group: bool = False + enable_dp_lm_head: bool = False enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" @@ -318,6 +318,11 @@ def __post_init__(self): f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " ) + if self.enable_dp_lm_head: + assert ( + self.enable_dp_attention + ), "Please enable dp attention when setting enable_dp_attention. " + # DeepEP MoE self.enable_sp_layernorm = False if self.enable_deepep_moe: @@ -1051,7 +1056,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) parser.add_argument( - "--use-attn-tp-group", + "--enable-dp-lm-head", action="store_true", help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.", ) From 8c6ec17e4b8b2cc3f43cb507f952e03b68c3783a Mon Sep 17 00:00:00 2001 From: liusy58 Date: Sun, 11 May 2025 13:30:40 +0800 Subject: [PATCH 19/26] Rename `use_attn_tp_group` to `enable_dp_lm_head` and refactor the `_get_logits` function to support vocabulary parallelism by default. --- docs/backend/server_arguments.md | 2 +- python/sglang/srt/layers/logits_processor.py | 62 ++++++++++++++++---- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/models/llama.py | 2 +- python/sglang/srt/server_args.py | 9 ++- 5 files changed, 59 insertions(+), 18 deletions(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index d793d9317b9..429cda14bdc 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -222,4 +222,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` | | `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` | | `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` | -| `use_attn_tp_group` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` | +| `enable_dp_lm_head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` | diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 59866e372bd..d803be54aea 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -28,6 +28,8 @@ ) from sglang.srt.layers.dp_attention import ( attn_tp_all_gather, + dp_gather_replicate, + dp_scatter, get_attention_dp_size, get_attention_tp_size, ) @@ -196,7 +198,7 @@ def __init__( super().__init__() self.config = config self.logit_scale = logit_scale - self.use_attn_tp_group = global_server_args_dict["use_attn_tp_group"] + self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"] if self.use_attn_tp_group: self.attn_tp_size = get_attention_tp_size() self.do_tensor_parallel_all_gather = ( @@ -427,20 +429,21 @@ def _get_logits( last position (e.g., extend without input logprobs). The caller should guarantee the given hidden_states follow this constraint. """ + if self.use_attn_tp_group: + if hasattr(lm_head, "weight"): + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T + ) + else: + # GGUF models + logits = lm_head.quant_method.apply( + lm_head, hidden_states, embedding_bias + ) - if hasattr(lm_head, "weight"): - logits = torch.matmul( - hidden_states.to(lm_head.weight.dtype), lm_head.weight.T - ) - else: - # GGUF models - logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias) - - if self.logit_scale is not None: - logits.mul_(self.logit_scale) + if self.logit_scale is not None: + logits.mul_(self.logit_scale) - if self.do_tensor_parallel_all_gather: - if self.use_attn_tp_group: + if self.do_tensor_parallel_all_gather: global_logits = torch.empty( (self.config.vocab_size, logits.shape[0]), device=logits.device, @@ -451,9 +454,42 @@ def _get_logits( list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits ) logits = global_logits + else: + if self.do_tensor_parallel_all_gather_dp_attn: + logits_metadata.compute_dp_attention_metadata(hidden_states) + hidden_states, local_hidden_states = ( + logits_metadata.gathered_buffer, + hidden_states.clone(), + ) + dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) + + if hasattr(lm_head, "weight"): + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T + ) else: + # GGUF models + logits = lm_head.quant_method.apply( + lm_head, hidden_states, embedding_bias + ) + + if self.logit_scale is not None: + logits.mul_(self.logit_scale) + + if self.do_tensor_parallel_all_gather: logits = tensor_model_parallel_all_gather(logits) + if self.do_tensor_parallel_all_gather_dp_attn: + logits, global_logits = ( + torch.empty( + (local_hidden_states.shape[0], logits.shape[1]), + device=logits.device, + dtype=logits.dtype, + ), + logits, + ) + dp_scatter(logits, global_logits, logits_metadata) + logits = logits[:, : self.config.vocab_size].float() if self.final_logit_softcapping: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index bd43335abd9..f6728d59230 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -73,7 +73,7 @@ "disable_radix_cache": ServerArgs.disable_radix_cache, "enable_deepep_moe": ServerArgs.enable_deepep_moe, "enable_dp_attention": ServerArgs.enable_dp_attention, - "use_attn_tp_group": ServerArgs.use_attn_tp_group, + "enable_dp_lm_head": ServerArgs.enable_dp_lm_head, "enable_ep_moe": ServerArgs.enable_ep_moe, "enable_nan_detection": ServerArgs.enable_nan_detection, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index a07677bfc66..dc4d8f9df35 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -421,7 +421,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["use_attn_tp_group"], + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 90585b85ab0..610a66974ee 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -159,7 +159,7 @@ class ServerArgs: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False - use_attn_tp_group: bool = False + enable_dp_lm_head: bool = False enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" @@ -318,6 +318,11 @@ def __post_init__(self): f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " ) + if self.enable_dp_lm_head: + assert ( + self.enable_dp_attention + ), "Please enable dp attention when setting enable_dp_attention. " + # DeepEP MoE self.enable_sp_layernorm = False if self.enable_deepep_moe: @@ -1051,7 +1056,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) parser.add_argument( - "--use-attn-tp-group", + "--enable-dp-lm-head", action="store_true", help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.", ) From efea846c686395c35f0bb0c5064efd1750f9cf46 Mon Sep 17 00:00:00 2001 From: liusy58 Date: Sun, 11 May 2025 16:38:59 +0800 Subject: [PATCH 20/26] Gather is needed if `enable_dp_lm_head` is not set. --- python/sglang/srt/managers/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6654e7dc048..b75dc473cf6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1691,7 +1691,7 @@ def prepare_dp_attn_batch_raw( local_batch = get_idle_batch() if local_batch is not None: - if moe_dense_tp_size == 1: + if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]: local_batch.global_num_tokens = [num_tokens] local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob] else: From f84c245f98fe6d1f7c3703efcae82fddd0981922 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 11 May 2025 01:48:48 -0700 Subject: [PATCH 21/26] Update scheduler.py --- python/sglang/srt/managers/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b75dc473cf6..3b3c91eb891 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1691,6 +1691,7 @@ def prepare_dp_attn_batch_raw( local_batch = get_idle_batch() if local_batch is not None: + # TODO: handle the case when moe_dense_tp_size != 1 if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]: local_batch.global_num_tokens = [num_tokens] local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob] From 71c12f6ba425c65e64281dd74435d0639b485763 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 11 May 2025 11:45:08 -0700 Subject: [PATCH 22/26] update code style --- python/sglang/srt/layers/logits_processor.py | 77 ++++++++------------ 1 file changed, 32 insertions(+), 45 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 62a7decd615..afa89ac8dc3 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -205,6 +205,7 @@ def __init__( self.do_tensor_parallel_all_gather = ( not skip_all_gather and self.attn_tp_size > 1 ) + self.do_tensor_parallel_all_gather_dp_attn = False else: self.do_tensor_parallel_all_gather = ( not skip_all_gather and get_tensor_model_parallel_world_size() > 1 @@ -430,21 +431,29 @@ def _get_logits( last position (e.g., extend without input logprobs). The caller should guarantee the given hidden_states follow this constraint. """ - if self.use_attn_tp_group: - if hasattr(lm_head, "weight"): - logits = torch.matmul( - hidden_states.to(lm_head.weight.dtype), lm_head.weight.T - ) - else: - # GGUF models - logits = lm_head.quant_method.apply( - lm_head, hidden_states, embedding_bias - ) + if self.do_tensor_parallel_all_gather_dp_attn: + logits_metadata.compute_dp_attention_metadata(hidden_states) + hidden_states, local_hidden_states = ( + logits_metadata.gathered_buffer, + hidden_states.clone(), + ) + dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) + + if hasattr(lm_head, "weight"): + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T + ) + else: + # GGUF models + logits = lm_head.quant_method.apply( + lm_head, hidden_states, embedding_bias + ) - if self.logit_scale is not None: - logits.mul_(self.logit_scale) + if self.logit_scale is not None: + logits.mul_(self.logit_scale) - if self.do_tensor_parallel_all_gather: + if self.do_tensor_parallel_all_gather: + if self.use_attn_tp_group: global_logits = torch.empty( (self.config.vocab_size, logits.shape[0]), device=logits.device, @@ -455,41 +464,19 @@ def _get_logits( list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits ) logits = global_logits - else: - if self.do_tensor_parallel_all_gather_dp_attn: - logits_metadata.compute_dp_attention_metadata(hidden_states) - hidden_states, local_hidden_states = ( - logits_metadata.gathered_buffer, - hidden_states.clone(), - ) - dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) - - if hasattr(lm_head, "weight"): - logits = torch.matmul( - hidden_states.to(lm_head.weight.dtype), lm_head.weight.T - ) else: - # GGUF models - logits = lm_head.quant_method.apply( - lm_head, hidden_states, embedding_bias - ) - - if self.logit_scale is not None: - logits.mul_(self.logit_scale) - - if self.do_tensor_parallel_all_gather: logits = tensor_model_parallel_all_gather(logits) - if self.do_tensor_parallel_all_gather_dp_attn: - logits, global_logits = ( - torch.empty( - (local_hidden_states.shape[0], logits.shape[1]), - device=logits.device, - dtype=logits.dtype, - ), - logits, - ) - dp_scatter(logits, global_logits, logits_metadata) + if self.do_tensor_parallel_all_gather_dp_attn: + logits, global_logits = ( + torch.empty( + (local_hidden_states.shape[0], logits.shape[1]), + device=logits.device, + dtype=logits.dtype, + ), + logits, + ) + dp_scatter(logits, global_logits, logits_metadata) logits = logits[:, : self.config.vocab_size].float() From 160517b405fd61d9dc7028e500b7bf61d5e3d635 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 11 May 2025 11:47:35 -0700 Subject: [PATCH 23/26] format --- python/sglang/srt/layers/logits_processor.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index afa89ac8dc3..5a4f0781729 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -438,16 +438,14 @@ def _get_logits( hidden_states.clone(), ) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) - + if hasattr(lm_head, "weight"): logits = torch.matmul( hidden_states.to(lm_head.weight.dtype), lm_head.weight.T ) else: # GGUF models - logits = lm_head.quant_method.apply( - lm_head, hidden_states, embedding_bias - ) + logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias) if self.logit_scale is not None: logits.mul_(self.logit_scale) From 5d02170c45e61a7eac42fe1c49fb841c4604c341 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 11 May 2025 12:41:54 -0700 Subject: [PATCH 24/26] fix --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9c3e0831fa6..e8ef96a6eec 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1475,7 +1475,7 @@ def __init__( config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_attention"], + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.dp_size = get_attention_dp_size() From bf10e717819fea7270ac070b19e3fe5b72dd2f1f Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 11 May 2025 23:12:05 -0700 Subject: [PATCH 25/26] Update logits_processor.py --- python/sglang/srt/layers/logits_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index eab59cf12b1..70c82399249 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -30,7 +30,6 @@ attn_tp_all_gather, dp_gather_replicate, dp_scatter, - get_attention_dp_rank, get_attention_dp_size, get_attention_tp_size, get_local_attention_dp_rank, From 25c838f716c84b3fbc0e27abd0c925aaab536301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A2=89=E6=B2=86?= Date: Mon, 12 May 2025 17:32:15 +0800 Subject: [PATCH 26/26] rename `dp_rank` to `attn_dp_rank` --- python/sglang/srt/managers/scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9c48aae5adb..6ae3004c61c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -773,7 +773,7 @@ def event_loop_pp(self): if not self.pp_group.is_last_rank: # send out reqs to the next stage - dp_offset = self.dp_rank * self.attn_tp_size + dp_offset = self.attn_dp_rank * self.attn_tp_size if self.attn_tp_rank == 0: point_to_point_pyobj( recv_reqs, @@ -820,7 +820,7 @@ def recv_requests(self) -> List[Req]: recv_reqs = None else: if self.attn_tp_rank == 0: - dp_offset = self.dp_rank * self.attn_tp_size + dp_offset = self.attn_dp_rank * self.attn_tp_size recv_reqs = point_to_point_pyobj( [], self.pp_rank * self.tp_size + dp_offset, @@ -2192,8 +2192,8 @@ def close_session(self, recv_req: CloseSessionReqInput): def get_print_prefix(self): prefix = "" - if self.dp_rank is not None: - prefix += f" DP{self.dp_rank}" + if self.attn_dp_rank is not None: + prefix += f" DP{self.attn_dp_rank}" if self.server_args.tp_size > 1: prefix += f" TP{self.tp_rank}" if self.pp_size > 1: