diff --git a/tests/entrypoints/openai/test_server_load_limit.py b/tests/entrypoints/openai/test_server_load_limit.py new file mode 100644 index 000000000000..02ce867c7b00 --- /dev/null +++ b/tests/entrypoints/openai/test_server_load_limit.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for server load limit functionality.""" + +from unittest.mock import MagicMock + +import pytest +from fastapi.responses import JSONResponse + +from vllm.entrypoints.utils import load_aware_call + + +class TestServerLoadLimit: + """Test suite for server load limiting functionality.""" + + @pytest.mark.asyncio + async def test_load_aware_call_max_load_exceeded(self): + """Test that requests are rejected when max load is exceeded.""" + + @load_aware_call + async def dummy_handler(raw_request): + return {"message": "success"} + + # Mock request with load exceeding limit + mock_request = MagicMock() + mock_request.app.state.enable_server_load_tracking = True + mock_request.app.state.max_server_load = 10 + mock_request.app.state.server_load_metrics = 15 # Exceeds limit + mock_request.app.state.server_overload_rejections_since_last_log = 0 + + response = await dummy_handler(raw_request=mock_request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 503 + + # Verify error content + import json + content = json.loads(response.body.decode('utf-8')) + assert content["error"]["type"] == "server_overloaded" + assert "Server is currently overloaded" in content["error"]["message"] + assert "Please try again later" in content["error"]["message"] + + @pytest.mark.asyncio + async def test_load_aware_call_max_load_at_limit(self): + """Test that requests are rejected when load equals limit.""" + + @load_aware_call + async def dummy_handler(raw_request): + return {"message": "success"} + + # Mock request with load exactly at limit + mock_request = MagicMock() + mock_request.app.state.enable_server_load_tracking = True + mock_request.app.state.max_server_load = 10 + mock_request.app.state.server_load_metrics = 10 # At limit + mock_request.app.state.server_overload_rejections_since_last_log = 0 + + response = await dummy_handler(raw_request=mock_request) + + assert isinstance(response, JSONResponse) + assert response.status_code == 503 + + @pytest.mark.asyncio + async def test_load_aware_call_max_load_under_limit(self): + """Test that requests proceed normally when under limit.""" + + @load_aware_call + async def dummy_handler(raw_request): + return {"message": "success"} + + # Mock request with load under limit + mock_request = MagicMock() + mock_request.app.state.enable_server_load_tracking = True + mock_request.app.state.max_server_load = 10 + mock_request.app.state.server_load_metrics = 5 # Under limit + mock_request.app.state.server_overload_rejections_since_last_log = 0 + + response = await dummy_handler(raw_request=mock_request) + + # Should proceed normally + assert response == {"message": "success"} + + @pytest.mark.asyncio + async def test_load_aware_call_max_load_not_set(self): + """Test that requests proceed normally when max_server_load is None.""" + + @load_aware_call + async def dummy_handler(raw_request): + return {"message": "success"} + + # Mock request with no max load set + mock_request = MagicMock() + mock_request.app.state.enable_server_load_tracking = True + mock_request.app.state.max_server_load = None # No limit + mock_request.app.state.server_load_metrics = 100 # High load + mock_request.app.state.server_overload_rejections_since_last_log = 0 + + response = await dummy_handler(raw_request=mock_request) + + # Should proceed normally despite high load + assert response == {"message": "success"} + + @pytest.mark.asyncio + async def test_load_aware_call_tracking_disabled(self): + """Test that load limiting is bypassed when tracking is disabled.""" + + @load_aware_call + async def dummy_handler(raw_request): + return {"message": "success"} + + # Mock request with tracking disabled + mock_request = MagicMock() + mock_request.app.state.enable_server_load_tracking = False + mock_request.app.state.max_server_load = 5 + mock_request.app.state.server_load_metrics = 100 # High load + mock_request.app.state.server_overload_rejections_since_last_log = 0 + + response = await dummy_handler(raw_request=mock_request) + + # Should proceed normally when tracking is disabled + assert response == {"message": "success"} + + @pytest.mark.asyncio + async def test_load_aware_call_with_exception(self): + """Test that load counter is properly decremented on exception.""" + + @load_aware_call + async def failing_handler(raw_request): + raise ValueError("Test exception") + + # Mock request under limit + mock_request = MagicMock() + mock_request.app.state.enable_server_load_tracking = True + mock_request.app.state.max_server_load = 10 + mock_request.app.state.server_load_metrics = 5 + mock_request.app.state.server_overload_rejections_since_last_log = 0 + + # Should raise the original exception + with pytest.raises(ValueError, match="Test exception"): + await failing_handler(raw_request=mock_request) + + # Load counter should be decremented back to 5 + assert mock_request.app.state.server_load_metrics == 5 + + @pytest.mark.asyncio + async def test_load_aware_call_increments_counter(self): + """Test that load counter is properly incremented.""" + + @load_aware_call + async def dummy_handler(raw_request): + # Verify counter was incremented + assert raw_request.app.state.server_load_metrics == 6 + return {"message": "success"} + + # Mock request under limit + mock_request = MagicMock() + mock_request.app.state.enable_server_load_tracking = True + mock_request.app.state.max_server_load = 10 + mock_request.app.state.server_load_metrics = 5 + mock_request.app.state.server_overload_rejections_since_last_log = 0 + + response = await dummy_handler(raw_request=mock_request) + + assert response == {"message": "success"} + + @pytest.mark.asyncio + async def test_load_aware_call_zero_max_load(self): + """Test behavior when max_server_load is set to 0.""" + + @load_aware_call + async def dummy_handler(raw_request): + return {"message": "success"} + + # Mock request with zero max load + mock_request = MagicMock() + mock_request.app.state.enable_server_load_tracking = True + mock_request.app.state.max_server_load = 0 + mock_request.app.state.server_load_metrics = 0 + mock_request.app.state.server_overload_rejections_since_last_log = 0 + + response = await dummy_handler(raw_request=mock_request) + + # Should be rejected since 0 >= 0 + assert isinstance(response, JSONResponse) + assert response.status_code == 503 + + def test_max_server_load_parameter_exists(self): + """Test that max_server_load parameter is properly defined.""" + from vllm.entrypoints.openai.cli_args import FrontendArgs + + # Check that the parameter exists in FrontendArgs + frontend_args = FrontendArgs() + assert hasattr(frontend_args, 'max_server_load') + assert frontend_args.max_server_load is None # Default value + + def test_frontend_args_annotation(self): + """Test that max_server_load has proper type annotation.""" + from vllm.entrypoints.openai.cli_args import FrontendArgs + + # Get type hints + annotations = FrontendArgs.__annotations__ + assert 'max_server_load' in annotations + + # Should be Optional[int] + import typing + expected_type = typing.Optional[int] + assert annotations['max_server_load'] == expected_type diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e5d31c1fd03f..39873c2a1950 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1762,6 +1762,7 @@ async def init_app_state( state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 + state.max_server_load = args.max_server_load def create_server_socket(addr: tuple[str, int]) -> socket.socket: diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index e15f65b43082..01d795e82adb 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -164,6 +164,10 @@ class FrontendArgs: """If set to True, enable prompt_tokens_details in usage.""" enable_server_load_tracking: bool = False """If set to True, enable tracking server_load_metrics in the app state.""" + max_server_load: Optional[int] = None + """Maximum number of concurrent requests allowed. When exceeded, new + requests will be rejected with HTTP 503. Only effective when + --enable-server-load-tracking is enabled.""" enable_force_include_usage: bool = False """If set to True, including usage on every request.""" enable_tokenizer_info_endpoint: bool = False @@ -265,6 +269,16 @@ def validate_parsed_serve_args(args: argparse.Namespace): raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") + # Validate max_server_load + if args.max_server_load is not None: + if not args.enable_server_load_tracking: + raise TypeError("Error: --max-server-load requires " + "--enable-server-load-tracking to be enabled") + if not isinstance(args.max_server_load, + int) or args.max_server_load <= 0: + raise TypeError( + "Error: --max-server-load must be a positive integer") + def create_parser_for_docs() -> FlexibleArgumentParser: parser_for_docs = FlexibleArgumentParser( diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index d8905fc14124..e1a6caeabd15 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -8,6 +8,8 @@ import os import subprocess import sys +import time +from http import HTTPStatus from typing import Any, Optional, Union from fastapi import Request @@ -24,6 +26,16 @@ logger = init_logger(__name__) +SERVER_OVERLOADED_RESPONSE = JSONResponse( + content={ + "error": { + "type": "server_overloaded", + "message": + "Server is currently overloaded. Please try again later." + } + }, + status_code=HTTPStatus.SERVICE_UNAVAILABLE) + VLLM_SUBCMD_PARSER_EPILOG = ( "Tip: Use `vllm [serve|run-batch|bench ] " "--help=` to explore arguments from help.\n" @@ -98,6 +110,39 @@ def decrement_server_load(request: Request): request.app.state.server_load_metrics -= 1 +def _flush_pending_overload_warnings(app_state): + """Flush pending aggregated overload warnings if interval elapsed.""" + now = time.monotonic() + pending = getattr(app_state, "server_overload_rejections_since_last_log", + 0) + if pending > 0: + last_log_time = getattr(app_state, "server_overload_last_log_time", + now) + log_interval = getattr(app_state, "server_overload_log_interval", 60.0) + if (now - last_log_time) >= log_interval: + max_load_snapshot = getattr(app_state, "max_server_load", None) + try: + logger.warning( + "Server overloaded: current load %s >= max load %s. " + "Rejected %d requests since last log.", + app_state.server_load_metrics, max_load_snapshot, pending) + except Exception: + logger.exception("Failed to log server overload warning") + else: + app_state.server_overload_rejections_since_last_log = 0 + app_state.server_overload_last_log_time = now + + +def _aggregate_rejection_stats(app_state): + """Aggregate rejections since last log.""" + now = time.monotonic() + if not hasattr(app_state, "server_overload_last_log_time"): + app_state.server_overload_last_log_time = now + if not hasattr(app_state, "server_overload_rejections_since_last_log"): + app_state.server_overload_rejections_since_last_log = 0 + app_state.server_overload_rejections_since_last_log += 1 + + def load_aware_call(func): @functools.wraps(func) @@ -109,19 +154,28 @@ async def wrapper(*args, **kwargs): raise ValueError( "raw_request required when server load tracking is enabled") - if not getattr(raw_request.app.state, "enable_server_load_tracking", - False): + app_state = raw_request.app.state + if not getattr(app_state, "enable_server_load_tracking", False): return await func(*args, **kwargs) # ensure the counter exists - if not hasattr(raw_request.app.state, "server_load_metrics"): - raw_request.app.state.server_load_metrics = 0 + if not hasattr(app_state, "server_load_metrics"): + app_state.server_load_metrics = 0 + + # Flush pending aggregated overload warnings if interval elapsed. + _flush_pending_overload_warnings(app_state) + + max_load = getattr(app_state, "max_server_load", None) + if max_load is not None and app_state.server_load_metrics >= max_load: + # Aggregate rejections since last log + _aggregate_rejection_stats(app_state) + return SERVER_OVERLOADED_RESPONSE - raw_request.app.state.server_load_metrics += 1 + app_state.server_load_metrics += 1 try: response = await func(*args, **kwargs) except Exception: - raw_request.app.state.server_load_metrics -= 1 + app_state.server_load_metrics -= 1 raise if isinstance(response, (JSONResponse, StreamingResponse)): @@ -141,7 +195,7 @@ async def wrapper(*args, **kwargs): tasks.add_task(decrement_server_load, raw_request) response.background = tasks else: - raw_request.app.state.server_load_metrics -= 1 + app_state.server_load_metrics -= 1 return response