Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,18 @@ jobs:
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache

- name: Benchmark offline decode throughput (PP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_offline_throughput_default_decode

- name: Benchmark offline prefill throughput (PP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_pp_long_context_prefill

accuracy-test-1-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
github.event.pull_request.draft == false
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,6 @@ def add_one_req(
return AddReqResult.OTHER

with self._lock_node(req.last_node):
if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN

if (
enable_hierarchical_cache
and req.last_node_global is not None
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def event_loop_pp(self):
server_is_idle = False
result = self.run_batch(self.cur_batch)

# send the outputs to the next step
# (last rank) send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
Expand Down Expand Up @@ -759,18 +759,18 @@ def event_loop_pp(self):
self.process_batch_result(mbs[next_mb_id], output_result)
last_mbs[next_mb_id] = mbs[next_mb_id]

# carry the outputs to the next stage
# (not last rank)
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)

if not self.pp_group.is_last_rank:
# send out reqs to the next stage
dp_offset = self.dp_rank * self.attn_tp_size
if self.attn_tp_rank == 0:
Expand Down
11 changes: 9 additions & 2 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
get_tp_group,
get_world_group,
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce,
Expand Down Expand Up @@ -404,7 +405,10 @@ def init_torch_distributed(self):
)

min_per_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1
self.device,
self.gpu_id,
distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group,
)
self.tp_group = get_tp_group()
self.attention_tp_group = get_attention_tp_group()
Expand Down Expand Up @@ -716,7 +720,10 @@ def init_lora_manager(self):

def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1
self.device,
self.gpu_id,
distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group,
)
if self.use_mla_backend:
num_layers = (
Expand Down
132 changes: 98 additions & 34 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Mixtral model."""

from typing import Iterable, Optional, Tuple
import logging
from typing import Iterable, Optional, Tuple, Union

import torch
from torch import nn
from transformers import MixtralConfig

from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
Expand All @@ -38,14 +40,17 @@
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, make_layers

logger = logging.getLogger(__name__)


class MixtralMoE(nn.Module):
Expand Down Expand Up @@ -257,43 +262,68 @@ def __init__(
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()

self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
)
self.layers = nn.ModuleList(
[
MixtralDecoderLayer(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(config.num_hidden_layers)
]
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
)
else:
self.embed_tokens = PPMissingLayer()

self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: MixtralDecoderLayer(
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix="layers",
return_tuple=True,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
hidden_states = input_embeds
residual = None
for i in range(len(self.layers)):
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)

if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
hidden_states, _ = self.norm(hidden_states, residual)

return hidden_states


Expand All @@ -306,6 +336,7 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.quant_config = quant_config
self.model = MixtralModel(
Expand All @@ -322,12 +353,31 @@ def forward(
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)

if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return hidden_states

@property
def start_layer(self):
return self.model.start_layer

@property
def end_layer(self):
return self.model.end_layer

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
Expand All @@ -348,6 +398,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue

if "rotary_emb.inv_freq" in name:
continue

Expand Down Expand Up @@ -398,11 +459,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if name is None:
continue

param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")


EntryClass = MixtralForCausalLM
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,12 @@ def __post_init__(self):
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)

if self.pp_size > 1:
self.disable_overlap_schedule = True
logger.warning(
"Pipeline parallelism is incompatible with overlap schedule."
)

# Speculative Decoding
if self.speculative_algorithm == "NEXTN":
# NEXTN shares the same implementation of EAGLE
Expand Down
10 changes: 6 additions & 4 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ def inner_func(*args, **kwargs):
return wrapper


def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
def get_available_gpu_memory(
device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
Expand Down Expand Up @@ -344,10 +346,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()

if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device(device, gpu_id)
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
)
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item()

return free_gpu_memory / (1 << 30)
Expand Down
44 changes: 44 additions & 0 deletions test/srt/test_bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,50 @@ def test_moe_offline_throughput_without_radix_cache(self):
else:
self.assertGreater(res["output_throughput"], 2200)

def test_pp_offline_throughput_default_decode(self):
res = run_bench_serving(
model=DEFAULT_MOE_MODEL_NAME_FOR_TEST,
num_prompts=1000,
request_rate=float("inf"),
random_input_len=1,
random_output_len=1024,
other_server_args=["--pp", "2"],
need_warmup=True,
seed=42,
)

if is_in_ci():
write_github_step_summary(
f"### test_pp_offline_throughput_default_decode\n"
f'Output throughput: {res["output_throughput"]:.2f} token/s\n'
)
self.assertGreater(res["output_throughput"], 7500)

def test_pp_long_context_prefill(self):
res = run_bench_serving(
model="meta-llama/Llama-3.3-70B-Instruct",
num_prompts=4,
request_rate=float("inf"),
random_input_len=128000,
random_output_len=1,
dataset_name="random",
other_server_args=[
"--quantization",
"fp8",
"--pp",
2,
],
need_warmup=False,
seed=42,
)

if is_in_ci():
write_github_step_summary(
f"### test_pp_long_context_latency_prefill\n"
f'input_throughput: {res["input_throughput"]:.2f} ms\n'
)
self.assertGreater(res["input_throughput"], 4000)


if __name__ == "__main__":
unittest.main()
Loading