Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions tests/entrypoints/openai/test_server_load_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for server load limit functionality."""

from unittest.mock import MagicMock

import pytest
from fastapi.responses import JSONResponse

from vllm.entrypoints.utils import load_aware_call


class TestServerLoadLimit:
"""Test suite for server load limiting functionality."""

@pytest.mark.asyncio
async def test_load_aware_call_max_load_exceeded(self):
"""Test that requests are rejected when max load is exceeded."""

@load_aware_call
async def dummy_handler(raw_request):
return {"message": "success"}

# Mock request with load exceeding limit
mock_request = MagicMock()
mock_request.app.state.enable_server_load_tracking = True
mock_request.app.state.max_server_load = 10
mock_request.app.state.server_load_metrics = 15 # Exceeds limit
mock_request.app.state.server_overload_rejections_since_last_log = 0

response = await dummy_handler(raw_request=mock_request)

assert isinstance(response, JSONResponse)
assert response.status_code == 503

# Verify error content
import json
content = json.loads(response.body.decode('utf-8'))
assert content["error"]["type"] == "server_overloaded"
assert "Server is currently overloaded" in content["error"]["message"]
assert "Please try again later" in content["error"]["message"]

@pytest.mark.asyncio
async def test_load_aware_call_max_load_at_limit(self):
"""Test that requests are rejected when load equals limit."""

@load_aware_call
async def dummy_handler(raw_request):
return {"message": "success"}

# Mock request with load exactly at limit
mock_request = MagicMock()
mock_request.app.state.enable_server_load_tracking = True
mock_request.app.state.max_server_load = 10
mock_request.app.state.server_load_metrics = 10 # At limit
mock_request.app.state.server_overload_rejections_since_last_log = 0

response = await dummy_handler(raw_request=mock_request)

assert isinstance(response, JSONResponse)
assert response.status_code == 503

@pytest.mark.asyncio
async def test_load_aware_call_max_load_under_limit(self):
"""Test that requests proceed normally when under limit."""

@load_aware_call
async def dummy_handler(raw_request):
return {"message": "success"}

# Mock request with load under limit
mock_request = MagicMock()
mock_request.app.state.enable_server_load_tracking = True
mock_request.app.state.max_server_load = 10
mock_request.app.state.server_load_metrics = 5 # Under limit
mock_request.app.state.server_overload_rejections_since_last_log = 0

response = await dummy_handler(raw_request=mock_request)

# Should proceed normally
assert response == {"message": "success"}

@pytest.mark.asyncio
async def test_load_aware_call_max_load_not_set(self):
"""Test that requests proceed normally when max_server_load is None."""

@load_aware_call
async def dummy_handler(raw_request):
return {"message": "success"}

# Mock request with no max load set
mock_request = MagicMock()
mock_request.app.state.enable_server_load_tracking = True
mock_request.app.state.max_server_load = None # No limit
mock_request.app.state.server_load_metrics = 100 # High load
mock_request.app.state.server_overload_rejections_since_last_log = 0

response = await dummy_handler(raw_request=mock_request)

# Should proceed normally despite high load
assert response == {"message": "success"}

@pytest.mark.asyncio
async def test_load_aware_call_tracking_disabled(self):
"""Test that load limiting is bypassed when tracking is disabled."""

@load_aware_call
async def dummy_handler(raw_request):
return {"message": "success"}

# Mock request with tracking disabled
mock_request = MagicMock()
mock_request.app.state.enable_server_load_tracking = False
mock_request.app.state.max_server_load = 5
mock_request.app.state.server_load_metrics = 100 # High load
mock_request.app.state.server_overload_rejections_since_last_log = 0

response = await dummy_handler(raw_request=mock_request)

# Should proceed normally when tracking is disabled
assert response == {"message": "success"}

@pytest.mark.asyncio
async def test_load_aware_call_with_exception(self):
"""Test that load counter is properly decremented on exception."""

@load_aware_call
async def failing_handler(raw_request):
raise ValueError("Test exception")

# Mock request under limit
mock_request = MagicMock()
mock_request.app.state.enable_server_load_tracking = True
mock_request.app.state.max_server_load = 10
mock_request.app.state.server_load_metrics = 5
mock_request.app.state.server_overload_rejections_since_last_log = 0

# Should raise the original exception
with pytest.raises(ValueError, match="Test exception"):
await failing_handler(raw_request=mock_request)

# Load counter should be decremented back to 5
assert mock_request.app.state.server_load_metrics == 5

@pytest.mark.asyncio
async def test_load_aware_call_increments_counter(self):
"""Test that load counter is properly incremented."""

@load_aware_call
async def dummy_handler(raw_request):
# Verify counter was incremented
assert raw_request.app.state.server_load_metrics == 6
return {"message": "success"}

# Mock request under limit
mock_request = MagicMock()
mock_request.app.state.enable_server_load_tracking = True
mock_request.app.state.max_server_load = 10
mock_request.app.state.server_load_metrics = 5
mock_request.app.state.server_overload_rejections_since_last_log = 0

response = await dummy_handler(raw_request=mock_request)

assert response == {"message": "success"}

@pytest.mark.asyncio
async def test_load_aware_call_zero_max_load(self):
"""Test behavior when max_server_load is set to 0."""

@load_aware_call
async def dummy_handler(raw_request):
return {"message": "success"}

# Mock request with zero max load
mock_request = MagicMock()
mock_request.app.state.enable_server_load_tracking = True
mock_request.app.state.max_server_load = 0
mock_request.app.state.server_load_metrics = 0
mock_request.app.state.server_overload_rejections_since_last_log = 0

response = await dummy_handler(raw_request=mock_request)

# Should be rejected since 0 >= 0
assert isinstance(response, JSONResponse)
assert response.status_code == 503

def test_max_server_load_parameter_exists(self):
"""Test that max_server_load parameter is properly defined."""
from vllm.entrypoints.openai.cli_args import FrontendArgs

# Check that the parameter exists in FrontendArgs
frontend_args = FrontendArgs()
assert hasattr(frontend_args, 'max_server_load')
assert frontend_args.max_server_load is None # Default value

def test_frontend_args_annotation(self):
"""Test that max_server_load has proper type annotation."""
from vllm.entrypoints.openai.cli_args import FrontendArgs

# Get type hints
annotations = FrontendArgs.__annotations__
assert 'max_server_load' in annotations

# Should be Optional[int]
import typing
expected_type = typing.Optional[int]
assert annotations['max_server_load'] == expected_type
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,7 @@ async def init_app_state(

state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0
state.max_server_load = args.max_server_load


def create_server_socket(addr: tuple[str, int]) -> socket.socket:
Expand Down
14 changes: 14 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ class FrontendArgs:
"""If set to True, enable prompt_tokens_details in usage."""
enable_server_load_tracking: bool = False
"""If set to True, enable tracking server_load_metrics in the app state."""
max_server_load: Optional[int] = None
"""Maximum number of concurrent requests allowed. When exceeded, new
requests will be rejected with HTTP 503. Only effective when
--enable-server-load-tracking is enabled."""
enable_force_include_usage: bool = False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint: bool = False
Expand Down Expand Up @@ -265,6 +269,16 @@ def validate_parsed_serve_args(args: argparse.Namespace):
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")

# Validate max_server_load
if args.max_server_load is not None:
if not args.enable_server_load_tracking:
raise TypeError("Error: --max-server-load requires "
"--enable-server-load-tracking to be enabled")
if not isinstance(args.max_server_load,
int) or args.max_server_load <= 0:
raise TypeError(
"Error: --max-server-load must be a positive integer")


def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
Expand Down
68 changes: 61 additions & 7 deletions vllm/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import os
import subprocess
import sys
import time
from http import HTTPStatus
from typing import Any, Optional, Union

from fastapi import Request
Expand All @@ -24,6 +26,16 @@

logger = init_logger(__name__)

SERVER_OVERLOADED_RESPONSE = JSONResponse(
content={
"error": {
"type": "server_overloaded",
"message":
"Server is currently overloaded. Please try again later."
}
},
status_code=HTTPStatus.SERVICE_UNAVAILABLE)

VLLM_SUBCMD_PARSER_EPILOG = (
"Tip: Use `vllm [serve|run-batch|bench <bench_type>] "
"--help=<keyword>` to explore arguments from help.\n"
Expand Down Expand Up @@ -98,6 +110,39 @@ def decrement_server_load(request: Request):
request.app.state.server_load_metrics -= 1


def _flush_pending_overload_warnings(app_state):
"""Flush pending aggregated overload warnings if interval elapsed."""
now = time.monotonic()
pending = getattr(app_state, "server_overload_rejections_since_last_log",
0)
if pending > 0:
last_log_time = getattr(app_state, "server_overload_last_log_time",
now)
log_interval = getattr(app_state, "server_overload_log_interval", 60.0)
if (now - last_log_time) >= log_interval:
max_load_snapshot = getattr(app_state, "max_server_load", None)
try:
logger.warning(
"Server overloaded: current load %s >= max load %s. "
"Rejected %d requests since last log.",
app_state.server_load_metrics, max_load_snapshot, pending)
except Exception:
logger.exception("Failed to log server overload warning")
else:
app_state.server_overload_rejections_since_last_log = 0
app_state.server_overload_last_log_time = now


def _aggregate_rejection_stats(app_state):
"""Aggregate rejections since last log."""
now = time.monotonic()
if not hasattr(app_state, "server_overload_last_log_time"):
app_state.server_overload_last_log_time = now
if not hasattr(app_state, "server_overload_rejections_since_last_log"):
app_state.server_overload_rejections_since_last_log = 0
app_state.server_overload_rejections_since_last_log += 1


def load_aware_call(func):

@functools.wraps(func)
Expand All @@ -109,19 +154,28 @@ async def wrapper(*args, **kwargs):
raise ValueError(
"raw_request required when server load tracking is enabled")

if not getattr(raw_request.app.state, "enable_server_load_tracking",
False):
app_state = raw_request.app.state
if not getattr(app_state, "enable_server_load_tracking", False):
return await func(*args, **kwargs)

# ensure the counter exists
if not hasattr(raw_request.app.state, "server_load_metrics"):
raw_request.app.state.server_load_metrics = 0
if not hasattr(app_state, "server_load_metrics"):
app_state.server_load_metrics = 0

# Flush pending aggregated overload warnings if interval elapsed.
_flush_pending_overload_warnings(app_state)

max_load = getattr(app_state, "max_server_load", None)
if max_load is not None and app_state.server_load_metrics >= max_load:
# Aggregate rejections since last log
_aggregate_rejection_stats(app_state)
return SERVER_OVERLOADED_RESPONSE

raw_request.app.state.server_load_metrics += 1
app_state.server_load_metrics += 1
try:
response = await func(*args, **kwargs)
except Exception:
raw_request.app.state.server_load_metrics -= 1
app_state.server_load_metrics -= 1
raise

if isinstance(response, (JSONResponse, StreamingResponse)):
Expand All @@ -141,7 +195,7 @@ async def wrapper(*args, **kwargs):
tasks.add_task(decrement_server_load, raw_request)
response.background = tasks
else:
raw_request.app.state.server_load_metrics -= 1
app_state.server_load_metrics -= 1

return response

Expand Down