Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 135 additions & 2 deletions tests/core/test_dp_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
from vllm.config import VllmConfig
from vllm.v1.core.sched.interface import PauseState
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import KVCacheConfig
Expand Down Expand Up @@ -90,8 +91,8 @@ def test_init_creates_worker_processes(
assert len(scheduler.processes) == 2
assert len(scheduler.input_queues) == 2
# output_queues is a dict with (rank, command) tuple keys
# 2 ranks × 15 commands (SchedulerCommand enum)
assert len(scheduler.output_queues) == 30
# 2 ranks × 17 commands (SchedulerCommand enum)
assert len(scheduler.output_queues) == 34
assert scheduler.log_stats is True
assert len(scheduler.per_rank_kv_cache_configs) == 2

Expand Down Expand Up @@ -702,6 +703,138 @@ def test_reset_encoder_cache(self, mock_vllm_config, mock_kv_cache_config,
scheduler.input_queues[1].put.assert_called_with(
(SchedulerCommand.RESET_ENCODER_CACHE, None))

def test_set_pause_state(self, mock_vllm_config, mock_kv_cache_config,
mock_structured_output_manager):
"""Test set_pause_state sends SET_PAUSE_STATE command to all workers."""
with patch(
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
):
with patch('multiprocessing.get_context'):
scheduler = DPScheduler(
vllm_config=mock_vllm_config,
kv_cache_config=mock_kv_cache_config,
structured_output_manager=mock_structured_output_manager,
block_size=16,
)

scheduler.input_queues = [MagicMock(), MagicMock()]

mock_queue_0 = MagicMock()
mock_queue_0.get.return_value = None
mock_queue_1 = MagicMock()
mock_queue_1.get.return_value = None

scheduler.output_queues = {
(0, "set_pause_state"): mock_queue_0,
(1, "set_pause_state"): mock_queue_1,
}

scheduler.set_pause_state(PauseState.PAUSED_ALL)

# Verify SET_PAUSE_STATE commands were sent to all ranks
scheduler.input_queues[0].put.assert_called_with(
(SchedulerCommand.SET_PAUSE_STATE, PauseState.PAUSED_ALL))
scheduler.input_queues[1].put.assert_called_with(
(SchedulerCommand.SET_PAUSE_STATE, PauseState.PAUSED_ALL))

# Verify we waited for completion from both ranks
mock_queue_0.get.assert_called_once()
mock_queue_1.get.assert_called_once()

def test_set_pause_state_paused_new(self, mock_vllm_config,
mock_kv_cache_config,
mock_structured_output_manager):
"""Test set_pause_state with PAUSED_NEW state."""
with patch(
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
):
with patch('multiprocessing.get_context'):
scheduler = DPScheduler(
vllm_config=mock_vllm_config,
kv_cache_config=mock_kv_cache_config,
structured_output_manager=mock_structured_output_manager,
block_size=16,
)

scheduler.input_queues = [MagicMock(), MagicMock()]

mock_queue_0 = MagicMock()
mock_queue_0.get.return_value = None
mock_queue_1 = MagicMock()
mock_queue_1.get.return_value = None

scheduler.output_queues = {
(0, "set_pause_state"): mock_queue_0,
(1, "set_pause_state"): mock_queue_1,
}

scheduler.set_pause_state(PauseState.PAUSED_NEW)

scheduler.input_queues[0].put.assert_called_with(
(SchedulerCommand.SET_PAUSE_STATE, PauseState.PAUSED_NEW))
scheduler.input_queues[1].put.assert_called_with(
(SchedulerCommand.SET_PAUSE_STATE, PauseState.PAUSED_NEW))

def test_pause_state_property(self, mock_vllm_config, mock_kv_cache_config,
mock_structured_output_manager):
"""Test pause_state property queries rank 0 and returns the state."""
with patch(
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
):
with patch('multiprocessing.get_context'):
scheduler = DPScheduler(
vllm_config=mock_vllm_config,
kv_cache_config=mock_kv_cache_config,
structured_output_manager=mock_structured_output_manager,
block_size=16,
)

scheduler.input_queues = [MagicMock(), MagicMock()]

mock_queue_0 = MagicMock()
mock_queue_0.get.return_value = PauseState.PAUSED_ALL

scheduler.output_queues = {
(0, "get_pause_state"): mock_queue_0,
}

result = scheduler.pause_state

# Verify GET_PAUSE_STATE command was sent only to rank 0
scheduler.input_queues[0].put.assert_called_with(
(SchedulerCommand.GET_PAUSE_STATE, None))
# Verify rank 1 was NOT queried
scheduler.input_queues[1].put.assert_not_called()

assert result == PauseState.PAUSED_ALL

def test_pause_state_unpaused(self, mock_vllm_config, mock_kv_cache_config,
mock_structured_output_manager):
"""Test pause_state property returns UNPAUSED state."""
with patch(
'tpu_inference.core.sched.dp_scheduler._scheduler_worker_process'
):
with patch('multiprocessing.get_context'):
scheduler = DPScheduler(
vllm_config=mock_vllm_config,
kv_cache_config=mock_kv_cache_config,
structured_output_manager=mock_structured_output_manager,
block_size=16,
)

scheduler.input_queues = [MagicMock(), MagicMock()]

mock_queue_0 = MagicMock()
mock_queue_0.get.return_value = PauseState.UNPAUSED

scheduler.output_queues = {
(0, "get_pause_state"): mock_queue_0,
}

result = scheduler.pause_state

assert result == PauseState.UNPAUSED

def test_make_stats_aggregates_from_workers(
self, mock_vllm_config, mock_kv_cache_config,
mock_structured_output_manager):
Expand Down
31 changes: 30 additions & 1 deletion tpu_inference/core/sched/dp_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
SchedulerOutput)
from vllm.v1.core.sched.scheduler import Scheduler
Expand Down Expand Up @@ -59,6 +59,8 @@ class SchedulerCommand(Enum):
GET_TOKEN_COUNT = "get_token_count"
GET_COMPUTED_BLOCKS = "get_computed_blocks"
RESET_ENCODER_CACHE = "reset_encoder_cache"
SET_PAUSE_STATE = "set_pause_state"
GET_PAUSE_STATE = "get_pause_state"
SHUTDOWN = "shutdown"


Expand Down Expand Up @@ -192,6 +194,15 @@ def _scheduler_worker_process(
scheduler.reset_encoder_cache()
output_queues[command.value].put(None)

case SchedulerCommand.SET_PAUSE_STATE:
pause_state = data
scheduler.set_pause_state(pause_state)
output_queues[command.value].put(None)

case SchedulerCommand.GET_PAUSE_STATE:
result = scheduler.pause_state
output_queues[command.value].put(result)

case SchedulerCommand.GET_NUM_UNFINISHED_REQUESTS:
result = scheduler.get_num_unfinished_requests()
output_queues[command.value].put(result)
Expand Down Expand Up @@ -761,6 +772,24 @@ def reset_encoder_cache(self) -> None:
self._get_result_from_queue(rank,
SchedulerCommand.RESET_ENCODER_CACHE)

@property
def pause_state(self) -> PauseState:
"""Get the pause state from the first DP rank scheduler.

All ranks share the same pause state, so we only need to query one.
"""
self.input_queues[0].put((SchedulerCommand.GET_PAUSE_STATE, None))
return self._get_result_from_queue(0, SchedulerCommand.GET_PAUSE_STATE)

def set_pause_state(self, pause_state: PauseState) -> None:
"""Set pause state for all DP rank schedulers."""
for rank in range(self.dp_size):
self.input_queues[rank].put(
(SchedulerCommand.SET_PAUSE_STATE, pause_state))

for rank in range(self.dp_size):
self._get_result_from_queue(rank, SchedulerCommand.SET_PAUSE_STATE)

def make_stats(self,
spec_decoding_stats=None,
kv_connector_stats=None) -> Optional[SchedulerStats]:
Expand Down
Loading