diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 888a1fe937a1..85b4feedab72 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1041,6 +1041,7 @@ def get_benchmark_args( gsp_output_len=32, gsp_num_turns=1, header=None, + max_concurrency=None, ): return SimpleNamespace( backend=backend, @@ -1082,6 +1083,7 @@ def get_benchmark_args( gsp_output_len=gsp_output_len, gsp_num_turns=gsp_num_turns, header=header, + max_concurrency=max_concurrency, ) diff --git a/test/registered/distributed/test_disaggregation_dp_attention.py b/test/registered/distributed/test_disaggregation_dp_attention.py index 249a11e75a7b..a723111e03e7 100644 --- a/test/registered/distributed/test_disaggregation_dp_attention.py +++ b/test/registered/distributed/test_disaggregation_dp_attention.py @@ -1,6 +1,7 @@ import unittest from types import SimpleNamespace +from sglang.bench_serving import run_benchmark from sglang.srt.environ import envs from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k @@ -10,11 +11,12 @@ from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + get_benchmark_args, popen_launch_pd_server, try_cached_model, ) -register_cuda_ci(est_time=155, suite="stage-c-test-8-gpu-h20") +register_cuda_ci(est_time=580, suite="stage-c-test-8-gpu-h20") class TestDisaggregationDPAttention(PDDisaggregationServerBase): @@ -104,6 +106,24 @@ def test_gsm8k(self): class TestDisaggregationDPAttentionRoundRobin(TestDisaggregationDPAttention): LOAD_BALANCE_METHOD = "round_robin" + # TODO: add test for other load balance methods + # TODO: add a balancedness metric + + def test_bench_serving(self): + args = get_benchmark_args( + base_url=f"http://{self.base_host}:{self.lb_port}", + dataset_name="random", + tokenizer=self.model, + num_prompts=1000, + random_input_len=4096, + random_output_len=1024, + request_rate=float("inf"), + max_concurrency=256, + ) + result = run_benchmark(args) + + self.assertLess(result["mean_tpot_ms"], 20) + self.assertEqual(result["completed"], 1000) if __name__ == "__main__":