Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 2 additions & 44 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,53 +87,11 @@ jobs:
run: |
bash scripts/ci_install_dependency.sh

- name: Test data parallelism (DP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_data_parallelism.py

- name: Test data parallelism attention (DP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_dp_attention.py

- name: Test update weights from distributed
timeout-minutes: 10
run: |
cd test/srt
python3 test_update_weights_from_distributed.py

- name: Test VerlEngine
timeout-minutes: 10
run: |
cd test/srt
python3 test_verl_engine.py

- name: Test Patch Torch
timeout-minutes: 10
run: |
cd test/srt
python3 test_patch_torch.py

- name: Test expert parallelism (EP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_moe_ep.py

- name: Test torch compile (TP=2)
- name: Run test
timeout-minutes: 10
run: |
cd test/srt
python3 test_mla_tp.py

- name: Test lora tensor parallelism (TP=2)
timeout-minutes: 10
run: |
cd test/srt/models/lora
python3 test_lora_tp.py
python3 run_suite.py --suite per-commit-2-gpu

performance-test-1-gpu-part-1:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/constrained/base_grammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def reset(self):
self.cache.clear()


def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
def create_grammar_backend(
server_args: ServerArgs, tokenizer, vocab_size: int
) -> Optional[BaseGrammarBackend]:
if server_args.grammar_backend == "outlines":
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend

Expand All @@ -188,6 +190,8 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
tokenizer=tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
elif server_args.grammar_backend == "none":
return None
else:
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")

Expand Down
9 changes: 6 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def reset_for_retract(self):
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
self.already_computed = 0

def __repr__(self):
return (
Expand Down Expand Up @@ -960,8 +961,6 @@ def prepare_for_extend(self):
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting

if req.is_retracted:
req.already_computed = 0
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
Expand Down Expand Up @@ -1189,7 +1188,11 @@ def get_required_tokens(num_reqs: int):
self.req_to_token_pool.free(req.req_pool_idx)
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
last_uncached_pos = (
(len(req.prefix_indices) + server_args.page_size - 1)
// server_args.page_size
* server_args.page_size
)
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
Expand Down
76 changes: 23 additions & 53 deletions python/sglang/srt/metrics/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SchedulerMetricsCollector:

def __init__(self, labels: Dict[str, str]) -> None:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Gauge
from prometheus_client import Gauge, Histogram

self.labels = labels
self.last_log_time = time.time()
Expand Down Expand Up @@ -139,10 +139,10 @@ def __init__(self, labels: Dict[str, str]) -> None:
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
0.2,
0.4,
0.6,
0.8,
1,
2,
4,
Expand All @@ -153,36 +153,9 @@ def __init__(self, labels: Dict[str, str]) -> None:
40,
60,
80,
120,
160,
],
)

self.histogram_time_per_output_token = Histogram(
name="sglang:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labels.keys(),
buckets=[
0.002,
0.005,
0.010,
0.020,
0.030,
0.040,
0.050,
0.060,
0.070,
0.080,
0.090,
0.100,
0.150,
0.200,
0.300,
0.400,
0.600,
0.800,
1.000,
2.000,
100,
200,
400,
],
)

Expand All @@ -202,17 +175,18 @@ def __init__(self, labels: Dict[str, str]) -> None:
0.030,
0.035,
0.040,
0.050,
0.075,
0.060,
0.080,
0.100,
0.150,
0.200,
0.300,
0.400,
0.500,
0.750,
0.600,
0.800,
1.000,
2.000,
4.000,
6.000,
8.000,
],
)

Expand All @@ -224,23 +198,22 @@ def __init__(self, labels: Dict[str, str]) -> None:
0.1,
0.2,
0.4,
0.6,
0.8,
1,
2,
5,
4,
6,
8,
10,
20,
40,
60,
80,
100,
150,
200,
250,
300,
350,
500,
1000,
400,
800,
],
)

Expand All @@ -256,13 +229,10 @@ def observe_one_finished_request(
):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
if cached_tokens > 0:
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
self.num_requests_total.labels(**self.labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
if generation_tokens >= 1:
self.histogram_time_per_output_token.labels(**self.labels).observe(
e2e_latency / generation_tokens
)

def observe_time_to_first_token(self, value: float):
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
Expand Down
20 changes: 12 additions & 8 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ServerArgs:
# Kernel backend
attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = "xgrammar"
grammar_backend: Optional[str] = None

# Speculative decoding
speculative_algorithm: Optional[str] = None
Expand Down Expand Up @@ -193,6 +193,13 @@ class ServerArgs:
disaggregation_bootstrap_port: int = 8998

def __post_init__(self):
# Expert parallelism
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)

