diff --git a/tests/core/test_dp_scheduler.py b/tests/core/test_dp_scheduler.py index f07fdad1e9..e356368672 100644 --- a/tests/core/test_dp_scheduler.py +++ b/tests/core/test_dp_scheduler.py @@ -16,6 +16,7 @@ import pytest from vllm.config import VllmConfig +from vllm.v1.core.sched.interface import PauseState from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import KVCacheConfig @@ -90,8 +91,8 @@ def test_init_creates_worker_processes( assert len(scheduler.processes) == 2 assert len(scheduler.input_queues) == 2 # output_queues is a dict with (rank, command) tuple keys - # 2 ranks × 15 commands (SchedulerCommand enum) - assert len(scheduler.output_queues) == 30 + # 2 ranks × 17 commands (SchedulerCommand enum) + assert len(scheduler.output_queues) == 34 assert scheduler.log_stats is True assert len(scheduler.per_rank_kv_cache_configs) == 2 @@ -702,6 +703,138 @@ def test_reset_encoder_cache(self, mock_vllm_config, mock_kv_cache_config, scheduler.input_queues[1].put.assert_called_with( (SchedulerCommand.RESET_ENCODER_CACHE, None)) + def test_set_pause_state(self, mock_vllm_config, mock_kv_cache_config, + mock_structured_output_manager): + """Test set_pause_state sends SET_PAUSE_STATE command to all workers.""" + with patch( + 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process' + ): + with patch('multiprocessing.get_context'): + scheduler = DPScheduler( + vllm_config=mock_vllm_config, + kv_cache_config=mock_kv_cache_config, + structured_output_manager=mock_structured_output_manager, + block_size=16, + ) + + scheduler.input_queues = [MagicMock(), MagicMock()] + + mock_queue_0 = MagicMock() + mock_queue_0.get.return_value = None + mock_queue_1 = MagicMock() + mock_queue_1.get.return_value = None + + scheduler.output_queues = { + (0, "set_pause_state"): mock_queue_0, + (1, "set_pause_state"): mock_queue_1, + } + + scheduler.set_pause_state(PauseState.PAUSED_ALL) + + # Verify SET_PAUSE_STATE commands were sent to all ranks + scheduler.input_queues[0].put.assert_called_with( + (SchedulerCommand.SET_PAUSE_STATE, PauseState.PAUSED_ALL)) + scheduler.input_queues[1].put.assert_called_with( + (SchedulerCommand.SET_PAUSE_STATE, PauseState.PAUSED_ALL)) + + # Verify we waited for completion from both ranks + mock_queue_0.get.assert_called_once() + mock_queue_1.get.assert_called_once() + + def test_set_pause_state_paused_new(self, mock_vllm_config, + mock_kv_cache_config, + mock_structured_output_manager): + """Test set_pause_state with PAUSED_NEW state.""" + with patch( + 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process' + ): + with patch('multiprocessing.get_context'): + scheduler = DPScheduler( + vllm_config=mock_vllm_config, + kv_cache_config=mock_kv_cache_config, + structured_output_manager=mock_structured_output_manager, + block_size=16, + ) + + scheduler.input_queues = [MagicMock(), MagicMock()] + + mock_queue_0 = MagicMock() + mock_queue_0.get.return_value = None + mock_queue_1 = MagicMock() + mock_queue_1.get.return_value = None + + scheduler.output_queues = { + (0, "set_pause_state"): mock_queue_0, + (1, "set_pause_state"): mock_queue_1, + } + + scheduler.set_pause_state(PauseState.PAUSED_NEW) + + scheduler.input_queues[0].put.assert_called_with( + (SchedulerCommand.SET_PAUSE_STATE, PauseState.PAUSED_NEW)) + scheduler.input_queues[1].put.assert_called_with( + (SchedulerCommand.SET_PAUSE_STATE, PauseState.PAUSED_NEW)) + + def test_pause_state_property(self, mock_vllm_config, mock_kv_cache_config, + mock_structured_output_manager): + """Test pause_state property queries rank 0 and returns the state.""" + with patch( + 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process' + ): + with patch('multiprocessing.get_context'): + scheduler = DPScheduler( + vllm_config=mock_vllm_config, + kv_cache_config=mock_kv_cache_config, + structured_output_manager=mock_structured_output_manager, + block_size=16, + ) + + scheduler.input_queues = [MagicMock(), MagicMock()] + + mock_queue_0 = MagicMock() + mock_queue_0.get.return_value = PauseState.PAUSED_ALL + + scheduler.output_queues = { + (0, "get_pause_state"): mock_queue_0, + } + + result = scheduler.pause_state + + # Verify GET_PAUSE_STATE command was sent only to rank 0 + scheduler.input_queues[0].put.assert_called_with( + (SchedulerCommand.GET_PAUSE_STATE, None)) + # Verify rank 1 was NOT queried + scheduler.input_queues[1].put.assert_not_called() + + assert result == PauseState.PAUSED_ALL + + def test_pause_state_unpaused(self, mock_vllm_config, mock_kv_cache_config, + mock_structured_output_manager): + """Test pause_state property returns UNPAUSED state.""" + with patch( + 'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process' + ): + with patch('multiprocessing.get_context'): + scheduler = DPScheduler( + vllm_config=mock_vllm_config, + kv_cache_config=mock_kv_cache_config, + structured_output_manager=mock_structured_output_manager, + block_size=16, + ) + + scheduler.input_queues = [MagicMock(), MagicMock()] + + mock_queue_0 = MagicMock() + mock_queue_0.get.return_value = PauseState.UNPAUSED + + scheduler.output_queues = { + (0, "get_pause_state"): mock_queue_0, + } + + result = scheduler.pause_state + + assert result == PauseState.UNPAUSED + def test_make_stats_aggregates_from_workers( self, mock_vllm_config, mock_kv_cache_config, mock_structured_output_manager): diff --git a/tpu_inference/core/sched/dp_scheduler.py b/tpu_inference/core/sched/dp_scheduler.py index 34cdd65e7d..944acc1c57 100644 --- a/tpu_inference/core/sched/dp_scheduler.py +++ b/tpu_inference/core/sched/dp_scheduler.py @@ -26,7 +26,7 @@ from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.sched.async_scheduler import AsyncScheduler -from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput, SchedulerOutput) from vllm.v1.core.sched.scheduler import Scheduler @@ -59,6 +59,8 @@ class SchedulerCommand(Enum): GET_TOKEN_COUNT = "get_token_count" GET_COMPUTED_BLOCKS = "get_computed_blocks" RESET_ENCODER_CACHE = "reset_encoder_cache" + SET_PAUSE_STATE = "set_pause_state" + GET_PAUSE_STATE = "get_pause_state" SHUTDOWN = "shutdown" @@ -192,6 +194,15 @@ def _scheduler_worker_process( scheduler.reset_encoder_cache() output_queues[command.value].put(None) + case SchedulerCommand.SET_PAUSE_STATE: + pause_state = data + scheduler.set_pause_state(pause_state) + output_queues[command.value].put(None) + + case SchedulerCommand.GET_PAUSE_STATE: + result = scheduler.pause_state + output_queues[command.value].put(result) + case SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS: result = scheduler.get_num_unfinished_requests() output_queues[command.value].put(result) @@ -761,6 +772,24 @@ def reset_encoder_cache(self) -> None: self._get_result_from_queue(rank, SchedulerCommand.RESET_ENCODER_CACHE) + @property + def pause_state(self) -> PauseState: + """Get the pause state from the first DP rank scheduler. + + All ranks share the same pause state, so we only need to query one. + """ + self.input_queues[0].put((SchedulerCommand.GET_PAUSE_STATE, None)) + return self._get_result_from_queue(0, SchedulerCommand.GET_PAUSE_STATE) + + def set_pause_state(self, pause_state: PauseState) -> None: + """Set pause state for all DP rank schedulers.""" + for rank in range(self.dp_size): + self.input_queues[rank].put( + (SchedulerCommand.SET_PAUSE_STATE, pause_state)) + + for rank in range(self.dp_size): + self._get_result_from_queue(rank, SchedulerCommand.SET_PAUSE_STATE) + def make_stats(self, spec_decoding_stats=None, kv_connector_stats=None) -> Optional[SchedulerStats]: