From 63fbb7c7701c325a9a682816686f744e21d377bd Mon Sep 17 00:00:00 2001 From: Weisi Duan Date: Mon, 30 Mar 2026 11:45:55 -0700 Subject: [PATCH] Fix in place mode in pause generation --- python/sglang/srt/managers/scheduler.py | 13 +- .../test_scheduler_pause_generation.py | 130 ++++++++++++++++++ 2 files changed, 140 insertions(+), 3 deletions(-) create mode 100644 test/registered/unit/managers/test_scheduler_pause_generation.py diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5b4d85c9ecac..3d53d235a89b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -3246,6 +3246,16 @@ def _pause_engine(self) -> Tuple[List[Req], int]: def pause_generation(self, recv_req: PauseGenerationReqInput): self._engine_paused = True + if recv_req.mode == "in_place": + # In-place pause: just set the flag and return immediately. + # All scheduler state (running_batch, last_batch, chunked_req, + # result_queue) is left untouched. On resume, the normal event + # loop (get_next_batch_to_run) handles last_batch merge, + # chunked_req cleanup, and overlap result processing through + # the standard code paths. This avoids duplicating batch + # manipulation logic and the accounting bugs that come with it. + return + if self.enable_overlap and self.last_batch: # Process the results of the last batch tmp_batch, tmp_result = self.result_queue.popleft() @@ -3253,9 +3263,6 @@ def pause_generation(self, recv_req: PauseGenerationReqInput): if self.last_batch and self.last_batch.forward_mode.is_extend(): chunked_req_to_exclude = set() - if recv_req.mode == "in_place": - if self.chunked_req is not None: - chunked_req_to_exclude.add(self.chunked_req) self.last_batch.filter_batch( chunked_req_to_exclude=list(chunked_req_to_exclude) ) diff --git a/test/registered/unit/managers/test_scheduler_pause_generation.py b/test/registered/unit/managers/test_scheduler_pause_generation.py new file mode 100644 index 000000000000..ee5813d30049 --- /dev/null +++ b/test/registered/unit/managers/test_scheduler_pause_generation.py @@ -0,0 +1,130 @@ +import unittest +from collections import deque +from unittest.mock import MagicMock + +from sglang.srt.managers.io_struct import PauseGenerationReqInput +from sglang.srt.managers.scheduler import Scheduler +from sglang.test.ci.ci_register import register_cpu_ci + +register_cpu_ci(est_time=2, suite="stage-a-cpu-only") + + +class TestSchedulerPauseGeneration(unittest.TestCase): + def _new_scheduler(self) -> Scheduler: + scheduler = Scheduler.__new__(Scheduler) + scheduler._engine_paused = False + scheduler.enable_overlap = False + scheduler.last_batch = None + scheduler.cur_batch = None + scheduler.chunked_req = None + scheduler.running_batch = MagicMock() + scheduler.running_batch.reqs = [] + scheduler.running_batch.is_empty.return_value = True + scheduler.running_batch.batch_is_full = False + scheduler.tree_cache = MagicMock() + scheduler.tree_cache.protected_size.return_value = 0 + scheduler.req_to_token_pool = MagicMock() + scheduler.result_queue = deque() + # Support _kv_snap diagnostic logging in patched schedulers + scheduler.token_to_kv_pool_allocator = MagicMock() + scheduler.token_to_kv_pool_allocator.available_size.return_value = 1000 + scheduler.max_total_num_tokens = 1000 + scheduler._get_token_info = MagicMock(return_value=(0, 0, 1000, 0)) + return scheduler + + def test_inplace_only_sets_flag(self): + """in_place pause should only set _engine_paused and return.""" + scheduler = self._new_scheduler() + scheduler.last_batch = MagicMock() + scheduler.cur_batch = MagicMock() + scheduler.chunked_req = MagicMock() + + original_last_batch = scheduler.last_batch + original_cur_batch = scheduler.cur_batch + original_chunked_req = scheduler.chunked_req + + scheduler.pause_generation(PauseGenerationReqInput(mode="in_place")) + + self.assertTrue(scheduler._engine_paused) + # All state must be preserved — no mutation + self.assertIs(scheduler.last_batch, original_last_batch) + self.assertIs(scheduler.cur_batch, original_cur_batch) + self.assertIs(scheduler.chunked_req, original_chunked_req) + + def test_inplace_does_not_drain_overlap_queue(self): + """in_place should not process the overlap result_queue.""" + scheduler = self._new_scheduler() + scheduler.enable_overlap = True + scheduler.last_batch = MagicMock() + scheduler.result_queue = deque([(MagicMock(), MagicMock())]) + + scheduler.pause_generation(PauseGenerationReqInput(mode="in_place")) + + self.assertTrue(scheduler._engine_paused) + self.assertEqual(len(scheduler.result_queue), 1) + + def test_inplace_does_not_merge_batch(self): + """in_place should not filter or merge last_batch into running_batch.""" + scheduler = self._new_scheduler() + last_batch = MagicMock() + last_batch.forward_mode.is_extend.return_value = True + scheduler.last_batch = last_batch + + scheduler.pause_generation(PauseGenerationReqInput(mode="in_place")) + + last_batch.filter_batch.assert_not_called() + scheduler.running_batch.merge_batch.assert_not_called() + + def test_abort_clears_state(self): + """abort mode should clear last_batch and cur_batch.""" + scheduler = self._new_scheduler() + scheduler.last_batch = MagicMock() + scheduler.last_batch.forward_mode.is_extend.return_value = False + scheduler.cur_batch = MagicMock() + + scheduler.pause_generation(PauseGenerationReqInput(mode="abort")) + + self.assertTrue(scheduler._engine_paused) + self.assertIsNone(scheduler.last_batch) + self.assertIsNone(scheduler.cur_batch) + + def test_retract_clears_running_batch(self): + """retract mode should retract all requests from running_batch.""" + scheduler = self._new_scheduler() + scheduler.last_batch = None + scheduler.running_batch.reqs = [MagicMock(), MagicMock()] + scheduler.running_batch.__len__ = lambda self: len(self.reqs) + scheduler.running_batch.is_empty.return_value = False + scheduler.waiting_queue = [] + scheduler._add_request_to_queue = MagicMock() + + retracted = [MagicMock(), MagicMock()] + scheduler.running_batch.retract_all.return_value = retracted + scheduler.running_batch.filter_batch = MagicMock() + scheduler.server_args = MagicMock() + + scheduler.pause_generation(PauseGenerationReqInput(mode="retract")) + + self.assertTrue(scheduler._engine_paused) + scheduler.running_batch.retract_all.assert_called_once() + self.assertEqual(scheduler._add_request_to_queue.call_count, 2) + self.assertIsNone(scheduler.chunked_req) + + def test_abort_drains_overlap_queue(self): + """abort with overlap enabled should drain the result_queue.""" + scheduler = self._new_scheduler() + scheduler.enable_overlap = True + mock_batch = MagicMock() + mock_batch.forward_mode.is_extend.return_value = False + scheduler.last_batch = mock_batch + scheduler.result_queue = deque([(MagicMock(), MagicMock())]) + scheduler.process_batch_result = MagicMock() + + scheduler.pause_generation(PauseGenerationReqInput(mode="abort")) + + scheduler.process_batch_result.assert_called_once() + self.assertEqual(len(scheduler.result_queue), 0) + + +if __name__ == "__main__": + unittest.main()