Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
Expand Down Expand Up @@ -258,6 +259,7 @@ def __init__(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix),
)
else:
Expand Down Expand Up @@ -326,7 +328,11 @@ def forward(
}
)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

# If this function is called, it should always initialize KV cache scale
Expand Down
58 changes: 42 additions & 16 deletions python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -54,18 +56,21 @@ def __init__(
self.hidden_size = hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The self.tp_size attribute is initialized here but is not used in the subsequent calculations for self.num_heads or self.num_kv_heads within this __init__ method. Instead, attn_tp_size is used, which is appropriate for data parallel attention. If self.tp_size is not utilized elsewhere in the Qwen3Attention class, consider removing this line to improve clarity and reduce potential confusion for future maintainers.

self.total_num_heads = num_heads
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // self.tp_size
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()

assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= self.tp_size:
if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % self.tp_size == 0
assert self.total_num_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
Expand All @@ -84,13 +89,18 @@ def __init__(
self.total_num_kv_heads,
bias=attention_bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=attention_bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
reduce_results=False,
prefix=add_prefix("o_proj", prefix),
)

Expand Down Expand Up @@ -176,6 +186,18 @@ def __init__(
config.hidden_size, eps=config.rms_norm_eps
)

self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=False,
is_previous_layer_sparse=False,
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)

def forward(
self,
positions: torch.Tensor,
Expand All @@ -184,20 +206,24 @@ def forward(
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
hidden_states = self.mlp(hidden_states)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual


Expand Down
Loading