diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index a858a4d8c823..ec928c0fe104 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -12,6 +12,7 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.engine.exceptions import SchedulerWaitingQueueFullError from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) from vllm.v1.outputs import ModelRunnerOutput @@ -1832,3 +1833,109 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): assert len(output.scheduled_new_reqs) == 0 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 + + +def test_scheduler_max_waiting_queue_length(): + """Test that V1 scheduler respects max_waiting_queue_length setting.""" + max_waiting_queue_length = 2 + scheduler = create_scheduler( + max_num_seqs=64, + max_num_batched_tokens=100, + max_waiting_queue_length=max_waiting_queue_length, + ) + requests = create_requests(num_requests=max_waiting_queue_length) + + # Add requests up to the limit + for i, request in enumerate(requests): + scheduler.add_request(request) + assert len(scheduler.waiting) == i + 1 + + assert len(scheduler.waiting) == max_waiting_queue_length + # Try to add one more request - should raise exception + overflow_request = create_requests(num_requests=1)[0] + overflow_request.request_id = "overflow" + + with pytest.raises(SchedulerWaitingQueueFullError, + match="Scheduler waiting queue is full"): + scheduler.add_request(overflow_request) + + # Verify that the queue size hasn't changed + assert len(scheduler.waiting) == max_waiting_queue_length + + +def test_scheduler_max_waiting_queue_length_disabled(): + """Test that V1 scheduler allows unlimited queue when + max_waiting_queue_length is None.""" + scheduler = create_scheduler( + max_num_seqs=64, + max_num_batched_tokens=100, + max_waiting_queue_length=None, # No limit + ) + + # Add many requests - should not raise an exception + num_requests = 10 + requests = create_requests(num_requests=num_requests) + for i, request in enumerate(requests): + scheduler.add_request(request) + assert len(scheduler.waiting) == i + 1 + + +def test_scheduler_max_waiting_queue_length_with_scheduling(): + """Test max_waiting_queue_length behavior when requests are being + scheduled.""" + + max_waiting_queue_length = 2 + scheduler = create_scheduler( + max_num_seqs=1, # Only 1 can run at once, forcing others to wait + max_num_batched_tokens=100, + max_waiting_queue_length=max_waiting_queue_length, + ) + + # Add requests up to the waiting queue limit + requests = create_requests(num_requests=max_waiting_queue_length) + + # Add requests up to the limit + for request in requests: + scheduler.add_request(request) + + # All requests should be in waiting queue initially + assert len(scheduler.waiting) == max_waiting_queue_length + assert len(scheduler.running) == 0 + + # Schedule one request (should move 1 from waiting to running) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 # max_num_seqs = 1 + assert len(scheduler.running) == 1 + assert len( + scheduler.waiting) == max_waiting_queue_length - 1 # 1 left in waiting + + # Now add one more request to fill the waiting queue back to its limit + additional_request = create_requests(num_requests=1)[0] + additional_request.request_id = "additional" + scheduler.add_request(additional_request) + + assert len( + scheduler.waiting) == max_waiting_queue_length # back to full capacity + + # Try to add one more request - should raise exception + overflow_request = create_requests(num_requests=1)[0] + overflow_request.request_id = "overflow" + + with pytest.raises(SchedulerWaitingQueueFullError, + match="Scheduler waiting queue is full"): + scheduler.add_request(overflow_request) + + # Verify queue sizes are unchanged + assert len(scheduler.waiting) == max_waiting_queue_length + assert len(scheduler.running) == 1 + + +def test_scheduler_max_waiting_queue_length_zero(): + """Test that max_waiting_queue_length=0 raises ValueError.""" + with pytest.raises(ValueError, + match="max_waiting_queue_length cannot be 0"): + create_scheduler( + max_num_seqs=1, # Only 1 can run at once + max_num_batched_tokens=100, + max_waiting_queue_length=0, # Should raise ValueError + ) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 0b7d8251b640..07175396b0d0 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -32,6 +32,7 @@ def create_scheduler( num_speculative_tokens: Optional[int] = None, skip_tokenizer_init: bool = False, async_scheduling: bool = False, + max_waiting_queue_length: Optional[int] = None, ) -> Union[Scheduler, AsyncScheduler]: '''Create scheduler under test. @@ -56,6 +57,7 @@ def create_scheduler( disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, async_scheduling=async_scheduling, + max_waiting_queue_length=max_waiting_queue_length, ) model_config = ModelConfig( model=model, diff --git a/vllm/config.py b/vllm/config.py index f94c08c32536..272e6a317826 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2292,6 +2292,11 @@ class SchedulerConfig: structured outputs, speculative decoding, and pipeline parallelism. """ + max_waiting_queue_length: Optional[int] = None + """Maximum number of requests that can be in the waiting queue. + When the queue reaches this limit, new requests will be rejected + with HTTP 503 error. If None, no limit is enforced.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2455,6 +2460,19 @@ def _verify_args(self) -> Self: def is_multi_step(self) -> bool: return self.num_scheduler_steps > 1 + @field_validator("max_waiting_queue_length") + @classmethod + def validate_max_waiting_queue_length( + cls, value: Optional[int]) -> Optional[int]: + if value == 0: + raise ValueError( + "max_waiting_queue_length cannot be 0. Use None for unlimited " + "queue or a positive integer for a limited queue.") + if value is not None and value < 0: + raise ValueError( + "max_waiting_queue_length must be None or a positive integer") + return value + Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"] diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a5eb16a53976..ca2eb179eee5 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -266,7 +266,7 @@ async def create_chat_completion( generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) assert len(generators) == 1 result_generator, = generators @@ -289,7 +289,7 @@ async def create_chat_completion( conversation, tokenizer, request_metadata) except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: @@ -470,7 +470,7 @@ async def chat_completion_stream_generator( reasoning_parser = self.reasoning_parser(tokenizer) except RuntimeError as e: logger.exception("Error in reasoning parser creation.") - data = self.create_streaming_error_response(str(e)) + data = self.create_streaming_error_response(e) yield f"data: {data}\n\n" yield "data: [DONE]\n\n" return @@ -484,7 +484,7 @@ async def chat_completion_stream_generator( tool_parsers = [None] * num_choices except Exception as e: logger.exception("Error in tool parser creation.") - data = self.create_streaming_error_response(str(e)) + data = self.create_streaming_error_response(e) yield f"data: {data}\n\n" yield "data: [DONE]\n\n" return @@ -935,7 +935,7 @@ async def chat_completion_stream_generator( except Exception as e: # TODO: Use a vllm-specific Validation Error logger.exception("Error in chat completion stream generator.") - data = self.create_streaming_error_response(str(e)) + data = self.create_streaming_error_response(e) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" @@ -961,7 +961,7 @@ async def chat_completion_full_generator( return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) assert final_res is not None @@ -990,7 +990,7 @@ async def chat_completion_full_generator( reasoning_parser = self.reasoning_parser(tokenizer) except RuntimeError as e: logger.exception("Error in reasoning parser creation.") - return self.create_error_response(str(e)) + return self.create_error_response(e) # If the reasoning parser is enabled, # tool calls are extracted exclusively from the content. reasoning_content, content = ( @@ -1065,7 +1065,7 @@ async def chat_completion_full_generator( tool_parser = self.tool_parser(tokenizer) except RuntimeError as e: logger.exception("Error in tool parser creation.") - return self.create_error_response(str(e)) + return self.create_error_response(e) tool_call_info = tool_parser.extract_tool_calls( content if content is not None else "", request=request) diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 3ac4f01ea602..16376a478b65 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -73,7 +73,7 @@ async def _preprocess( except (ValueError, TypeError) as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(e) def _build_response( self, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1e1f655022f0..8e0e117347eb 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -137,16 +137,16 @@ async def create_completion( ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(e) except TypeError as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(e) except RuntimeError as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(e) except jinja2.TemplateError as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(e) # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] @@ -229,7 +229,7 @@ async def create_completion( generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) result_generator = merge_async_iterators(*generators) @@ -293,7 +293,7 @@ async def create_completion( return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) # When user requests streaming but we don't stream, we still need to # return a streaming response with a single event. @@ -475,7 +475,7 @@ async def completion_stream_generator( except Exception as e: # TODO: Use a vllm-specific Validation Error - data = self.create_streaming_error_response(str(e)) + data = self.create_streaming_error_response(e) yield f"data: {data}\n\n" yield "data: [DONE]\n\n" diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e87decfe636a..9938567c7652 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -95,7 +95,7 @@ async def _preprocess( return None except (ValueError, TypeError) as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(e) def _build_response( self, @@ -196,6 +196,6 @@ def _validate_request( try: pooling_params.verify(self.model_config) except ValueError as e: - return self.create_error_response(str(e)) + return self.create_error_response(e) return None diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 462317a0878c..49ecb07b3525 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -76,6 +76,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, merge_async_iterators, random_uuid) +from vllm.v1.engine.exceptions import SchedulerWaitingQueueFullError logger = init_logger(__name__) @@ -404,16 +405,28 @@ async def _collect_batch( def create_error_response( self, - message: str, + message: Union[str, Exception], err_type: str = "BadRequestError", status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: - return ErrorResponse(message=message, + + if isinstance(message, SchedulerWaitingQueueFullError): + return ErrorResponse( + message=str(message), + type="ServiceUnavailableError", + code=HTTPStatus.SERVICE_UNAVAILABLE.value, + ) + elif isinstance(message, Exception): + message_str = str(message) + else: + message_str = message + + return ErrorResponse(message=message_str, type=err_type, code=status_code.value) def create_streaming_error_response( self, - message: str, + message: Union[str, Exception], err_type: str = "BadRequestError", status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: json_str = json.dumps({ diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index c2ed50d04d12..8b8d38daaf6d 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -135,7 +135,7 @@ async def create_pooling( ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(e) # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] @@ -166,7 +166,7 @@ async def create_pooling( generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) result_generator = merge_async_iterators(*generators) @@ -195,7 +195,7 @@ async def create_pooling( return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) return response diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index a359371848ce..1e8e253a7e8a 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -187,7 +187,7 @@ async def create_responses( generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) assert len(generators) == 1 result_generator, = generators @@ -244,7 +244,7 @@ async def create_responses( request_metadata, ) except Exception as e: - return self.create_error_response(str(e)) + return self.create_error_response(e) async def responses_full_generator( self, @@ -267,7 +267,7 @@ async def responses_full_generator( return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) assert final_res is not None assert len(final_res.outputs) == 1 @@ -278,7 +278,7 @@ async def responses_full_generator( reasoning_parser = self.reasoning_parser(tokenizer) except RuntimeError as e: logger.exception("Error in reasoning parser creation.") - return self.create_error_response(str(e)) + return self.create_error_response(e) reasoning_content, content = ( reasoning_parser.extract_reasoning_content(final_output.text, @@ -391,7 +391,7 @@ async def _run_background_request( except Exception as e: logger.exception("Background request failed for %s", request.request_id) - response = self.create_error_response(str(e)) + response = self.create_error_response(e) if isinstance(response, ErrorResponse): # If the request has failed, update the status to "failed". diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 8d47a417f9cd..9005350eacdf 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -385,7 +385,7 @@ async def create_score( return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) async def do_rerank( self, @@ -431,7 +431,7 @@ async def do_rerank( return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) def request_output_to_score_response( self, diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 09b346dcef6b..d5cd70d6de0b 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -171,7 +171,7 @@ async def _create_speech_to_text( except ValueError as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return self.create_error_response(e) list_result_generator: Optional[list[AsyncGenerator[RequestOutput, None]]] = None @@ -200,7 +200,7 @@ async def _create_speech_to_text( ] except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) if request.stream: return stream_generator_method(request, list_result_generator, @@ -218,7 +218,7 @@ async def _create_speech_to_text( return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + return self.create_error_response(e) async def _speech_to_text_stream_generator( self, @@ -324,7 +324,7 @@ async def _speech_to_text_stream_generator( except Exception as e: # TODO: Use a vllm-specific Validation Error logger.exception("Error in %s stream generator.", self.task_type) - data = self.create_streaming_error_response(str(e)) + data = self.create_streaming_error_response(e) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 446f98034cb8..73043857a33a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -28,6 +28,7 @@ from vllm.v1.core.sched.utils import check_stop from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) +from vllm.v1.engine.exceptions import SchedulerWaitingQueueFullError from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput @@ -957,6 +958,14 @@ def get_request_counts(self) -> tuple[int, int]: return len(self.running), len(self.waiting) def add_request(self, request: Request) -> None: + # Check if the waiting queue has reached its maximum capacity + if (self.scheduler_config.max_waiting_queue_length is not None + and len(self.waiting) + >= self.scheduler_config.max_waiting_queue_length): + raise SchedulerWaitingQueueFullError( + f"Scheduler waiting queue is full. Cannot add request " + f"{request.request_id}.") + self.waiting.add_request(request) self.requests[request.request_id] = request if self.log_stats: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3754570dfaaa..562b9ea04654 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -29,7 +29,8 @@ from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient -from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError +from vllm.v1.engine.exceptions import (EngineDeadError, EngineGenerateError, + SchedulerWaitingQueueFullError) from vllm.v1.engine.output_processor import (OutputProcessor, RequestOutputCollector) from vllm.v1.engine.parallel_sampling import ParentRequest @@ -351,6 +352,12 @@ async def generate( logger.info("Request %s failed (bad request).", request_id) raise + # Scheduler waiting queue is full. + except SchedulerWaitingQueueFullError: + if self.log_requests: + logger.info("Request %s failed (queue full).", request_id) + raise + # Unexpected error in the generate() task (possibly recoverable). except Exception as e: await self.abort(request_id) @@ -513,6 +520,12 @@ async def encode( logger.info("Request %s failed (bad request).", request_id) raise + # Scheduler waiting queue is full. + except SchedulerWaitingQueueFullError: + if self.log_requests: + logger.info("Request %s failed (queue full).", request_id) + raise + # Unexpected error in the generate() task (possibly recoverable). except Exception as e: await self.abort(request_id) diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py index 692ba9dc840f..6da128a0ef6c 100644 --- a/vllm/v1/engine/exceptions.py +++ b/vllm/v1/engine/exceptions.py @@ -15,3 +15,9 @@ def __init__(self, *args, suppress_context: bool = False, **kwargs): # Make stack trace clearer when using with LLMEngine by # silencing irrelevant ZMQError. self.__suppress_context__ = suppress_context + + +class SchedulerWaitingQueueFullError(Exception): + """Raised when the scheduler's waiting queue is full and cannot accept + new requests.""" + pass