# Set missing default values
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
Expand Down Expand Up @@ -274,12 +281,9 @@ def __post_init__(self):
)
self.disable_cuda_graph = True

# Expert parallelism
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Choose grammar backend
if self.grammar_backend is None:
self.grammar_backend = "xgrammar"

# Data parallelism attention
if self.enable_dp_attention:
Expand Down Expand Up @@ -813,7 +817,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--grammar-backend",
type=str,
choices=["xgrammar", "outlines", "llguidance"],
choices=["xgrammar", "outlines", "llguidance", "none"],
default=ServerArgs.grammar_backend,
help="Choose the backend for grammar-guided decoding.",
)
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,9 +1012,6 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):


class CustomTestCase(unittest.TestCase):
pass

"""
def _callTestMethod(self, method):
max_retry = int(
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
Expand All @@ -1023,4 +1020,3 @@ def _callTestMethod(self, method):
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry,
)
"""
6 changes: 3 additions & 3 deletions test/srt/models/lora/test_lora_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
],
max_loras_per_batch=1,
),
]

ALL_OTHER_LORA_MODELS = [
LoRAModelCase(
base="meta-llama/Llama-3.1-8B-Instruct",
adaptors=[
Expand All @@ -43,9 +46,6 @@
],
max_loras_per_batch=1,
),
]

ALL_OTHER_LORA_MODELS = [
LoRAModelCase(
base="meta-llama/Llama-2-7b-hf",
adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")],
Expand Down
16 changes: 13 additions & 3 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TestFile:
TestFile("models/lora/test_lora.py", 76),
TestFile("models/lora/test_lora_backend.py", 420),
TestFile("models/lora/test_multi_lora_backend.py", 144),
TestFile("models/test_embedding_models.py", 119),
TestFile("models/test_embedding_models.py", 35),
TestFile("models/test_generation_models.py", 103),
TestFile("models/test_grok_models.py", 60),
TestFile("models/test_qwen_models.py", 82),
Expand All @@ -38,7 +38,7 @@ class TestFile:
TestFile("test_metrics.py", 32),
TestFile("test_mla.py", 92),
TestFile("test_mla_deepseek_v3.py", 221),
TestFile("test_mla_int8_deepseek_v3.py", 421),
TestFile("test_mla_int8_deepseek_v3.py", 522),
TestFile("test_mla_flashinfer.py", 395),
TestFile("test_mla_fp8.py", 93),
TestFile("test_no_chunked_prefill.py", 126),
Expand All @@ -59,7 +59,7 @@ class TestFile:
TestFile("test_srt_endpoint.py", 94),
TestFile("test_torch_compile.py", 76),
TestFile("test_torch_compile_moe.py", 85),
TestFile("test_torch_native_attention_backend.py", 149),
TestFile("test_torch_native_attention_backend.py", 123),
TestFile("test_torchao.py", 70),
TestFile("test_triton_attention_kernels.py", 4),
TestFile("test_triton_attention_backend.py", 134),
Expand All @@ -76,6 +76,16 @@ class TestFile:
TestFile("test_hicache.py", 60),
TestFile("test_hicache_mla.py", 90),
],
"per-commit-2-gpu": [
TestFile("test_data_parallelism.py", 90),
TestFile("test_dp_attention.py", 90),
TestFile("test_update_weights_from_distributed.py", 100),
TestFile("test_verl_engine.py", 100),
TestFile("test_patch_torch.py", 30),
TestFile("test_moe_ep.py", 220),
TestFile("test_mla_tp.py", 420),
TestFile("test_lora_tp.py", 300),
],
"nightly": [
TestFile("test_nightly_gsm8k_eval.py"),
],
Expand Down
4 changes: 4 additions & 0 deletions test/srt/test_dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ def test_mgsm_en(self):
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)


if __name__ == "__main__":
unittest.main()
1 change: 0 additions & 1 deletion test/srt/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def test_metrics_enabled(self):
"sglang:cached_tokens_total",
"sglang:num_requests_total",
"sglang:time_to_first_token_seconds",
"sglang:time_per_output_token_seconds",
"sglang:inter_token_latency_seconds",
"sglang:e2e_request_latency_seconds",
]
Expand Down
Loading