Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions tests/entrypoints/openai/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import pytest_asyncio
import requests
from prometheus_client.parser import text_string_to_metric_families

from vllm.version import __version__ as VLLM_VERSION

Expand Down Expand Up @@ -219,3 +220,72 @@ def make_long_completion_request():
response = requests.get(server.url_for("load"))
assert response.status_code == HTTPStatus.OK
assert response.json().get("server_load") == 0


@pytest.mark.parametrize(
"server_args",
[
pytest.param(["--enable-http-middleware"],
id="enable-http-middleware"),
],
indirect=True,
)
@pytest.mark.asyncio
async def test_http_error_count(server: RemoteOpenAIServer):
# Check initial server load
response = requests.get(server.url_for("load"))
assert response.status_code == HTTPStatus.OK
assert response.json().get("server_load") == 0

NUM_EXPECTED_ERRORS = 3

# exceed max tokens, should raise exception
bad_request_1 = requests.post(
server.url_for("v1/completions"),
headers={"Content-Type": "application/json"},
json={
"prompt": "Give me a long story",
"max_tokens": 999999999,
"temperature": 0,
},
)
# invalid payload
bad_request_2 = requests.post(
server.url_for("v1/completions"),
headers={"Content-Type": "application/json"},
json={
"bad_prompt": "Give me a long story",
"max_tokens": 100,
"temperature": 0,
},
)
# invalid content type
bad_request_3 = requests.post(
server.url_for("v1/completions"),
headers={"Content-Type": "bad/application/json"},
json={
"prompt": "Give me a long story",
"max_tokens": 100,
"temperature": 0,
},
)

assert bad_request_1.status_code != 200
assert bad_request_2.status_code != 200
assert bad_request_3.status_code != 200

response = requests.get(server.url_for("metrics"))
assert response.status_code == HTTPStatus.OK

metric_family = "http_error_count"

found_metric = False

for family in text_string_to_metric_families(response.text):
if family.name == metric_family:
for sample in family.samples:
if sample.value == NUM_EXPECTED_ERRORS:
found_metric = True
break

assert found_metric
19 changes: 18 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.entrypoints.utils import (http_error_counter, http_middleware,
load_aware_call, with_cancellation)
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
Expand Down Expand Up @@ -294,6 +295,8 @@ async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
if raw_request.app.state.enable_http_middleware:
http_error_counter.inc()
raise HTTPException(
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
detail="Unsupported Media Type: Only 'application/json' is allowed"
Expand All @@ -318,6 +321,9 @@ def mount_metrics(app: FastAPI):
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)

# Register http service level metrics
http_error_counter.registry = registry

# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
Expand Down Expand Up @@ -453,6 +459,7 @@ async def show_version():
dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
@http_middleware
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
handler = chat(raw_request)
Expand All @@ -475,6 +482,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
@http_middleware
async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request)
if handler is None:
Expand All @@ -494,6 +502,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
@http_middleware
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
Expand Down Expand Up @@ -541,6 +550,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
@http_middleware
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
Expand All @@ -560,6 +570,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
@router.post("/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
@http_middleware
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
Expand All @@ -579,6 +590,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
@http_middleware
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "
Expand All @@ -590,6 +602,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
@router.post("/v1/audio/transcriptions")
@with_cancellation
@load_aware_call
@http_middleware
async def create_transcriptions(request: Annotated[TranscriptionRequest,
Form()],
raw_request: Request):
Expand All @@ -615,6 +628,7 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
@http_middleware
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
if handler is None:
Expand Down Expand Up @@ -814,6 +828,8 @@ async def validation_exception_handler(_, exc):
err = ErrorResponse(message=str(exc),
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST)
if app.state.enable_http_middleware:
http_error_counter.inc()
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)

Expand Down Expand Up @@ -982,6 +998,7 @@ async def init_app_state(
state.task = model_config.task

state.enable_server_load_tracking = args.enable_server_load_tracking
state.enable_http_middleware = args.enable_http_middleware
state.server_load_metrics = 0


Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help=
"If set to True, enable tracking server_load_metrics in the app state."
)
parser.add_argument(
"--enable-http-middleware",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I view http-middleware decorator as a code structure extension instead of feature extension. No need to have a explicit control for this?
The feature we could control is the counter/logging, but I feel it's pretty standard and we should just enable it for all use cases? cc: @simon-mo

action="store_true",
default=False,
help="If set to True, enable http middleware decorator.",
)

return parser

Expand Down
39 changes: 39 additions & 0 deletions vllm/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@

from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from prometheus_client import Counter
from starlette.background import BackgroundTask, BackgroundTasks

http_error_counter = Counter(
"http_error_count",
"Total error count across requests with non 2xx response code",
)


async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received"""
Expand Down Expand Up @@ -108,3 +114,36 @@ async def wrapper(*args, **kwargs):
return response

return wrapper


# to avoid the performance hit of a middleware, we can use a decorator
# to handle all http related overhead logic
def http_middleware(func):

@functools.wraps(func)
async def wrapper(*args, **kwargs):
raw_request = kwargs.get("raw_request",
args[1] if len(args) > 1 else None)

if raw_request is None:
raise ValueError(
"raw_request required when http middleware is enabled")

if not raw_request.app.state.enable_http_middleware:
return await func(*args, **kwargs)

try:
response = await func(*args, **kwargs)
except Exception:
http_error_counter.inc()
raise

status_code = (response.status_code
if hasattr(response, "status_code") else
response.code if hasattr(response, "code") else None)
if status_code and not (200 <= status_code < 300):
http_error_counter.inc()

return response

return wrapper