diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d0bdd4916144..a5fe6297c92a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -464,6 +464,7 @@ class EngineArgs: max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode disable_log_stats: bool = False + max_waiting_queue_time: float | None = None aggregate_engine_logging: bool = False revision: str | None = ModelConfig.revision code_revision: str | None = ModelConfig.code_revision diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 0b3b29cd6c1f..ecb37a78ebfd 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -112,6 +112,15 @@ async def abort(self, request_id: str | Iterable[str]) -> None: @abstractmethod async def is_tracing_enabled(self) -> bool: ... + @abstractmethod + def get_estimated_queue_time(self) -> float: + """Get the estimated queue time in seconds based on historical average. + + This is the predicted wait time for a new request before it gets + scheduled. Returns 0.0 if no historical data is available. + """ + ... + @abstractmethod async def do_log_stats(self) -> None: ... diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index ad7982b615c4..72475f608009 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -105,12 +105,14 @@ def __init__( enable_log_outputs: bool = False, enable_log_deltas: bool = True, default_chat_template_kwargs: dict[str, Any] | None = None, + max_waiting_queue_time: float | None = None, ) -> None: super().__init__( engine_client=engine_client, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, + max_waiting_queue_time=max_waiting_queue_time, ) self.openai_serving_render = openai_serving_render diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index ab28b62999d8..a618917f9d09 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -156,9 +156,14 @@ class BaseFrontendArgs: """If set to True, log the stack trace of error responses""" tokens_only: bool = False """ - If set to True, only enable the Tokens In<>Out endpoint. + If set to True, only enable the Tokens In<>Out endpoint. This is intended for use in a Disaggregated Everything setup. """ + max_waiting_queue_time: float | None = None + """ + Maximum estimated queue time in seconds. If the predicted waiting time + exceeds this limit, new requests will be rejected with a 503 error. + """ @classmethod def _customize_cli_kwargs( diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 96cd7797c14d..0ca72777396b 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -59,12 +59,14 @@ def __init__( return_tokens_as_token_ids: bool = False, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + max_waiting_queue_time: float | None = None, ): super().__init__( engine_client=engine_client, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, + max_waiting_queue_time=max_waiting_queue_time, ) self.openai_serving_render = openai_serving_render diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 405db1a134c1..fd3ad5cd7678 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -185,6 +185,7 @@ def __init__( *, request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, + max_waiting_queue_time: float | None = None, ): super().__init__() @@ -194,6 +195,7 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids + self.max_waiting_queue_time = max_waiting_queue_time self.model_config = engine_client.model_config self.renderer = engine_client.renderer @@ -465,6 +467,18 @@ def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None: ) return None + def _validate_max_waiting_queue_time(self) -> ErrorResponse | None: + if self.max_waiting_queue_time is not None: + queue_time = self.engine_client.get_estimated_queue_time() + if queue_time > self.max_waiting_queue_time: + return self.create_error_response( + "The server is currently experiencing high load.\n" + "Please try again later.", + err_type="ServiceUnavailableError", + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + return None + def _create_pooling_params( self, ctx: ServeContext, @@ -598,7 +612,8 @@ async def _check_model( request: AnyRequest, ) -> ErrorResponse | None: error_response = None - + if error_response := self._validate_max_waiting_queue_time(): + return error_response if self._is_model_supported(request.model): return None if request.model in self.models.lora_requests: diff --git a/vllm/entrypoints/openai/generate/api_router.py b/vllm/entrypoints/openai/generate/api_router.py index c81c295e4597..2c27d1325077 100644 --- a/vllm/entrypoints/openai/generate/api_router.py +++ b/vllm/entrypoints/openai/generate/api_router.py @@ -116,6 +116,7 @@ async def init_generate_state( enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, enable_log_deltas=args.enable_log_deltas, + max_waiting_queue_time=args.max_waiting_queue_time, ) if "generate" in supported_tasks else None @@ -131,6 +132,7 @@ async def init_generate_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + max_waiting_queue_time=args.max_waiting_queue_time, ) if "generate" in supported_tasks else None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a9c42e78e53b..c51012a20a91 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -867,6 +867,9 @@ def get_tokenizer(self) -> TokenizerLike: async def is_tracing_enabled(self) -> bool: return self.observability_config.otlp_traces_endpoint is not None + def get_estimated_queue_time(self) -> float: + return self.output_processor.queue_time_tracker.avg_queue_time + async def do_log_stats(self) -> None: if self.logger_manager: self.logger_manager.log() diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index f9e965092288..85bc7609eecb 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -34,6 +34,7 @@ from vllm.v1.metrics.stats import ( IterationStats, LoRARequestStates, + QueueTimeTracker, RequestStateStats, SchedulerStats, ) @@ -429,6 +430,7 @@ def __init__( self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list) self.lora_states = LoRARequestStates(log_stats) self.tracing_enabled = tracing_enabled + self.queue_time_tracker = QueueTimeTracker() def get_num_unfinished_requests(self): return len(self.request_states) @@ -793,6 +795,13 @@ def _update_stats_from_finished( assert finish_reason is not None assert req_state.stats is not None + + # Update historical queue time tracker before computing + # finished request stats. + if req_state.stats.scheduled_ts > 0.0 and req_state.stats.queued_ts > 0.0: + queued_time = req_state.stats.scheduled_ts - req_state.stats.queued_ts + self.queue_time_tracker.observe(queued_time) + iteration_stats.update_from_finished_request( finish_reason=finish_reason, num_prompt_tokens=req_state.prompt_len, diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 4a1e8b6f35ce..48e5414d4097 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -111,6 +111,85 @@ def hit_rate(self) -> float: return self.aggregated_query_hit / self.aggregated_query_total +class QueueTimeTracker: + """Tracks average queue time over a sliding window of recent requests. + + Modeled after CachingMetrics. Maintains a running average of queue time + (scheduled_ts - queued_ts) across the most recent N finished requests. + + Includes time-based decay: if no requests have finished within a period + based on the current avg_queue_time, the tracker resets to 0 to avoid + blocking new requests when the server is idle. + + Args: + max_recent_requests: The number of the most recent requests to + aggregate. Defaults to 100. + decay_multiplier: Multiplier for avg_queue_time to determine decay + period. Defaults to 3.0 (i.e., decay after 3x the avg queue time + with no new observations). Minimum decay period is 10 seconds. + """ + + def __init__( + self, + max_recent_requests: int = 100, + decay_multiplier: float = 3.0, + ) -> None: + self.max_recent_requests = max_recent_requests + self.decay_multiplier = decay_multiplier + self.aggregated_requests = 0 + self.aggregated_queue_time = 0.0 + self.queue: deque[tuple[int, float]] = deque() + self._last_observation_time: float | None = None + + def observe(self, queue_time: float) -> None: + """Observe the queue time of a single finished request.""" + self._last_observation_time = time.time() + self.queue.append((1, queue_time)) + self.aggregated_requests += 1 + self.aggregated_queue_time += queue_time + + while ( + len(self.queue) > 1 and self.aggregated_requests > self.max_recent_requests + ): + old_count, old_qt = self.queue.popleft() + self.aggregated_requests -= old_count + self.aggregated_queue_time -= old_qt + + @property + def avg_queue_time(self) -> float: + """Return the average queue time in seconds. + + Returns 0.0 if no requests have been observed recently (within + the decay period based on current avg_queue_time), indicating + the server is idle. + """ + if self.aggregated_requests == 0: + return 0.0 + + # Calculate dynamic decay period: 3x the current average queue time + # But with a minimum of 10 seconds to handle the cold start case + current_avg = self.aggregated_queue_time / self.aggregated_requests + decay_period = max(10.0, current_avg * self.decay_multiplier) + + # Check if we should decay - no observations within decay period + if ( + self._last_observation_time is not None + and time.time() - self._last_observation_time > decay_period + ): + # Reset the tracker - server has been idle + self.reset() + return 0.0 + + return current_avg + + def reset(self) -> None: + """Reset the tracker to initial state.""" + self.aggregated_requests = 0 + self.aggregated_queue_time = 0.0 + self.queue.clear() + self._last_observation_time = None + + @dataclass class PrefixCacheStats(BaseCacheStats): """