Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3246,16 +3246,23 @@ def _pause_engine(self) -> Tuple[List[Req], int]:
def pause_generation(self, recv_req: PauseGenerationReqInput):
self._engine_paused = True

if recv_req.mode == "in_place":
# In-place pause: just set the flag and return immediately.
# All scheduler state (running_batch, last_batch, chunked_req,
# result_queue) is left untouched. On resume, the normal event
# loop (get_next_batch_to_run) handles last_batch merge,
# chunked_req cleanup, and overlap result processing through
# the standard code paths. This avoids duplicating batch
# manipulation logic and the accounting bugs that come with it.
return

if self.enable_overlap and self.last_batch:
# Process the results of the last batch
tmp_batch, tmp_result = self.result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result)

if self.last_batch and self.last_batch.forward_mode.is_extend():
chunked_req_to_exclude = set()
if recv_req.mode == "in_place":
if self.chunked_req is not None:
chunked_req_to_exclude.add(self.chunked_req)
self.last_batch.filter_batch(
chunked_req_to_exclude=list(chunked_req_to_exclude)
)
Expand Down
130 changes: 130 additions & 0 deletions test/registered/unit/managers/test_scheduler_pause_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import unittest
from collections import deque
from unittest.mock import MagicMock

from sglang.srt.managers.io_struct import PauseGenerationReqInput
from sglang.srt.managers.scheduler import Scheduler
from sglang.test.ci.ci_register import register_cpu_ci

register_cpu_ci(est_time=2, suite="stage-a-cpu-only")


class TestSchedulerPauseGeneration(unittest.TestCase):
def _new_scheduler(self) -> Scheduler:
scheduler = Scheduler.__new__(Scheduler)
scheduler._engine_paused = False
scheduler.enable_overlap = False
scheduler.last_batch = None
scheduler.cur_batch = None
scheduler.chunked_req = None
scheduler.running_batch = MagicMock()
scheduler.running_batch.reqs = []
scheduler.running_batch.is_empty.return_value = True
scheduler.running_batch.batch_is_full = False
scheduler.tree_cache = MagicMock()
scheduler.tree_cache.protected_size.return_value = 0
scheduler.req_to_token_pool = MagicMock()
scheduler.result_queue = deque()
# Support _kv_snap diagnostic logging in patched schedulers
scheduler.token_to_kv_pool_allocator = MagicMock()
scheduler.token_to_kv_pool_allocator.available_size.return_value = 1000
scheduler.max_total_num_tokens = 1000
scheduler._get_token_info = MagicMock(return_value=(0, 0, 1000, 0))
return scheduler

def test_inplace_only_sets_flag(self):
"""in_place pause should only set _engine_paused and return."""
scheduler = self._new_scheduler()
scheduler.last_batch = MagicMock()
scheduler.cur_batch = MagicMock()
scheduler.chunked_req = MagicMock()

original_last_batch = scheduler.last_batch
original_cur_batch = scheduler.cur_batch
original_chunked_req = scheduler.chunked_req

scheduler.pause_generation(PauseGenerationReqInput(mode="in_place"))

self.assertTrue(scheduler._engine_paused)
# All state must be preserved — no mutation
self.assertIs(scheduler.last_batch, original_last_batch)
self.assertIs(scheduler.cur_batch, original_cur_batch)
self.assertIs(scheduler.chunked_req, original_chunked_req)

def test_inplace_does_not_drain_overlap_queue(self):
"""in_place should not process the overlap result_queue."""
scheduler = self._new_scheduler()
scheduler.enable_overlap = True
scheduler.last_batch = MagicMock()
scheduler.result_queue = deque([(MagicMock(), MagicMock())])

scheduler.pause_generation(PauseGenerationReqInput(mode="in_place"))

self.assertTrue(scheduler._engine_paused)
self.assertEqual(len(scheduler.result_queue), 1)

def test_inplace_does_not_merge_batch(self):
"""in_place should not filter or merge last_batch into running_batch."""
scheduler = self._new_scheduler()
last_batch = MagicMock()
last_batch.forward_mode.is_extend.return_value = True
scheduler.last_batch = last_batch

scheduler.pause_generation(PauseGenerationReqInput(mode="in_place"))

last_batch.filter_batch.assert_not_called()
scheduler.running_batch.merge_batch.assert_not_called()

def test_abort_clears_state(self):
"""abort mode should clear last_batch and cur_batch."""
scheduler = self._new_scheduler()
scheduler.last_batch = MagicMock()
scheduler.last_batch.forward_mode.is_extend.return_value = False
scheduler.cur_batch = MagicMock()

scheduler.pause_generation(PauseGenerationReqInput(mode="abort"))

self.assertTrue(scheduler._engine_paused)
self.assertIsNone(scheduler.last_batch)
self.assertIsNone(scheduler.cur_batch)

def test_retract_clears_running_batch(self):
"""retract mode should retract all requests from running_batch."""
scheduler = self._new_scheduler()
scheduler.last_batch = None
scheduler.running_batch.reqs = [MagicMock(), MagicMock()]
scheduler.running_batch.__len__ = lambda self: len(self.reqs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assignment to __len__ is unnecessary because the pause_generation implementation checks len(self.running_batch.reqs), not len(self.running_batch). Furthermore, mocking special methods like __len__ on a MagicMock instance should typically be done by setting return_value or side_effect on the attribute, though it is not needed here.

scheduler.running_batch.is_empty.return_value = False
scheduler.waiting_queue = []
scheduler._add_request_to_queue = MagicMock()

retracted = [MagicMock(), MagicMock()]
scheduler.running_batch.retract_all.return_value = retracted
scheduler.running_batch.filter_batch = MagicMock()
scheduler.server_args = MagicMock()

scheduler.pause_generation(PauseGenerationReqInput(mode="retract"))

self.assertTrue(scheduler._engine_paused)
scheduler.running_batch.retract_all.assert_called_once()
self.assertEqual(scheduler._add_request_to_queue.call_count, 2)
self.assertIsNone(scheduler.chunked_req)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion is weak because scheduler.chunked_req is already initialized to None in _new_scheduler. To properly verify that the retract mode clears the chunked request, you should set scheduler.chunked_req to a non-None value (e.g., a MagicMock) before calling pause_generation.


def test_abort_drains_overlap_queue(self):
"""abort with overlap enabled should drain the result_queue."""
scheduler = self._new_scheduler()
scheduler.enable_overlap = True
mock_batch = MagicMock()
mock_batch.forward_mode.is_extend.return_value = False
scheduler.last_batch = mock_batch
scheduler.result_queue = deque([(MagicMock(), MagicMock())])
scheduler.process_batch_result = MagicMock()

scheduler.pause_generation(PauseGenerationReqInput(mode="abort"))

scheduler.process_batch_result.assert_called_once()
self.assertEqual(len(scheduler.result_queue), 0)


if __name__ == "__main__":
unittest.main()
Loading