From 4a336d926f2538d21896ccbdb02fbddf01bf3571 Mon Sep 17 00:00:00 2001 From: Weiliangl User Date: Fri, 20 Mar 2026 09:30:01 +0000 Subject: [PATCH 1/2] Fix cuda graph max bs capture upper bound --- python/sglang/srt/server_args.py | 3 +++ test/registered/core/test_server_args.py | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2be1c045a0bf..068a7e33665c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1299,6 +1299,9 @@ def _generate_cuda_graph_batch_sizes(self): capture_bs = [bs for bs in capture_bs if bs <= self.cuda_graph_max_bs] + if self.cuda_graph_max_bs not in capture_bs: + capture_bs.append(self.cuda_graph_max_bs) + return capture_bs def _generate_piecewise_cuda_graph_tokens(self): diff --git a/test/registered/core/test_server_args.py b/test/registered/core/test_server_args.py index 41fd5c04cba4..41b4f64f3444 100644 --- a/test/registered/core/test_server_args.py +++ b/test/registered/core/test_server_args.py @@ -48,6 +48,16 @@ def test_pd_decode_defaults_to_round_robin(self): self.assertEqual(server_args.load_balance_method, "round_robin") +class TestCudaGraphBatchSizes(unittest.TestCase): + def test_generate_cuda_graph_batch_sizes_includes_max_bs(self): + server_args = ServerArgs(model_path="dummy", cuda_graph_max_bs=500) + + capture_bs = server_args._generate_cuda_graph_batch_sizes() + + self.assertIn(500, capture_bs) + self.assertEqual(capture_bs[-1], 500) + + class TestPortArgs(unittest.TestCase): @patch("sglang.srt.server_args.get_free_port") @patch("sglang.srt.server_args.tempfile.NamedTemporaryFile") From 265a484b3a1e1534d3bad841c47cfa1250c520ea Mon Sep 17 00:00:00 2001 From: Weiliangl User Date: Fri, 20 Mar 2026 09:34:32 +0000 Subject: [PATCH 2/2] Remove cuda graph max bs regression test --- test/registered/core/test_server_args.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/registered/core/test_server_args.py b/test/registered/core/test_server_args.py index 41b4f64f3444..41fd5c04cba4 100644 --- a/test/registered/core/test_server_args.py +++ b/test/registered/core/test_server_args.py @@ -48,16 +48,6 @@ def test_pd_decode_defaults_to_round_robin(self): self.assertEqual(server_args.load_balance_method, "round_robin") -class TestCudaGraphBatchSizes(unittest.TestCase): - def test_generate_cuda_graph_batch_sizes_includes_max_bs(self): - server_args = ServerArgs(model_path="dummy", cuda_graph_max_bs=500) - - capture_bs = server_args._generate_cuda_graph_batch_sizes() - - self.assertIn(500, capture_bs) - self.assertEqual(capture_bs[-1], 500) - - class TestPortArgs(unittest.TestCase): @patch("sglang.srt.server_args.get_free_port") @patch("sglang.srt.server_args.tempfile.NamedTemporaryFile")