Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions python/sglang/srt/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -17,6 +17,7 @@
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
Expand Down Expand Up @@ -294,20 +295,33 @@ def __init__(
super().__init__()

self.config = config

self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
)
else:
self.embed_tokens = PPMissingLayer()

self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
)
self.mup_embedding_multiplier = config.mup_embedding_multiplier
self.start_layer, self.end_layer, self.layers = make_layers(
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Phi3SmallDecoderLayer(
config,
int(prefix.split(".")[-1]),
quant_config,
prefix=prefix,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
)

Expand Down
Loading