Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ff45cbe
Add simple unit test for disaggregation feature
ShangmingCai Apr 21, 2025
3cc7ed1
Merge branch 'main' into add_pd_test
ShangmingCai Apr 23, 2025
f8afe9c
Add disaggregation test into run_suite
ShangmingCai Apr 23, 2025
b27efb4
Fix ci dependency for mooncake
ShangmingCai Apr 23, 2025
6b08ed8
fix dependency
ShangmingCai Apr 23, 2025
98a0329
minor
ShangmingCai Apr 23, 2025
1581015
fix rdma dependency
ShangmingCai Apr 23, 2025
4afe3fe
fix lint
ShangmingCai Apr 23, 2025
84a6251
fix tp size
ShangmingCai Apr 23, 2025
dc79d56
fix lint
ShangmingCai Apr 23, 2025
777bfca
fix model tp
ShangmingCai Apr 23, 2025
69c66c7
tmp check ci env
ShangmingCai Apr 23, 2025
7f016b5
Merge branch 'main' into add_pd_test
ShangmingCai May 6, 2025
76407ef
fix dependency
ShangmingCai May 6, 2025
a9c38ab
Add a new job in pr-test
ShangmingCai May 6, 2025
cb79c0b
Merge branch 'main' into add_pd_test
ShangmingCai May 6, 2025
d01fe2c
check driver
ShangmingCai May 6, 2025
5f7da22
Add rdma dependencies
ShangmingCai May 6, 2025
236c557
Fix tzdata install
ShangmingCai May 6, 2025
18ab9a4
Fix tzdata again
ShangmingCai May 6, 2025
8883bf3
fix
ShangmingCai May 6, 2025
3098146
more
ShangmingCai May 6, 2025
c788ac5
more
ShangmingCai May 7, 2025
82bcade
more
ShangmingCai May 7, 2025
0698a9a
more
ShangmingCai May 7, 2025
4c3572c
more
ShangmingCai May 7, 2025
c8fe5b8
more
ShangmingCai May 7, 2025
3a051f7
more
ShangmingCai May 7, 2025
a4b4d09
more
ShangmingCai May 7, 2025
a51a0fe
Merge branch 'main' into add_pd_test
ShangmingCai May 8, 2025
4d814ba
clean script
ShangmingCai May 8, 2025
8250951
fix merge
ShangmingCai May 8, 2025
ca47625
Merge branch 'main' into add_pd_test
ShangmingCai May 8, 2025
d38f860
fix pr-test.yaml
ShangmingCai May 8, 2025
0a6dcd3
more
ShangmingCai May 8, 2025
37d4e83
use 8 gpu runner
ShangmingCai May 8, 2025
c6c2f69
Merge branch 'main' into add_pd_test
ShangmingCai May 8, 2025
db7c365
tmp enlarge timeout to verify correctness
ShangmingCai May 8, 2025
b034a91
seperate pd test
ShangmingCai May 8, 2025
dc85c95
Done
ShangmingCai May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,81 @@ def popen_launch_server(
raise TimeoutError("Server failed to start within the timeout period.")


def popen_launch_pd_server(
model: str,
base_url: str,
timeout: float,
api_key: Optional[str] = None,
other_args: list[str] = (),
env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
):
_, host, port = base_url.split(":")
host = host[2:]

command = "sglang.launch_server"

command = [
"python3",
"-m",
command,
"--model-path",
model,
*[str(x) for x in other_args],
]

command.extend(
[
"--host",
host,
"--port",
port,
]
)

if api_key:
command += ["--api-key", api_key]

print(f"command={' '.join(command)}")

if return_stdout_stderr:
process = subprocess.Popen(
command,
stdout=return_stdout_stderr[0],
stderr=return_stdout_stderr[1],
env=env,
text=True,
)
else:
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)

start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
response = session.get(
f"{base_url}/health",
headers=headers,
)
if response.status_code == 200:
return process
except requests.RequestException:
pass

return_code = process.poll()
if return_code is not None:
raise Exception(f"Server unexpectedly exits ({return_code=}).")

time.sleep(10)

kill_process_tree(process.pid)
raise TimeoutError("Server failed to start within the timeout period.")


def run_with_timeout(
func: Callable,
args: tuple = (),
Expand Down
2 changes: 1 addition & 1 deletion scripts/ci_install_dependency.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pip install -e "python[all]"

# Install additional dependencies
pip install torch_memory_saver
pip install transformers==4.51.0 sentence_transformers accelerate peft pandas datasets timm torchaudio
pip install transformers==4.51.0 sentence_transformers accelerate peft pandas datasets timm torchaudio mooncake-transfer-engine

# For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class TestFile:
"per-commit-2-gpu": [
TestFile("models/lora/test_lora_tp.py", 300),
TestFile("test_data_parallelism.py", 90),
TestFile("test_disaggregation.py", 90),
TestFile("test_dp_attention.py", 90),
TestFile("test_mla_tp.py", 420),
TestFile("test_moe_ep.py", 220),
Expand Down
137 changes: 137 additions & 0 deletions test/srt/test_disaggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import subprocess
import threading
import time
import unittest
from types import SimpleNamespace

import requests
import torch

from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_pd_server,
run_with_timeout,
)


class TestDisaggregationMooncake(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For faster CI, can we use a small llama model within a node? (maybe tp4+tp4)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test model is lite

cls.base_host = "127.0.0.1"
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1])
cls.lb_url = DEFAULT_URL_FOR_TEST
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}"
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}"

run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)

cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")

lb_command = [
"python3",
"-m",
"sglang.srt.disaggregation.mini_lb",
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
str(cls.base_port),
]

print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")

@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--host",
cls.base_host,
"--port",
str(cls.base_port + 100),
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)

@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--host",
cls.base_host,
"--port",
str(cls.base_port + 200),
"--base-gpu-id",
"1",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)

@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.time()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass

if time.time() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)

@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.lb_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")

self.assertGreater(metrics["accuracy"], 0.62)


if __name__ == "__main__":
unittest.main()
Loading