Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions python/sglang/srt/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

# Distributed
from sglang.srt.distributed import get_pp_group
from sglang.srt.distributed import get_pp_group, get_pp_indices
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation

Expand Down Expand Up @@ -59,6 +59,7 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.radix_linear_attention import RadixLinearAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
Expand Down Expand Up @@ -680,6 +681,8 @@ def __init__(
org_num_embeddings=config.vocab_size,
enable_tp=not is_dp_attention_enabled(),
)
else:
self.embed_tokens = PPMissingLayer()

# Decoder layers
def get_layer(idx: int, prefix: str):
Expand All @@ -703,13 +706,36 @@ def get_layer(idx: int, prefix: str):
prefix=f"{prefix}.layers",
)

pp_rank = self.pp_group.rank_in_group
pp_size = self.pp_group.world_size
num_layers = config.num_hidden_layers
self._start_layer, self._end_layer = (
get_pp_indices(
num_layers,
pp_rank,
pp_size,
)
if pp_rank is not None and pp_size is not None
else (0, num_layers)
)

# Final normalization
if self.pp_group.is_last_rank:
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()

def get_input_embeddings(self) -> nn.Embedding:
def get_input_embeddings(self):
return self.embed_tokens

@property
def start_layer(self) -> int:
return self._start_layer

@property
def end_layer(self) -> int:
return self._end_layer

@torch.no_grad()
def forward(
self,
Expand All @@ -733,7 +759,7 @@ def forward(
residual = pp_proxy_tensors["residual"]

# Pass through decoder layers
for layer_idx in range(len(self.layers)):
for layer_idx in range(self.start_layer, self.end_layer):
layer = self.layers[layer_idx]
with get_global_expert_distribution_recorder().with_current_layer(
layer_idx
Expand Down Expand Up @@ -1045,14 +1071,31 @@ def __init__(

self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes

@property
def start_layer(self) -> int:
return getattr(getattr(self, "model", None), "start_layer", 0)

@property
def end_layer(self) -> int:
model = getattr(self, "model", None)
end_layer = getattr(model, "end_layer", None)
if end_layer is not None:
return end_layer
cfg = getattr(model, "config", None)
return int(getattr(cfg, "num_hidden_layers", 0))

def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
embed = self.model.embed_tokens.weight if self.pp_group.is_first_rank else None
head = self.lm_head.weight if self.pp_group.is_last_rank else None
return embed, head

def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
if self.pp_group.is_first_rank and embed is not None:
del self.model.embed_tokens.weight
self.model.embed_tokens.weight = embed
if self.pp_group.is_last_rank and head is not None:
del self.lm_head.weight
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
Comment thread
yuan-luo marked this conversation as resolved.

Expand Down Expand Up @@ -1138,13 +1181,17 @@ def __init__(
self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes

def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
embed = self.model.embed_tokens.weight if self.pp_group.is_first_rank else None
head = self.lm_head.weight if self.pp_group.is_last_rank else None
return embed, head

def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
if self.pp_group.is_first_rank and embed is not None:
del self.model.embed_tokens.weight
self.model.embed_tokens.weight = embed
if self.pp_group.is_last_rank and head is not None:
del self.lm_head.weight
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()

Expand Down
54 changes: 54 additions & 0 deletions test/registered/distributed/test_pp_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,60 @@ def test_pp_consistency(self):
)


class TestQwen35PPAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_url = "http://127.0.0.1:23337" # different ports to avoid conflicts
cls.model_name = (
"Qwen/Qwen3.5-35B-A3B" # replace with your Qwen Model if needed
)

def run_gsm8k_test(self, pp_size):
process = popen_launch_server(
self.model_name,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--pp-size",
pp_size,
"--chunked-prefill-size",
256,
],
)

try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
time.sleep(5)
return metrics
finally:
kill_process_tree(process.pid)

def test_pp_consistency(self):
Comment thread
yuan-luo marked this conversation as resolved.
baseline = self.run_gsm8k_test(pp_size=1)
pp_metrics = self.run_gsm8k_test(pp_size=2)

print(f"[Qwen35 PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")

self.assertGreaterEqual(baseline["accuracy"], 0.83)
self.assertGreaterEqual(
pp_metrics["accuracy"],
baseline["accuracy"] - 0.02,
msg=(
f"PP accuracy dropped more than 1% compared to baseline. "
f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}"
),
)


class TestFixedBugs(unittest.TestCase):
def test_chunked_prefill_with_small_bs(self):
model = DEFAULT_MODEL_NAME_FOR_TEST
Expand Down
Loading