diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 3d5d7ce45cce..516b24443d4c 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -69,7 +69,7 @@ This living user guide outlines a few known **important changes and limitations* way by using a simple dictionary (e.g., {request_id: num_tokens}) to dynamically allocate a fixed token budget per request, enabling features like chunked prefills, prefix caching, and speculative decoding without a strict separation between prefill -and decode phases. +and decode phases. The V1 scheduler supports multiple scheduling policies, including First-Come, First-Served (FCFS) and priority-based scheduling (where requests are processed based on assigned priority, with FCFS as a tie-breaker), configurable via the `--scheduling-policy` argument. ### Semantic Changes and Deprecated Features diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f38454b1b288..079179730d5d 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1246,3 +1246,597 @@ def test_memory_leak(): # Confirm no memory leak. assert_scheduler_empty(scheduler) + + +def create_scheduler_with_priority( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: Optional[bool] = None, + long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, + use_kv_connector: bool = False, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, +) -> Scheduler: + '''Create scheduler with priority policy enabled. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + {class}`Scheduler` instance with priority scheduling + ''' + if max_model_len is None: + max_model_len = max_num_batched_tokens + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=True, + policy="priority", # Enable priority scheduling + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + speculative_config: Optional[SpeculativeConfig] = None + if num_speculative_tokens is not None: + speculative_config = SpeculativeConfig( + model="ngram", num_speculative_tokens=num_speculative_tokens) + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + speculative_config=speculative_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests_with_priority( + num_requests: int, + priorities: list[int], + arrival_times: Optional[list[float]] = None, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None): + """Create requests with specified priorities and arrival times.""" + assert len(priorities) == num_requests + if arrival_times is not None: + assert len(arrival_times) == num_requests + else: + arrival_times = [float(i) for i in range(num_requests)] + + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=arrival_times[i], + priority=priorities[i], + ) + requests.append(request) + return requests + + +def test_priority_scheduling_basic_ordering(): + """Test that requests are scheduled in priority order + (lower value = higher priority).""" + scheduler = create_scheduler_with_priority() + + # Create requests with different priorities + # Priority 0 (highest), 1, 2 (lowest) + priorities = [2, 0, 1] # Add in non-priority order + arrival_times = [1.0, 2.0, 3.0] # All different arrival times + requests = create_requests_with_priority(num_requests=3, + priorities=priorities, + arrival_times=arrival_times) + + # Add requests in non-priority order + for request in requests: + scheduler.add_request(request) + + # Schedule and verify priority order + output = scheduler.schedule() + + # Should schedule all requests since they fit in budget + assert len(output.scheduled_new_reqs) == 3 + + # Verify they are scheduled in priority order: + # req_1 (priority 0), req_2 (priority 1), req_0 (priority 2) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["1", "2", "0"] + + +def test_priority_scheduling_arrival_time_tiebreaker(): + """Test that arrival time is used + as tiebreaker when priorities are equal.""" + scheduler = create_scheduler_with_priority() + + # Create requests with same priority but different arrival times + priorities = [1, 1, 1] # All same priority + arrival_times = [3.0, 1.0, 2.0] # Different arrival times + requests = create_requests_with_priority(num_requests=3, + priorities=priorities, + arrival_times=arrival_times) + + # Add requests in non-arrival order + for request in requests: + scheduler.add_request(request) + + # Schedule and verify arrival time order + output = scheduler.schedule() + + # Should schedule all requests since they fit in budget + assert len(output.scheduled_new_reqs) == 3 + + # Verify they are scheduled in arrival time order: + # req_1 (1.0), req_2 (2.0), req_0 (3.0) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["1", "2", "0"] + + +def test_priority_scheduling_mixed_priority_and_arrival(): + """Test priority scheduling with mixed priorities and arrival times.""" + scheduler = create_scheduler_with_priority() + + # Create requests with mixed priorities and arrival times + priorities = [2, 1, 1, 0] # Mixed priorities + arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times + requests = create_requests_with_priority(num_requests=4, + priorities=priorities, + arrival_times=arrival_times) + + # Add requests + for request in requests: + scheduler.add_request(request) + + # Schedule and verify order + output = scheduler.schedule() + + # Should schedule all requests since they fit in budget + assert len(output.scheduled_new_reqs) == 4 + + # Expected order: + # 1. req_3 (priority 0, arrival 4.0) + # 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1 + # 3. req_1 (priority 1, arrival 3.0) + # 4. req_0 (priority 2, arrival 1.0) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["3", "2", "1", "0"] + + +def test_priority_scheduling_preemption(): + """Test that priority scheduling preempts + lower priority requests when memory is constrained.""" + # Create scheduler with very limited memory to force preemption + scheduler = create_scheduler_with_priority( + max_num_seqs=3, # Allow multiple requests + max_num_batched_tokens=200, + num_blocks=6, # Very limited blocks to force memory pressure + block_size=16, # Standard block size + ) + + # Create initial low-priority requests that will consume most memory + low_priority_requests = create_requests_with_priority( + num_requests=2, + priorities=[5, 5], # Low priority + arrival_times=[1.0, 2.0], + num_tokens=30 # Large enough to consume significant memory + ) + + # Add and schedule low priority requests + for request in low_priority_requests: + scheduler.add_request(request) + + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 2 + + # Simulate model execution to move requests to running state + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in low_priority_requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(low_priority_requests) + }, + sampled_token_ids=[[100] for _ in low_priority_requests], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(output, model_output) + + # Verify both requests are running + assert len(scheduler.running) == 2 + + # Now add a high-priority request that requires memory allocation + # This should trigger preemption due to memory constraints + high_priority_request = create_requests_with_priority( + num_requests=1, + priorities=[0], # High priority + arrival_times=[3.0], + num_tokens=30 # Large enough to require significant memory + )[0] + + scheduler.add_request(high_priority_request) + + # Schedule again - this should trigger + # preemption when trying to allocate memory + output = scheduler.schedule() + + # Due to the scheduler's design, if preemption happens + # during running request scheduling, + # waiting requests won't be scheduled in the same step + # Let's check if preemption occurred by looking at the waiting queue + + # If preemption happened, we should see requests in the + # waiting queue + if len(scheduler.waiting) > 1: # high priority + preempted request + # Preemption occurred - verify the high priority request + # gets scheduled next + output2 = scheduler.schedule() + assert len(output2.scheduled_new_reqs) == 1 + # High priority request + assert output2.scheduled_new_reqs[0].req_id == "0" + else: + # No preemption needed - all requests fit + # This is also valid behavior if memory allows + assert len(output.scheduled_new_reqs) == 1 + # High priority request + assert output.scheduled_new_reqs[0].req_id == "0" + + +def test_priority_scheduling_no_preemption_when_space_available(): + """Test that preemption doesn't happen + when there's space for new requests.""" + scheduler = create_scheduler_with_priority( + max_num_seqs=3, # Allow 3 concurrent requests + max_num_batched_tokens=200, # Sufficient token budget + ) + + # Add two low-priority running requests + low_priority_requests = create_requests_with_priority( + num_requests=2, + priorities=[5, 5], + arrival_times=[1.0, 2.0], + num_tokens=30) + + for request in low_priority_requests: + scheduler.add_request(request) + + output = scheduler.schedule() + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in low_priority_requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(low_priority_requests) + }, + sampled_token_ids=[[100] for _ in low_priority_requests], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(output, model_output) + + # Add high-priority request + high_priority_request = create_requests_with_priority(num_requests=1, + priorities=[0], + arrival_times=[3.0], + num_tokens=30)[0] + + scheduler.add_request(high_priority_request) + + # Schedule - should not preempt since there's space + output = scheduler.schedule() + + # Should schedule the new request without preemption + assert len(output.scheduled_new_reqs) == 1 + assert len(scheduler.running) == 3 # All three requests running + assert len(scheduler.waiting) == 0 # No requests waiting + + +def test_priority_scheduling_preemption_victim_selection(): + """Test that the correct victim is selected for + preemption based on priority and arrival time.""" + # This test verifies the priority-based victim selection logic + # by checking the waiting queue order after adding requests with different + # priorities + scheduler = create_scheduler_with_priority( + max_num_seqs=1, # Force sequential processing to test priority order + ) + + # Create requests with different priorities + requests = create_requests_with_priority( + num_requests=3, + priorities=[3, 2, 0], # Different priorities: low, medium, high + arrival_times=[1.0, 2.0, 3.0], + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule - should only schedule the highest priority request + # (req_2, priority 0) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority + + # Verify the waiting queue has the remaining requests in priority order + assert len(scheduler.waiting) == 2 + + # Extract waiting requests and verify priority order + temp_waiting = list(scheduler.waiting) + temp_waiting.sort() # Sort by (priority, arrival_time, request) + + waiting_priorities = [priority for priority, _, _ in temp_waiting] + waiting_req_ids = [req.request_id for _, _, req in temp_waiting] + + # Should be req_1 (priority 2) then req_0 (priority 3) + assert waiting_priorities == [2, 3] + assert waiting_req_ids == ["1", "0"] + + +def test_priority_scheduling_equal_priority_preemption(): + """Test arrival time tiebreaker when requests have equal priority.""" + # This test verifies that arrival time is used as a tiebreaker for equal + # priorities + scheduler = create_scheduler_with_priority( + max_num_seqs=1, # Force sequential processing + ) + + # Create requests with same priority but different arrival times + requests = create_requests_with_priority( + num_requests=3, + priorities=[2, 2, 2], # Same priority + arrival_times=[3.0, 1.0, 2.0], # Different arrival times + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule - should schedule the request with earliest arrival time + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0) + + # Verify the waiting queue has remaining requests in arrival time order + assert len(scheduler.waiting) == 2 + + # Extract waiting requests and verify arrival time order + temp_waiting = list(scheduler.waiting) + temp_waiting.sort() # Sort by (priority, arrival_time, request) + + waiting_arrival_times = [ + arrival_time for _, arrival_time, _ in temp_waiting + ] + waiting_req_ids = [req.request_id for _, _, req in temp_waiting] + + # Should be req_2 (arrival 2.0) then req_0 (arrival 3.0) + assert waiting_arrival_times == [2.0, 3.0] + assert waiting_req_ids == ["2", "0"] + + +def test_priority_scheduling_waiting_queue_order(): + """Test that the waiting queue maintains priority order.""" + scheduler = create_scheduler_with_priority( + max_num_seqs=1, # Only one request can run at a time + ) + + # Create multiple requests with different priorities + requests = create_requests_with_priority( + num_requests=4, + priorities=[3, 1, 2, 0], # Mixed priorities + arrival_times=[1.0, 2.0, 3.0, 4.0], + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule - should only schedule the highest priority request + # (req_3, priority 0) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_new_reqs[0].req_id == "3" + + # Verify waiting queue has remaining requests in priority order + assert len(scheduler.waiting) == 3 + + # Extract requests from waiting queue + # (it's a heap, so we need to pop to see order) + waiting_priorities = [] + waiting_req_ids = [] + temp_waiting = list(scheduler.waiting) # Copy the heap + temp_waiting.sort() # Sort by (priority, arrival_time, request) + + for priority, arrival_time, request in temp_waiting: + waiting_priorities.append(priority) + waiting_req_ids.append(request.request_id) + + # Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3) + assert waiting_req_ids == ["1", "2", "0"] + assert waiting_priorities == [1, 2, 3] + + +def test_priority_scheduling_fcfs_fallback(): + """Test that FCFS behavior is maintained when all + requests have same priority.""" + scheduler = create_scheduler_with_priority() + + # Create requests with same priority but different arrival times + priorities = [1, 1, 1, 1] # All same priority + arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times + requests = create_requests_with_priority(num_requests=4, + priorities=priorities, + arrival_times=arrival_times) + + # Add requests + for request in requests: + scheduler.add_request(request) + + # Schedule + output = scheduler.schedule() + + # Should schedule all requests in arrival time order + assert len(output.scheduled_new_reqs) == 4 + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + + # Expected order by arrival time: + # req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0) + assert scheduled_req_ids == ["1", "3", "2", "0"] + + +def test_priority_scheduling_with_limited_slots(): + """Test priority scheduling when max_num_seqs limits concurrent requests.""" + scheduler = create_scheduler_with_priority( + max_num_seqs=2, # Only allow 2 concurrent requests + max_num_batched_tokens=1000, # Plenty of token budget + ) + + # Create requests with different priorities + requests = create_requests_with_priority( + num_requests=4, + priorities=[3, 1, 2, 0], # Mixed priorities + arrival_times=[1.0, 2.0, 3.0, 4.0], + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule - should only schedule the 2 highest priority requests + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 2 + + # Should schedule req_3 (priority 0) and req_1 (priority 1) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert "3" in scheduled_req_ids # Priority 0 + assert "1" in scheduled_req_ids # Priority 1 + + # Remaining requests should be in waiting queue in priority order + assert len(scheduler.waiting) == 2 + + # Extract waiting requests and verify order + temp_waiting = list(scheduler.waiting) + temp_waiting.sort() + waiting_priorities = [priority for priority, _, _ in temp_waiting] + waiting_req_ids = [req.request_id for _, _, req in temp_waiting] + + # Should be req_2 (priority 2) then req_0 (priority 3) + assert waiting_priorities == [2, 3] + assert waiting_req_ids == ["2", "0"] + + +def test_priority_scheduling_heap_property(): + """Test that the waiting queue maintains heap + property for priority scheduling.""" + scheduler = create_scheduler_with_priority( + max_num_seqs=1, # Only one request can run at a time + ) + + # Add requests in random priority order + priorities = [5, 1, 8, 3, 2, 7, 4, 6] + arrival_times = [float(i) for i in range(len(priorities))] + requests = create_requests_with_priority(num_requests=len(priorities), + priorities=priorities, + arrival_times=arrival_times, + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule one request at a time and verify priority order + scheduled_priorities = [] + + while scheduler.waiting: + output = scheduler.schedule() + if output.scheduled_new_reqs: + req = output.scheduled_new_reqs[0] + scheduled_priorities.append(requests[int(req.req_id)].priority) + + # Simulate completion to make room for next request + model_output = ModelRunnerOutput( + req_ids=[req.req_id], + req_id_to_index={req.req_id: 0}, + sampled_token_ids=[[100]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(output, model_output) + + # Finish the request to make room for the next one + scheduler.finish_requests(req.req_id, + RequestStatus.FINISHED_STOPPED) + + # Verify requests were scheduled in priority order (lowest value first) + expected_priorities = sorted(priorities) + assert scheduled_priorities == expected_priorities diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ce16a1ed5a09..89e4016a7b97 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2,10 +2,11 @@ from __future__ import annotations +import heapq import time from collections import defaultdict, deque from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -63,8 +64,8 @@ def __init__( # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens + self.max_num_scheduled_tokens = ( + self.scheduler_config.max_num_batched_tokens) self.max_model_len = self.scheduler_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None @@ -88,8 +89,12 @@ def __init__( # req_id -> Request self.requests: dict[str, Request] = {} + # Scheduling policy + self.policy = self.scheduler_config.policy # Priority queues for requests. - self.waiting: deque[Request] = deque() + self.waiting: Union[list[tuple[int, float, Request]], + deque[Request]] = ([] if self.policy == "priority" + else deque()) self.running: list[Request] = [] # The request IDs that are finished in between the previous and the @@ -104,8 +109,8 @@ def __init__( # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> deque of CachedRequestData - self._cached_reqs_data: dict[ - str, deque[CachedRequestData]] = defaultdict(deque) + self._cached_reqs_data: dict[str, deque[CachedRequestData]] = ( + defaultdict(deque)) # Encoder-related. # Calculate encoder cache size if applicable @@ -209,10 +214,16 @@ def schedule(self) -> SchedulerOutput: encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_budget, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -230,18 +241,33 @@ def schedule(self) -> SchedulerOutput: num_draft_tokens = max( num_new_tokens + request.num_computed_tokens - - request.num_tokens, 0) + request.num_tokens, + 0, + ) while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, num_draft_tokens=num_draft_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) + num_lookahead_tokens=self.num_lookahead_tokens, + ) if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. - preempted_req = self.running.pop() + if not self.running: + # No request to preempt. + can_schedule = False + break + if self.policy == "priority": + preempted_req = min( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + else: + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 @@ -249,7 +275,19 @@ def schedule(self) -> SchedulerOutput: preempted_req.record_event( EngineCoreEventType.PREEMPTED, scheduled_timestamp) - self.waiting.appendleft(preempted_req) + if self.policy == "priority": + heapq.heappush( + cast(list[tuple[int, float, Request]], + self.waiting), + ( + preempted_req.priority, + preempted_req.arrival_time, + preempted_req, + ), + ) + else: + cast(deque[Request], + self.waiting).appendleft(preempted_req) preempted_reqs.append(preempted_req) if preempted_req == request: # No more request to preempt. @@ -307,7 +345,10 @@ def schedule(self) -> SchedulerOutput: # Use a temporary deque to collect requests that need to be skipped # and put back at the head of the waiting queue later - skipped_waiting_requests: deque[Request] = deque() + skipped_waiting_requests: Union[list[tuple[int, float, Request]], + deque[Request]] = ([] if self.policy + == "priority" else + deque()) # Next, schedule the WAITING requests. if not preempted_reqs: @@ -315,7 +356,14 @@ def schedule(self) -> SchedulerOutput: if len(self.running) == self.max_num_running_reqs: break - request = self.waiting[0] + if self.policy == "priority": + if (not self.waiting + ): # Should not happen due to outer loop condition + break + priority_val, arrival_time_val, request = heapq.heappop( + cast(list[tuple[int, float, Request]], self.waiting)) + else: + request = cast(deque[Request], self.waiting).popleft() # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: @@ -325,9 +373,18 @@ def schedule(self) -> SchedulerOutput: else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) + request.request_id, + ) + if self.policy == "priority": + waiting_queue = cast( + list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + cast(deque[Request], + skipped_waiting_requests).appendleft(request) continue # Skip request if the structured output request is still waiting @@ -337,19 +394,33 @@ def schedule(self) -> SchedulerOutput: if structured_output_req and structured_output_req.grammar: request.status = RequestStatus.WAITING else: - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) + if self.policy == "priority": + waiting_queue = cast( + list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + cast(deque[Request], + skipped_waiting_requests).appendleft(request) continue # Check that adding the request still respects the max_loras # constraint. - if self.lora_config and request.lora_request and ( - len(scheduled_loras) == self.lora_config.max_loras - and request.lora_request.lora_int_id - not in scheduled_loras): + if (self.lora_config and request.lora_request and + (len(scheduled_loras) == self.lora_config.max_loras and + request.lora_request.lora_int_id not in scheduled_loras)): # Scheduling would exceed max_loras, skip. - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) + if self.policy == "priority": + waiting_queue = cast(list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + cast(deque[Request], + skipped_waiting_requests).appendleft(request) continue num_external_computed_tokens = 0 @@ -358,9 +429,8 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request)) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: @@ -400,11 +470,16 @@ def schedule(self) -> SchedulerOutput: # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_budget, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -419,6 +494,15 @@ def schedule(self) -> SchedulerOutput: ) if new_blocks is None: # The request cannot be scheduled. + if self.policy == "priority": + waiting_queue = cast(list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + # For FCFS, push back to the front of the deque. + cast(deque[Request], self.waiting).appendleft(request) break # KVConnector: update internal state after allocation. @@ -432,17 +516,27 @@ def schedule(self) -> SchedulerOutput: num_external_computed_tokens, ) - self.waiting.popleft() + # Request was already popped from self.waiting + # (either via heapq.heappop or self.waiting.popleft()) + # unless it was re-added above due to new_blocks being None. if load_kv_async: # If loading async, allocate memory and put request # into the WAITING_FOR_REMOTE_KV state. - skipped_waiting_requests.appendleft(request) + if self.policy == "priority": + waiting_queue = cast(list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + cast(deque[Request], + skipped_waiting_requests).appendleft(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue if request.use_structured_output: - structured_output_request_ids[ - request.request_id] = req_index + structured_output_request_ids[request.request_id] = ( + req_index) req_index += 1 self.running.append(request) if self.log_stats: @@ -478,7 +572,17 @@ def schedule(self) -> SchedulerOutput: # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: - self.waiting.extendleft(skipped_waiting_requests) + if self.policy == "priority": + waiting_queue = cast(list[tuple[int, float, Request]], + self.waiting) + skipped_queue = cast(list[tuple[int, float, Request]], + skipped_waiting_requests) + for item in skipped_queue: + heapq.heappush(waiting_queue, item) + else: # FCFS + waiting_deque = cast(deque[Request], self.waiting) + skipped_deque = cast(deque[Request], skipped_waiting_requests) + waiting_deque.extendleft(skipped_deque) # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) @@ -488,8 +592,8 @@ def schedule(self) -> SchedulerOutput: # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. @@ -734,7 +838,8 @@ def update_from_output( spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1) + num_accepted_tokens=len(generated_token_ids) - 1, + ) cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -862,7 +967,13 @@ def get_request_counts(self) -> tuple[int, int]: return len(self.running), len(self.waiting) def add_request(self, request: Request) -> None: - self.waiting.append(request) + if self.policy == "priority": + heapq.heappush( + cast(list[tuple[int, float, Request]], self.waiting), + (request.priority, request.arrival_time, request), + ) + else: + cast(deque[Request], self.waiting).append(request) self.requests[request.request_id] = request if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) @@ -892,7 +1003,7 @@ def finish_requests( if request.status == RequestStatus.RUNNING: self.running.remove(request) else: - self.waiting.remove(request) + cast(deque[Request], self.waiting).remove(request) request.status = finished_status self._free_request(request) @@ -921,7 +1032,7 @@ def _free_blocks(self, request: Request): del self.requests[request.request_id] def get_num_unfinished_requests(self) -> int: - return len(self.waiting) + len(self.running) + return len(cast(deque[Request], self.waiting)) + len(self.running) def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 @@ -939,7 +1050,7 @@ def make_stats( assert prefix_cache_stats is not None return SchedulerStats( num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), + num_waiting_reqs=len(cast(deque[Request], self.waiting)), gpu_cache_usage=self.kv_cache_manager.usage, prefix_cache_stats=prefix_cache_stats, spec_decoding_stats=spec_decoding_stats, @@ -957,7 +1068,8 @@ def make_spec_decoding_stats( spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) spec_decoding_stats.observe_draft( num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_accepted_tokens=num_accepted_tokens, + ) return spec_decoding_stats def shutdown(self) -> None: @@ -981,8 +1093,8 @@ def _connector_finished( """ if self.connector is None: return False, None - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" + assert (len(self.kv_cache_config.kv_cache_groups) == 1 + ), "KV connector only supports one KV cache group now" block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] return self.connector.request_finished(request, block_ids) @@ -1000,8 +1112,8 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ if request.request_id not in self.finished_recving_kv_req_ids: return False - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" + assert (len(self.kv_cache_config.kv_cache_groups) == 1 + ), "KV connector only supports one KV cache group now" # Now that the blocks are ready, actually cache them. block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] num_computed_tokens = len(block_ids) * self.block_size @@ -1032,9 +1144,9 @@ def _update_from_kv_xfer_finished(self, scheduler the request during the next step. """ # P/D: update recv and send status from last step. - for req_id in (model_runner_output.finished_recving or ()): + for req_id in model_runner_output.finished_recving or (): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (model_runner_output.finished_sending or ()): + for req_id in model_runner_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 0c9f61a76427..46710e98db97 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -63,6 +63,7 @@ class EngineCoreRequest( # belong to, to cover a race condition where the request is sent before # a wave finished notification is received. current_wave: int = 0 + priority: int = 0 class EngineCoreEventType(enum.IntEnum): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 64a756148780..15d9f5e917dd 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -217,8 +217,6 @@ def process_inputs( # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) self._validate_params(params, lora_request) - if priority != 0: - raise ValueError("V1 does not support priority yet.") if trace_headers is not None: raise ValueError("V1 does not support tracing yet.") if prompt_adapter_request is not None: @@ -327,6 +325,7 @@ def process_inputs( arrival_time=arrival_time, lora_request=lora_request, cache_salt=decoder_inputs.get("cache_salt"), + priority=priority, ) def _validate_model_inputs(self, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 42c75ef96401..350826386c0a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import enum +import time from typing import TYPE_CHECKING, Any, Optional, Union from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -27,17 +28,22 @@ def __init__( sampling_params: SamplingParams, eos_token_id: Optional[int], client_index: int = 0, + arrival_time: Optional[float] = None, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, + priority: int = 0, ) -> None: self.request_id = request_id self.client_index = client_index + self.priority = priority self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id self.lora_request = lora_request self.structured_output_request = structured_output_request + self.arrival_time = arrival_time if arrival_time is not None else \ + time.time() self.status = (RequestStatus.WAITING_FOR_FSM if sampling_params.guided_decoding is not None else @@ -91,17 +97,18 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( request_id=request.request_id, - client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, + arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), cache_salt=request.cache_salt, + priority=request.priority, ) def append_output_token_ids(