Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions python/sglang/srt/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_pp_indices(
) -> Tuple[int, int]:
"""Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers.
the first N partitions will have one extra layer, where N = remainder.
"""
# partition_list_str can be set to None in sglang
partition_list_str = os.getenv("SGLANG_PP_LAYER_PARTITION", None)
Expand All @@ -83,12 +83,19 @@ def get_pp_indices(
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
else:
layers_per_partition = num_hidden_layers // pp_size
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition

if pp_rank == pp_size - 1:
end_layer = num_hidden_layers
base_layers = num_hidden_layers // pp_size
remainder = num_hidden_layers % pp_size
# Distribute the extra layers to the first 'remainder' partitions
if pp_rank < remainder:
# This partition gets one extra layer
start_layer = pp_rank * (base_layers + 1)
end_layer = start_layer + (base_layers + 1)
else:
# This partition gets only base layers
start_layer = (
remainder * (base_layers + 1) + (pp_rank - remainder) * base_layers
)
end_layer = start_layer + base_layers
Comment on lines +86 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating start_layer and end_layer is correct, but it can be simplified for better readability and maintainability. You can calculate start_layer with a single, more intuitive formula and then determine end_layer based on whether the current rank receives an extra layer.

Suggested change
base_layers = num_hidden_layers // pp_size
remainder = num_hidden_layers % pp_size
# Distribute the extra layers to the first 'remainder' partitions
if pp_rank < remainder:
# This partition gets one extra layer
start_layer = pp_rank * (base_layers + 1)
end_layer = start_layer + (base_layers + 1)
else:
# This partition gets only base layers
start_layer = (
remainder * (base_layers + 1) + (pp_rank - remainder) * base_layers
)
end_layer = start_layer + base_layers
base_layers = num_hidden_layers // pp_size
remainder = num_hidden_layers % pp_size
# Each rank `i` is assigned `base_layers + (1 if i < remainder else 0)` layers.
# The start layer for `pp_rank` is the sum of layers for all previous ranks.
start_layer = pp_rank * base_layers + min(pp_rank, remainder)
if pp_rank < remainder:
# This partition gets an extra layer.
end_layer = start_layer + base_layers + 1
else:
# This partition gets base layers.
end_layer = start_layer + base_layers


return (start_layer, end_layer)

Expand Down
21 changes: 14 additions & 7 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3254,13 +3254,20 @@ def __init__(
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
if self.pp_group.is_last_rank:
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
else:
# ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config)

self._routed_experts_weights_of_layer = LazyValue(
Expand Down
Loading