From 34bf33ebaf4ccf4bc6bcd8bb07352239178b93f4 Mon Sep 17 00:00:00 2001 From: Andy Xie Date: Thu, 8 Jan 2026 00:45:16 +0800 Subject: [PATCH] [openai api] log http exception in handler Signed-off-by: Andy Xie --- .../openai/responses/test_errors.py | 27 -- .../openai/responses/test_harmony.py | 4 +- tests/entrypoints/openai/test_chat_error.py | 11 +- .../openai/test_completion_error.py | 11 +- .../entrypoints/openai/test_openai_schema.py | 30 +- tests/entrypoints/openai/test_serving_chat.py | 11 +- tests/v1/engine/test_async_llm.py | 11 +- vllm/entrypoints/launcher.py | 43 +- vllm/entrypoints/openai/api_server.py | 7 +- .../openai/chat_completion/api_router.py | 12 +- .../openai/chat_completion/serving.py | 324 +++++++------- .../openai/completion/api_router.py | 10 +- vllm/entrypoints/openai/completion/serving.py | 120 +++-- vllm/entrypoints/openai/engine/protocol.py | 9 + vllm/entrypoints/openai/engine/serving.py | 174 ++------ .../entrypoints/openai/generate/api_router.py | 4 - .../entrypoints/openai/realtime/api_router.py | 1 - vllm/entrypoints/openai/realtime/serving.py | 2 - .../openai/responses/api_router.py | 24 +- vllm/entrypoints/openai/responses/serving.py | 221 ++++------ vllm/entrypoints/openai/server_utils.py | 80 +++- .../openai/speech_to_text/api_router.py | 15 +- .../openai/speech_to_text/serving.py | 4 - .../openai/speech_to_text/speech_to_text.py | 102 ++--- vllm/entrypoints/pooling/__init__.py | 4 - vllm/entrypoints/pooling/base/serving.py | 42 +- vllm/entrypoints/pooling/embed/api_router.py | 5 +- vllm/entrypoints/pooling/embed/serving.py | 411 +++++++++--------- .../entrypoints/pooling/pooling/api_router.py | 6 +- vllm/entrypoints/pooling/pooling/serving.py | 166 ++++--- vllm/entrypoints/pooling/score/api_router.py | 11 +- vllm/entrypoints/pooling/score/serving.py | 5 - vllm/entrypoints/serve/disagg/api_router.py | 6 +- vllm/entrypoints/serve/disagg/serving.py | 65 ++- vllm/entrypoints/serve/tokenize/api_router.py | 5 +- vllm/entrypoints/serve/tokenize/serving.py | 76 ++-- vllm/entrypoints/utils.py | 44 +- 37 files changed, 912 insertions(+), 1191 deletions(-) diff --git a/tests/entrypoints/openai/responses/test_errors.py b/tests/entrypoints/openai/responses/test_errors.py index 7daa3d1fb58f..0ef9bb901a64 100644 --- a/tests/entrypoints/openai/responses/test_errors.py +++ b/tests/entrypoints/openai/responses/test_errors.py @@ -6,7 +6,6 @@ import pytest -from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.serving import GenerationError, OpenAIServing @@ -38,32 +37,6 @@ async def test_raise_if_error_raises_generation_error(): serving._raise_if_error(None, "test-request-id") # should not raise -@pytest.mark.asyncio -async def test_convert_generation_error_to_response(): - """test _convert_generation_error_to_response creates proper ErrorResponse""" - mock_engine = MagicMock() - mock_engine.model_config = MagicMock() - mock_engine.model_config.max_model_len = 100 - mock_models = MagicMock() - - serving = OpenAIServing( - engine_client=mock_engine, - models=mock_models, - request_logger=None, - ) - - # create a GenerationError - gen_error = GenerationError("Internal server error") - - # convert to ErrorResponse - error_response = serving._convert_generation_error_to_response(gen_error) - - assert isinstance(error_response, ErrorResponse) - assert error_response.error.type == "InternalServerError" - assert error_response.error.message == "Internal server error" - assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR - - @pytest.mark.asyncio async def test_convert_generation_error_to_streaming_response(): """test _convert_generation_error_to_streaming_response output""" diff --git a/tests/entrypoints/openai/responses/test_harmony.py b/tests/entrypoints/openai/responses/test_harmony.py index 78419c92a9d0..3bc041ba485e 100644 --- a/tests/entrypoints/openai/responses/test_harmony.py +++ b/tests/entrypoints/openai/responses/test_harmony.py @@ -13,7 +13,7 @@ import pytest import pytest_asyncio import requests -from openai import BadRequestError, NotFoundError, OpenAI +from openai import InternalServerError, NotFoundError, OpenAI from openai_harmony import Message from ....utils import RemoteOpenAIServer @@ -698,7 +698,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): async def test_function_calling_required(client: OpenAI, model_name: str): tools = [GET_WEATHER_SCHEMA] - with pytest.raises(BadRequestError): + with pytest.raises(InternalServerError): await client.responses.create( model=model_name, input="What's the weather like in Paris today?", diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py index 970945b4759f..2f2fe6acb53d 100644 --- a/tests/entrypoints/openai/test_chat_error.py +++ b/tests/entrypoints/openai/test_chat_error.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field -from http import HTTPStatus from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -11,7 +10,7 @@ from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat -from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.engine.protocol import GenerationError from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.outputs import CompletionOutput, RequestOutput @@ -145,12 +144,8 @@ async def mock_generate(*args, **kwargs): stream=False, ) - response = await serving_chat.create_chat_completion(request) - - assert isinstance(response, ErrorResponse) - assert response.error.type == "InternalServerError" - assert response.error.message == "Internal server error" - assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR + with pytest.raises(GenerationError): + await serving_chat.create_chat_completion(request) @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_completion_error.py b/tests/entrypoints/openai/test_completion_error.py index 1e7a3d0a8c39..c39b9cf4e763 100644 --- a/tests/entrypoints/openai/test_completion_error.py +++ b/tests/entrypoints/openai/test_completion_error.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field -from http import HTTPStatus from typing import Any from unittest.mock import MagicMock @@ -11,7 +10,7 @@ from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.completion.protocol import CompletionRequest from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion -from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.engine.protocol import GenerationError from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.outputs import CompletionOutput, RequestOutput @@ -131,12 +130,8 @@ async def mock_generate(*args, **kwargs): stream=False, ) - response = await serving_completion.create_completion(request) - - assert isinstance(response, ErrorResponse) - assert response.error.type == "InternalServerError" - assert response.error.message == "Internal server error" - assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR + with pytest.raises(GenerationError): + await serving_completion.create_completion(request) @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 2b26ebd041d5..8efffdcaf7ef 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +from http import HTTPStatus from typing import Final import pytest import schemathesis +from httpx import URL from hypothesis import settings from schemathesis import GenerationConfig +from schemathesis.checks import not_a_server_error +from schemathesis.internal.checks import CheckContext +from schemathesis.models import Case +from schemathesis.transports.responses import GenericResponse from ...utils import RemoteOpenAIServer @@ -127,10 +133,25 @@ def no_invalid_types(case: schemathesis.models.Case): return strategy.filter(no_invalid_types) +def customized_not_a_server_error( + ctx: CheckContext, response: GenericResponse, case: Case +) -> bool | None: + try: + return not_a_server_error(ctx, response, case) + except Exception: + if ( + URL(response.request.url).path + in ["/v1/chat/completions/render", "/v1/chat/completions"] + and response.status_code == HTTPStatus.NOT_IMPLEMENTED.value + ): + return True + raise + + @schema.parametrize() @schema.override(headers={"Content-Type": "application/json"}) @settings(deadline=LONG_TIMEOUT_SECONDS * 1000, max_examples=50) -def test_openapi_stateless(case: schemathesis.Case): +def test_openapi_stateless(case: Case): key = ( case.operation.method.upper(), case.operation.path, @@ -155,4 +176,9 @@ def test_openapi_stateless(case: schemathesis.Case): }.get(key, DEFAULT_TIMEOUT_SECONDS) # No need to verify SSL certificate for localhost - case.call_and_validate(verify=False, timeout=timeout) + case.call_and_validate( + verify=False, + timeout=timeout, + additional_checks=(customized_not_a_server_error,), + excluded_checks=(not_a_server_error,), + ) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 33c69578ce93..e1380d4290f8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -23,6 +23,7 @@ ) from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.parser.harmony_utils import get_encoding +from vllm.exceptions import VLLMValidationError from vllm.inputs import TokensPrompt from vllm.outputs import CompletionOutput, RequestOutput from vllm.renderers.hf import HfRenderer @@ -818,9 +819,8 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(): max_tokens=10, ) - resp = await serving_chat.create_chat_completion(req) - assert isinstance(resp, ErrorResponse) - assert "context length is only" in resp.error.message + with pytest.raises(VLLMValidationError): + await serving_chat.create_chat_completion(req) @pytest.mark.asyncio @@ -860,9 +860,8 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(): max_tokens=1, ) - resp = await serving_chat.create_chat_completion(req) - assert isinstance(resp, ErrorResponse) - assert "context length is only" in resp.error.message + with pytest.raises(VLLMValidationError): + await serving_chat.create_chat_completion(req) @pytest.mark.asyncio diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 032da4a0318c..9fd95d0c5782 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -17,9 +17,6 @@ ChatCompletionResponse, ) from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat -from vllm.entrypoints.openai.engine.protocol import ( - ErrorResponse, -) from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.inputs import PromptType @@ -542,11 +539,9 @@ async def test_header_dp_rank_argument(): # Test 2: Out-of-range DP rank (1) mock_raw_request.headers = {"X-data-parallel-rank": "1"} - # should return ErrorResponse for out-of-range rank - response2 = await serving_chat.create_chat_completion(req, mock_raw_request) - assert isinstance(response2, ErrorResponse), ( - "Expected an ErrorResponse for out-of-range DP rank" - ) + # should raise ValueError for out-of-range rank + with pytest.raises(ValueError): + await serving_chat.create_chat_completion(req, mock_raw_request) @pytest.mark.asyncio diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index e75d66bbf685..b442fc70cdb0 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -4,11 +4,10 @@ import asyncio import signal import socket -from http import HTTPStatus from typing import Any import uvicorn -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI from vllm import envs from vllm.engine.protocol import EngineClient @@ -19,7 +18,6 @@ from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils.network_utils import find_process_using_port -from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError logger = init_logger(__name__) @@ -75,7 +73,7 @@ async def serve_http( config.h11_max_header_count = h11_max_header_count config.load() server = uvicorn.Server(config) - _add_shutdown_handlers(app, server) + app.state.server = server loop = asyncio.get_running_loop() @@ -148,40 +146,3 @@ def terminate_if_errored(server: uvicorn.Server, engine: EngineClient): engine_errored = engine.errored and not engine.is_running if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored: server.should_exit = True - - -def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: - """ - VLLM V1 AsyncLLM catches exceptions and returns - only two types: EngineGenerateError and EngineDeadError. - - EngineGenerateError is raised by the per request generate() - method. This error could be request specific (and therefore - recoverable - e.g. if there is an error in input processing). - - EngineDeadError is raised by the background output_handler - method. This error is global and therefore not recoverable. - - We register these @app.exception_handlers to return nice - responses to the end user if they occur and shut down if needed. - See https://fastapi.tiangolo.com/tutorial/handling-errors/ - for more details on how exception handlers work. - - If an exception is encountered in a StreamingResponse - generator, the exception is not raised, since we already sent - a 200 status. Rather, we send an error message as the next chunk. - Since the exception is not raised, this means that the server - will not automatically shut down. Instead, we use the watchdog - background task for check for errored state. - """ - - @app.exception_handler(RuntimeError) - @app.exception_handler(EngineDeadError) - @app.exception_handler(EngineGenerateError) - async def runtime_exception_handler(request: Request, __): - terminate_if_errored( - server=server, - engine=request.app.state.engine_client, - ) - - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 61095035fbfd..ee0b7115dd3c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -31,6 +31,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.server_utils import ( + engine_error_handler, + exception_handler, get_uvicorn_log_config, http_exception_handler, lifespan, @@ -57,6 +59,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.system_utils import decorate_logs, set_ulimit +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.version import __version__ as VLLM_VERSION prometheus_multiproc_dir: tempfile.TemporaryDirectory @@ -250,6 +253,9 @@ def build_app( app.exception_handler(HTTPException)(http_exception_handler) app.exception_handler(RequestValidationError)(validation_exception_handler) + app.exception_handler(EngineGenerateError)(engine_error_handler) + app.exception_handler(EngineDeadError)(engine_error_handler) + app.exception_handler(Exception)(exception_handler) # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]: @@ -355,7 +361,6 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, ) if any(task in supported_tasks for task in ("generate", "render")): diff --git a/vllm/entrypoints/openai/chat_completion/api_router.py b/vllm/entrypoints/openai/chat_completion/api_router.py index 81af0af3dc52..8f2c5c14f23c 100644 --- a/vllm/entrypoints/openai/chat_completion/api_router.py +++ b/vllm/entrypoints/openai/chat_completion/api_router.py @@ -39,6 +39,7 @@ def chat(request: Request) -> OpenAIServingChat | None: HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, }, ) @with_cancellation @@ -54,10 +55,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re message="The model does not support Chat Completions API" ) - try: - generator = await handler.create_chat_completion(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) + generator = await handler.create_chat_completion(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( @@ -81,6 +79,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, }, ) async def render_chat_completion(request: ChatCompletionRequest, raw_request: Request): @@ -93,10 +92,7 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re message="The model does not support Chat Completions API" ) - try: - result = await handler.render_chat_request(request) - except Exception as e: - result = handler.create_error_response(e) + result = await handler.render_chat_request(request) if isinstance(result, ErrorResponse): return JSONResponse(content=result.model_dump(), status_code=result.error.code) diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 06b16cde6748..08c783f87d83 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -8,7 +8,6 @@ from collections.abc import Sequence as GenericSequence from typing import Any, Final -import jinja2 import partial_json_parser import regex as re from fastapi import Request @@ -105,7 +104,6 @@ def __init__( enable_force_include_usage: bool = False, enable_log_outputs: bool = False, enable_log_deltas: bool = True, - log_error_stack: bool = False, default_chat_template_kwargs: dict[str, Any] | None = None, ) -> None: super().__init__( @@ -113,7 +111,6 @@ def __init__( models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - log_error_stack=log_error_stack, ) self.response_role = response_role @@ -235,81 +232,76 @@ async def render_chat_request( if self.engine_client.errored: raise self.engine_client.dead_error - try: - tokenizer = self.renderer.tokenizer - - tool_parser = self.tool_parser - - if is_mistral_tokenizer(tokenizer): - # because of issues with pydantic we need to potentially - # re-serialize the tool_calls field of the request - # for more info: see comment in `maybe_serialize_tool_calls` - _mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type] - _mt.truncate_tool_call_ids(request) # type: ignore[arg-type] - _mt.validate_request_params(request) - - # Check if tool parsing is unavailable (common condition) - tool_parsing_unavailable = ( - tool_parser is None - and not is_mistral_tokenizer(tokenizer) - and not self.use_harmony - ) + tokenizer = self.renderer.tokenizer - # Validate tool_choice when tool parsing is required but unavailable - if tool_parsing_unavailable and request.tool_choice not in ( - None, - "none", - ): - if request.tool_choice == "auto" and not self.enable_auto_tools: - # for hf tokenizers, "auto" tools requires - # --enable-auto-tool-choice and --tool-call-parser - return self.create_error_response( - '"auto" tool choice requires ' - "--enable-auto-tool-choice and --tool-call-parser to be set" - ) - elif request.tool_choice != "auto": - # "required" or named tool requires tool parser - return self.create_error_response( - f'tool_choice="{request.tool_choice}" requires ' - "--tool-call-parser to be set" - ) + tool_parser = self.tool_parser + + if is_mistral_tokenizer(tokenizer): + # because of issues with pydantic we need to potentially + # re-serialize the tool_calls field of the request + # for more info: see comment in `maybe_serialize_tool_calls` + _mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type] + _mt.truncate_tool_call_ids(request) # type: ignore[arg-type] + _mt.validate_request_params(request) + + # Check if tool parsing is unavailable (common condition) + tool_parsing_unavailable = ( + tool_parser is None + and not is_mistral_tokenizer(tokenizer) + and not self.use_harmony + ) - if request.tools is None or ( - request.tool_choice == "none" - and self.exclude_tools_when_tool_choice_none - ): - tool_dicts = None - else: - tool_dicts = [tool.model_dump() for tool in request.tools] - - if not self.use_harmony: - # Common case. - error_check_ret = self._validate_chat_template( - request_chat_template=request.chat_template, - chat_template_kwargs=request.chat_template_kwargs, - trust_request_chat_template=self.trust_request_chat_template, + # Validate tool_choice when tool parsing is required but unavailable + if tool_parsing_unavailable and request.tool_choice not in ( + None, + "none", + ): + if request.tool_choice == "auto" and not self.enable_auto_tools: + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + '"auto" tool choice requires ' + "--enable-auto-tool-choice and --tool-call-parser to be set" ) - if error_check_ret is not None: - return error_check_ret - - conversation, engine_prompts = await self._preprocess_chat( - request, - request.messages, - default_template=self.chat_template, - default_template_content_format=self.chat_template_content_format, - default_template_kwargs=self.default_chat_template_kwargs, - tool_dicts=tool_dicts, - tool_parser=tool_parser, + elif request.tool_choice != "auto": + # "required" or named tool requires tool parser + return self.create_error_response( + f'tool_choice="{request.tool_choice}" requires ' + "--tool-call-parser to be set" ) - else: - # For GPT-OSS. - should_include_tools = tool_dicts is not None - conversation, engine_prompts = self._make_request_with_harmony( - request, should_include_tools - ) - except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(e) + + if request.tools is None or ( + request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none + ): + tool_dicts = None + else: + tool_dicts = [tool.model_dump() for tool in request.tools] + + if not self.use_harmony: + # Common case. + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret + + conversation, engine_prompts = await self._preprocess_chat( + request, + request.messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=self.default_chat_template_kwargs, + tool_dicts=tool_dicts, + tool_parser=tool_parser, + ) + else: + # For GPT-OSS. + should_include_tools = tool_dicts is not None + conversation, engine_prompts = self._make_request_with_harmony( + request, should_include_tools + ) return conversation, engine_prompts @@ -329,20 +321,16 @@ async def create_chat_completion( tokenizer = self.renderer.tokenizer assert tokenizer is not None reasoning_parser: ReasoningParser | None = None - try: - if self.reasoning_parser_cls: - # Pass the same chat template kwargs as used in tokenization - chat_template_kwargs = self._prepare_extra_chat_template_kwargs( - request.chat_template_kwargs, - self.default_chat_template_kwargs, - ) - reasoning_parser = self.reasoning_parser_cls( - tokenizer, - chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] - ) - except RuntimeError as e: - logger.exception("Error in reasoning parser creation.") - return self.create_error_response(str(e)) + if self.reasoning_parser_cls: + # Pass the same chat template kwargs as used in tokenization + chat_template_kwargs = self._prepare_extra_chat_template_kwargs( + request.chat_template_kwargs, + self.default_chat_template_kwargs, + ) + reasoning_parser = self.reasoning_parser_cls( + tokenizer, + chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] + ) result = await self.render_chat_request(request) if isinstance(result, ErrorResponse): return result @@ -357,15 +345,9 @@ async def create_chat_completion( if raw_request: raw_request.state.request_metadata = request_metadata - try: - lora_request = self._maybe_get_adapters( - request, supports_default_mm_loras=True - ) + lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True) - model_name = self.models.model_name(lora_request) - except (ValueError, TypeError, RuntimeError) as e: - logger.exception("Error preparing request components") - return self.create_error_response(e) + model_name = self.models.model_name(lora_request) # Extract data_parallel_rank from header (router can inject it) data_parallel_rank = self._get_data_parallel_rank(raw_request) @@ -373,81 +355,76 @@ async def create_chat_completion( # Schedule the request and get the result generator. max_model_len = self.model_config.max_model_len generators: list[AsyncGenerator[RequestOutput, None]] = [] - try: - for i, engine_prompt in enumerate(engine_prompts): - prompt_token_ids = self._extract_prompt_components( - engine_prompt - ).token_ids - - # If we are creating sub requests for multiple prompts, ensure that they - # have unique request ids. - sub_request_id = ( - request_id if len(engine_prompts) == 1 else f"{request_id}_{i}" - ) + for i, engine_prompt in enumerate(engine_prompts): + prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids + + # If we are creating sub requests for multiple prompts, ensure that they + # have unique request ids. + sub_request_id = ( + request_id if len(engine_prompts) == 1 else f"{request_id}_{i}" + ) + + max_tokens = get_max_tokens( + max_model_len, + request.max_completion_tokens + if request.max_completion_tokens is not None + else request.max_tokens, + self._extract_prompt_len(engine_prompt), + self.default_sampling_params, + self.override_max_tokens, + ) - max_tokens = get_max_tokens( - max_model_len, - request.max_completion_tokens - if request.max_completion_tokens is not None - else request.max_tokens, - self._extract_prompt_len(engine_prompt), + sampling_params: SamplingParams | BeamSearchParams + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + max_tokens, self.default_sampling_params + ) + else: + sampling_params = request.to_sampling_params( + max_tokens, self.default_sampling_params, - self.override_max_tokens, ) - sampling_params: SamplingParams | BeamSearchParams - if request.use_beam_search: - sampling_params = request.to_beam_search_params( - max_tokens, self.default_sampling_params - ) - else: - sampling_params = request.to_sampling_params( - max_tokens, - self.default_sampling_params, - ) + self._log_inputs( + sub_request_id, + engine_prompt, + params=sampling_params, + lora_request=lora_request, + ) - self._log_inputs( - sub_request_id, - engine_prompt, + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) + + if isinstance(sampling_params, BeamSearchParams): + generator = self.beam_search( + prompt=engine_prompt, + request_id=sub_request_id, params=sampling_params, lora_request=lora_request, + trace_headers=trace_headers, ) - - trace_headers = ( - None - if raw_request is None - else await self._get_trace_headers(raw_request.headers) + else: + reasoning_ended = ( + reasoning_parser.is_reasoning_end(prompt_token_ids or []) + if reasoning_parser + else None ) - if isinstance(sampling_params, BeamSearchParams): - generator = self.beam_search( - prompt=engine_prompt, - request_id=sub_request_id, - params=sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - ) - else: - reasoning_ended = ( - reasoning_parser.is_reasoning_end(prompt_token_ids or []) - if reasoning_parser - else None - ) - - generator = self.engine_client.generate( - engine_prompt, - sampling_params, - sub_request_id, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - data_parallel_rank=data_parallel_rank, - reasoning_ended=reasoning_ended, - ) + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + sub_request_id, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + data_parallel_rank=data_parallel_rank, + reasoning_ended=reasoning_ended, + ) - generators.append(generator) - except ValueError as e: - return self.create_error_response(e) + generators.append(generator) assert len(generators) == 1 (result_generator,) = generators @@ -464,21 +441,16 @@ async def create_chat_completion( reasoning_parser, ) - try: - return await self.chat_completion_full_generator( - request, - result_generator, - request_id, - model_name, - conversation, - tokenizer, - request_metadata, - reasoning_parser, - ) - except GenerationError as e: - return self._convert_generation_error_to_response(e) - except ValueError as e: - return self.create_error_response(e) + return await self.chat_completion_full_generator( + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + reasoning_parser, + ) def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: @@ -1414,8 +1386,6 @@ async def chat_completion_full_generator( final_res = res except asyncio.CancelledError: return self.create_error_response("Client disconnected") - except ValueError as e: - return self.create_error_response(e) assert final_res is not None diff --git a/vllm/entrypoints/openai/completion/api_router.py b/vllm/entrypoints/openai/completion/api_router.py index 04dfdbccbef9..466c059aae94 100644 --- a/vllm/entrypoints/openai/completion/api_router.py +++ b/vllm/entrypoints/openai/completion/api_router.py @@ -54,10 +54,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): message="The model does not support Completions API" ) - try: - generator = await handler.create_completion(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) + generator = await handler.create_completion(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( @@ -91,10 +88,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request): message="The model does not support Completions API" ) - try: - result = await handler.render_completion_request(request) - except Exception as e: - result = handler.create_error_response(e) + result = await handler.render_completion_request(request) if isinstance(result, ErrorResponse): return JSONResponse(content=result.model_dump(), status_code=result.error.code) diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index c6534489fd34..27320cbd0eba 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -7,7 +7,6 @@ from collections.abc import Sequence as GenericSequence from typing import cast -import jinja2 from fastapi import Request from vllm.engine.protocol import EngineClient @@ -56,14 +55,12 @@ def __init__( return_tokens_as_token_ids: bool = False, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, - log_error_stack: bool = False, ): super().__init__( engine_client=engine_client, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details @@ -110,15 +107,11 @@ async def render_completion_request( "prompt_logprobs is not compatible with prompt embeds." ) - try: - engine_prompts = await self._preprocess_completion( - request, - prompt_input=request.prompt, - prompt_embeds=request.prompt_embeds, - ) - except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(e) + engine_prompts = await self._preprocess_completion( + request, + prompt_input=request.prompt, + prompt_embeds=request.prompt_embeds, + ) return engine_prompts @@ -149,11 +142,7 @@ async def create_completion( if raw_request: raw_request.state.request_metadata = request_metadata - try: - lora_request = self._maybe_get_adapters(request) - except (ValueError, TypeError, RuntimeError) as e: - logger.exception("Error preparing request components") - return self.create_error_response(e) + lora_request = self._maybe_get_adapters(request) # Extract data_parallel_rank from header (router can inject it) data_parallel_rank = self._get_data_parallel_rank(raw_request) @@ -161,64 +150,61 @@ async def create_completion( # Schedule the request and get the result generator. max_model_len = self.model_config.max_model_len generators: list[AsyncGenerator[RequestOutput, None]] = [] - try: - for i, engine_prompt in enumerate(engine_prompts): - max_tokens = get_max_tokens( - max_model_len, - request.max_tokens, - self._extract_prompt_len(engine_prompt), + for i, engine_prompt in enumerate(engine_prompts): + max_tokens = get_max_tokens( + max_model_len, + request.max_tokens, + self._extract_prompt_len(engine_prompt), + self.default_sampling_params, + self.override_max_tokens, + ) + + sampling_params: SamplingParams | BeamSearchParams + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + max_tokens, self.default_sampling_params + ) + else: + sampling_params = request.to_sampling_params( + max_tokens, self.default_sampling_params, - self.override_max_tokens, ) - sampling_params: SamplingParams | BeamSearchParams - if request.use_beam_search: - sampling_params = request.to_beam_search_params( - max_tokens, self.default_sampling_params - ) - else: - sampling_params = request.to_sampling_params( - max_tokens, - self.default_sampling_params, - ) + request_id_item = f"{request_id}-{i}" + + self._log_inputs( + request_id_item, + engine_prompt, + params=sampling_params, + lora_request=lora_request, + ) - request_id_item = f"{request_id}-{i}" + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) - self._log_inputs( - request_id_item, - engine_prompt, + if isinstance(sampling_params, BeamSearchParams): + generator = self.beam_search( + prompt=engine_prompt, + request_id=request_id, params=sampling_params, lora_request=lora_request, + trace_headers=trace_headers, ) - - trace_headers = ( - None - if raw_request is None - else await self._get_trace_headers(raw_request.headers) + else: + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + data_parallel_rank=data_parallel_rank, ) - if isinstance(sampling_params, BeamSearchParams): - generator = self.beam_search( - prompt=engine_prompt, - request_id=request_id, - params=sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - ) - else: - generator = self.engine_client.generate( - engine_prompt, - sampling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - data_parallel_rank=data_parallel_rank, - ) - - generators.append(generator) - except ValueError as e: - return self.create_error_response(e) + generators.append(generator) result_generator = merge_async_iterators(*generators) @@ -273,10 +259,6 @@ async def create_completion( ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") - except GenerationError as e: - return self._convert_generation_error_to_response(e) - except ValueError as e: - return self.create_error_response(e) # When user requests streaming but we don't stream, we still need to # return a streaming response with a single event. diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py index 6b5b714dc32a..f4e5fe733303 100644 --- a/vllm/entrypoints/openai/engine/protocol.py +++ b/vllm/entrypoints/openai/engine/protocol.py @@ -4,6 +4,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time +from http import HTTPStatus from typing import Any, ClassVar, Literal, TypeAlias import regex as re @@ -262,6 +263,14 @@ class DeltaMessage(OpenAIBaseModel): tool_calls: list[DeltaToolCall] = Field(default_factory=list) +class GenerationError(Exception): + """raised when finish_reason indicates internal server error (500)""" + + def __init__(self, message: str = "Internal server error"): + super().__init__(message) + self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + ####### Tokens IN <> Tokens OUT ####### class GenerateRequest(BaseModel): request_id: str = Field( diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index e864f562ee1e..44954ef9d55f 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import json -import sys import time -import traceback from collections.abc import AsyncGenerator, Callable, Mapping, Sequence from dataclasses import dataclass, field from http import HTTPStatus @@ -38,10 +36,10 @@ CompletionResponse, ) from vllm.entrypoints.openai.engine.protocol import ( - ErrorInfo, ErrorResponse, FunctionCall, FunctionDefinition, + GenerationError, ) from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.responses.context import ( @@ -89,7 +87,7 @@ TokenizeCompletionRequest, TokenizeResponse, ) -from vllm.entrypoints.utils import get_max_tokens, sanitize_message +from vllm.entrypoints.utils import create_error_response, get_max_tokens from vllm.exceptions import VLLMValidationError from vllm.inputs.data import ( ProcessorInputs, @@ -125,15 +123,6 @@ ) from vllm.utils.mistral import is_mistral_tokenizer - -class GenerationError(Exception): - """raised when finish_reason indicates internal server error (500)""" - - def __init__(self, message: str = "Internal server error"): - super().__init__(message) - self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR - - logger = init_logger(__name__) @@ -225,7 +214,6 @@ def __init__( *, request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, - log_error_stack: bool = False, ): super().__init__() @@ -236,8 +224,6 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids - self.log_error_stack = log_error_stack - self.model_config = engine_client.model_config self.renderer = engine_client.renderer self.io_processor = engine_client.io_processor @@ -526,133 +512,79 @@ async def _prepare_generators( """Schedule the request and get the result generator.""" generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - try: - trace_headers = ( - None - if ctx.raw_request is None - else await self._get_trace_headers(ctx.raw_request.headers) - ) + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) - pooling_params = self._create_pooling_params(ctx) - if isinstance(pooling_params, ErrorResponse): - return pooling_params + pooling_params = self._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params - if ctx.engine_prompts is None: - return self.create_error_response("Engine prompts not available") + if ctx.engine_prompts is None: + return self.create_error_response("Engine prompts not available") - for i, engine_prompt in enumerate(ctx.engine_prompts): - request_id_item = f"{ctx.request_id}-{i}" + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_id_item = f"{ctx.request_id}-{i}" - self._log_inputs( - request_id_item, - engine_prompt, - params=pooling_params, - lora_request=ctx.lora_request, - ) - - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=ctx.lora_request, - trace_headers=trace_headers, - priority=getattr(ctx.request, "priority", 0), - ) + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) - generators.append(generator) + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) - ctx.result_generator = merge_async_iterators(*generators) + generators.append(generator) - return None + ctx.result_generator = merge_async_iterators(*generators) - except Exception as e: - return self.create_error_response(e) + return None async def _collect_batch( self, ctx: ServeContext, ) -> ErrorResponse | None: """Collect batch results from the result generator.""" - try: - if ctx.engine_prompts is None: - return self.create_error_response("Engine prompts not available") + if ctx.engine_prompts is None: + return self.create_error_response("Engine prompts not available") - num_prompts = len(ctx.engine_prompts) - final_res_batch: list[PoolingRequestOutput | None] - final_res_batch = [None] * num_prompts + num_prompts = len(ctx.engine_prompts) + final_res_batch: list[PoolingRequestOutput | None] + final_res_batch = [None] * num_prompts - if ctx.result_generator is None: - return self.create_error_response("Result generator not available") + if ctx.result_generator is None: + return self.create_error_response("Result generator not available") - async for i, res in ctx.result_generator: - final_res_batch[i] = res - - if None in final_res_batch: - return self.create_error_response( - "Failed to generate results for all prompts" - ) + async for i, res in ctx.result_generator: + final_res_batch[i] = res - ctx.final_res_batch = [res for res in final_res_batch if res is not None] + if None in final_res_batch: + return self.create_error_response( + "Failed to generate results for all prompts" + ) - return None + ctx.final_res_batch = [res for res in final_res_batch if res is not None] - except Exception as e: - return self.create_error_response(e) + return None + @staticmethod def create_error_response( - self, message: str | Exception, err_type: str = "BadRequestError", status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, param: str | None = None, ) -> ErrorResponse: - exc: Exception | None = None - - if isinstance(message, Exception): - exc = message - - from vllm.exceptions import VLLMValidationError - - if isinstance(exc, VLLMValidationError): - err_type = "BadRequestError" - status_code = HTTPStatus.BAD_REQUEST - param = exc.parameter - elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): - # Common validation errors from user input - err_type = "BadRequestError" - status_code = HTTPStatus.BAD_REQUEST - param = None - elif isinstance(exc, NotImplementedError): - err_type = "NotImplementedError" - status_code = HTTPStatus.NOT_IMPLEMENTED - param = None - elif exc.__class__.__name__ == "TemplateError": - # jinja2.TemplateError (avoid importing jinja2) - err_type = "BadRequestError" - status_code = HTTPStatus.BAD_REQUEST - param = None - else: - err_type = "InternalServerError" - status_code = HTTPStatus.INTERNAL_SERVER_ERROR - param = None - - message = str(exc) - - if self.log_error_stack: - exc_type, _, _ = sys.exc_info() - if exc_type is not None: - traceback.print_exc() - else: - traceback.print_stack() - - return ErrorResponse( - error=ErrorInfo( - message=sanitize_message(message), - type=err_type, - code=status_code.value, - param=param, - ) - ) + return create_error_response(message, err_type, status_code, param) def create_streaming_error_response( self, @@ -680,16 +612,6 @@ def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None: ) raise GenerationError("Internal server error") - def _convert_generation_error_to_response( - self, e: GenerationError - ) -> ErrorResponse: - """Convert GenerationError to ErrorResponse.""" - return self.create_error_response( - str(e), - err_type="InternalServerError", - status_code=e.status_code, - ) - def _convert_generation_error_to_streaming_response( self, e: GenerationError ) -> str: diff --git a/vllm/entrypoints/openai/generate/api_router.py b/vllm/entrypoints/openai/generate/api_router.py index e4049331e811..5e4f184a0145 100644 --- a/vllm/entrypoints/openai/generate/api_router.py +++ b/vllm/entrypoints/openai/generate/api_router.py @@ -87,7 +87,6 @@ async def init_generate_state( enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, - log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None @@ -111,7 +110,6 @@ async def init_generate_state( enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, enable_log_deltas=args.enable_log_deltas, - log_error_stack=args.log_error_stack, ) if any(task in supported_tasks for task in ("generate", "render")) else None @@ -127,7 +125,6 @@ async def init_generate_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, - log_error_stack=args.log_error_stack, ) if any(task in supported_tasks for task in ("generate", "render")) else None @@ -156,7 +153,6 @@ async def init_generate_state( state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - log_error_stack=args.log_error_stack, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_log_outputs=args.enable_log_outputs, force_no_detokenize=args.tokens_only, diff --git a/vllm/entrypoints/openai/realtime/api_router.py b/vllm/entrypoints/openai/realtime/api_router.py index fb7decbd707a..c48191d14cd4 100644 --- a/vllm/entrypoints/openai/realtime/api_router.py +++ b/vllm/entrypoints/openai/realtime/api_router.py @@ -68,7 +68,6 @@ def init_realtime_state( engine_client, state.openai_serving_models, request_logger=request_logger, - log_error_stack=args.log_error_stack, ) if "realtime" in supported_tasks else None diff --git a/vllm/entrypoints/openai/realtime/serving.py b/vllm/entrypoints/openai/realtime/serving.py index d239968e75d2..5aead4d00f0b 100644 --- a/vllm/entrypoints/openai/realtime/serving.py +++ b/vllm/entrypoints/openai/realtime/serving.py @@ -33,13 +33,11 @@ def __init__( models: OpenAIServingModels, *, request_logger: RequestLogger | None, - log_error_stack: bool = False, ): super().__init__( engine_client=engine_client, models=models, request_logger=request_logger, - log_error_stack=log_error_stack, ) self.task_type: Literal["realtime"] = "realtime" diff --git a/vllm/entrypoints/openai/responses/api_router.py b/vllm/entrypoints/openai/responses/api_router.py index 62328c045df4..0c6b4a73801f 100644 --- a/vllm/entrypoints/openai/responses/api_router.py +++ b/vllm/entrypoints/openai/responses/api_router.py @@ -63,10 +63,8 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): return base_server.create_error_response( message="The model does not support Responses API" ) - try: - generator = await handler.create_responses(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) + + generator = await handler.create_responses(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( @@ -95,14 +93,11 @@ async def retrieve_responses( message="The model does not support Responses API" ) - try: - response = await handler.retrieve_responses( - response_id, - starting_after=starting_after, - stream=stream, - ) - except Exception as e: - response = handler.create_error_response(e) + response = await handler.retrieve_responses( + response_id, + starting_after=starting_after, + stream=stream, + ) if isinstance(response, ErrorResponse): return JSONResponse( @@ -125,10 +120,7 @@ async def cancel_responses(response_id: str, raw_request: Request): message="The model does not support Responses API" ) - try: - response = await handler.cancel_responses(response_id) - except Exception as e: - response = handler.create_error_response(e) + response = await handler.cancel_responses(response_id) if isinstance(response, ErrorResponse): return JSONResponse( diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 3cfb6fffc3ea..03a926d9ef51 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -11,7 +11,6 @@ from http import HTTPStatus from typing import Final -import jinja2 from fastapi import Request from openai.types.responses import ( ResponseContentPartAddedEvent, @@ -174,14 +173,12 @@ def __init__( enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, - log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - log_error_stack=log_error_stack, ) self.chat_template = chat_template @@ -365,28 +362,15 @@ async def create_responses( else: prev_response = None - try: - lora_request = self._maybe_get_adapters(request) - model_name = self.models.model_name(lora_request) - - if self.use_harmony: - messages, engine_prompts = self._make_request_with_harmony( - request, prev_response - ) - else: - messages, engine_prompts = await self._make_request( - request, prev_response - ) + lora_request = self._maybe_get_adapters(request) + model_name = self.models.model_name(lora_request) - except ( - ValueError, - TypeError, - RuntimeError, - jinja2.TemplateError, - NotImplementedError, - ) as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(e) + if self.use_harmony: + messages, engine_prompts = self._make_request_with_harmony( + request, prev_response + ) + else: + messages, engine_prompts = await self._make_request(request, prev_response) request_metadata = RequestResponseMetadata(request_id=request.request_id) if raw_request: @@ -424,86 +408,83 @@ async def create_responses( else: assert len(builtin_tool_list) == 0 available_tools = [] - try: - tokenizer = self.renderer.get_tokenizer() - - for engine_prompt in engine_prompts: - maybe_error = self._validate_generator_input(engine_prompt) - if maybe_error is not None: - return maybe_error - - default_max_tokens = get_max_tokens( - max_model_len, - request.max_output_tokens, - self._extract_prompt_len(engine_prompt), - self.default_sampling_params, - self.override_max_tokens, - ) + tokenizer = self.renderer.get_tokenizer() + + for engine_prompt in engine_prompts: + maybe_error = self._validate_generator_input(engine_prompt) + if maybe_error is not None: + return maybe_error + + default_max_tokens = get_max_tokens( + max_model_len, + request.max_output_tokens, + self._extract_prompt_len(engine_prompt), + self.default_sampling_params, + self.override_max_tokens, + ) - sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params - ) + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params + ) - trace_headers = ( - None - if raw_request is None - else await self._get_trace_headers(raw_request.headers) - ) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) - context: ConversationContext - if self.use_harmony: - if request.stream: - context = StreamingHarmonyContext(messages, available_tools) - else: - context = HarmonyContext(messages, available_tools) + context: ConversationContext + if self.use_harmony: + if request.stream: + context = StreamingHarmonyContext(messages, available_tools) else: - if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: - # This is a feature in development for parsing - # tokens during generation instead of at the end - context = ParsableContext( - response_messages=messages, - tokenizer=tokenizer, - reasoning_parser_cls=self.parser.reasoning_parser_cls - if self.parser - else None, - request=request, - tool_parser_cls=self.parser.tool_parser_cls - if self.parser - else None, - available_tools=available_tools, - chat_template=self.chat_template, - chat_template_content_format=self.chat_template_content_format, - ) - else: - context = SimpleContext() - - if self.parser and self.parser.reasoning_parser_cls is not None: - reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) - if ( - isinstance( - struct_out := sampling_params.structured_outputs, - StructuredOutputsParams, - ) - and struct_out.all_non_structural_tag_constraints_none() - ): - sampling_params.structured_outputs = replace( - struct_out, - structural_tag=reasoning_parser.prepare_structured_tag( - struct_out.structural_tag, self.tool_server - ), - ) - generator = self._generate_with_builtin_tools( - request_id=request.request_id, - engine_prompt=engine_prompt, - sampling_params=sampling_params, - context=context, - lora_request=lora_request, - priority=request.priority, - trace_headers=trace_headers, - ) - generators.append(generator) - except ValueError as e: - return self.create_error_response(e) + context = HarmonyContext(messages, available_tools) + else: + if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: + # This is a feature in development for parsing + # tokens during generation instead of at the end + context = ParsableContext( + response_messages=messages, + tokenizer=tokenizer, + reasoning_parser_cls=self.parser.reasoning_parser_cls + if self.parser + else None, + request=request, + tool_parser_cls=self.parser.tool_parser_cls + if self.parser + else None, + available_tools=available_tools, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + ) + else: + context = SimpleContext() + + if self.parser and self.parser.reasoning_parser_cls is not None: + reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) + if ( + isinstance( + struct_out := sampling_params.structured_outputs, + StructuredOutputsParams, + ) + and struct_out.all_non_structural_tag_constraints_none() + ): + sampling_params.structured_outputs = replace( + struct_out, + structural_tag=reasoning_parser.prepare_structured_tag( + struct_out.structural_tag, self.tool_server + ), + ) + generator = self._generate_with_builtin_tools( + request_id=request.request_id, + engine_prompt=engine_prompt, + sampling_params=sampling_params, + context=context, + lora_request=lora_request, + priority=request.priority, + trace_headers=trace_headers, + ) + generators.append(generator) assert len(generators) == 1 (result_generator,) = generators @@ -578,20 +559,15 @@ async def create_responses( request_metadata, ) - try: - return await self.responses_full_generator( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - ) - except GenerationError as e: - return self._convert_generation_error_to_response(e) - except Exception as e: - return self.create_error_response(e) + return await self.responses_full_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) async def _make_request( self, @@ -675,8 +651,6 @@ async def responses_full_generator( pass except asyncio.CancelledError: return self.create_error_response("Client disconnected") - except ValueError as e: - return self.create_error_response(e) # NOTE: Implementation of status is still WIP, but for now # we guarantee that if the status is not "completed", it is accurate. @@ -1129,16 +1103,11 @@ async def _run_background_request_stream( new_event_signal = asyncio.Event() self.event_store[request.request_id] = (event_deque, new_event_signal) response = None + generator = self.responses_stream_generator(request, *args, **kwargs) try: - generator = self.responses_stream_generator(request, *args, **kwargs) async for event in generator: event_deque.append(event) new_event_signal.set() # Signal new event available - except GenerationError as e: - response = self._convert_generation_error_to_response(e) - except Exception as e: - logger.exception("Background request failed for %s", request.request_id) - response = self.create_error_response(e) finally: new_event_signal.set() @@ -1157,13 +1126,7 @@ async def _run_background_request( *args, **kwargs, ): - try: - response = await self.responses_full_generator(request, *args, **kwargs) - except GenerationError as e: - response = self._convert_generation_error_to_response(e) - except Exception as e: - logger.exception("Background request failed for %s", request.request_id) - response = self.create_error_response(e) + response = await self.responses_full_generator(request, *args, **kwargs) if isinstance(response, ErrorResponse): # If the request has failed, update the status to "failed". diff --git a/vllm/entrypoints/openai/server_utils.py b/vllm/entrypoints/openai/server_utils.py index 12768cb6f97c..b21126472912 100644 --- a/vllm/entrypoints/openai/server_utils.py +++ b/vllm/entrypoints/openai/server_utils.py @@ -11,7 +11,7 @@ from http import HTTPStatus import pydantic -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from starlette.concurrency import iterate_in_threadpool @@ -20,11 +20,13 @@ from vllm import envs from vllm.engine.protocol import EngineClient +from vllm.entrypoints.launcher import terminate_if_errored from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse -from vllm.entrypoints.utils import sanitize_message +from vllm.entrypoints.utils import create_error_response, sanitize_message from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger from vllm.utils.gc_utils import freeze_gc_heap +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError logger = init_logger("vllm.entrypoints.openai.server_utils") @@ -309,7 +311,69 @@ async def log_response(request: Request, call_next): return response -async def http_exception_handler(_: Request, exc: HTTPException): +async def engine_error_handler( + req: Request, exc: EngineDeadError | EngineGenerateError +): + """ + VLLM V1 AsyncLLM catches exceptions and returns + only two types: EngineGenerateError and EngineDeadError. + + EngineGenerateError is raised by the per request generate() + method. This error could be request specific (and therefore + recoverable - e.g. if there is an error in input processing). + + EngineDeadError is raised by the background output_handler + method. This error is global and therefore not recoverable. + + We register these @app.exception_handlers to return nice + responses to the end user if they occur and shut down if needed. + See https://fastapi.tiangolo.com/tutorial/handling-errors/ + for more details on how exception handlers work. + + If an exception is encountered in a StreamingResponse + generator, the exception is not raised, since we already sent + a 200 status. Rather, we send an error message as the next chunk. + Since the exception is not raised, this means that the server + will not automatically shut down. Instead, we use the watchdog + background task for check for errored state. + """ + + if req.app.state.args.log_error_stack: + logger.exception( + "Engine Exception caught. Request id: %s", + req.state.request_metadata.request_id + if hasattr(req.state, "request_metadata") + else None, + ) + + terminate_if_errored( + server=req.app.state.server, + engine=req.app.state.engine_client, + ) + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + +async def exception_handler(req: Request, exc: Exception): + if req.app.state.args.log_error_stack: + logger.exception( + "Exception caught. Request id: %s", + req.state.request_metadata.request_id + if hasattr(req.state, "request_metadata") + else None, + ) + + err = create_error_response(exc) + return JSONResponse(err.model_dump(), status_code=err.error.code) + + +async def http_exception_handler(req: Request, exc: HTTPException): + if req.app.state.args.log_error_stack: + logger.exception( + "HTTPException caught. Request id: %s", + req.state.request_metadata.request_id + if hasattr(req.state, "request_metadata") + else None, + ) err = ErrorResponse( error=ErrorInfo( message=sanitize_message(exc.detail), @@ -320,7 +384,15 @@ async def http_exception_handler(_: Request, exc: HTTPException): return JSONResponse(err.model_dump(), status_code=exc.status_code) -async def validation_exception_handler(_: Request, exc: RequestValidationError): +async def validation_exception_handler(req: Request, exc: RequestValidationError): + if req.app.state.args.log_error_stack: + logger.exception( + "RequestValidationError caught. Request id: %s", + req.state.request_metadata.request_id + if hasattr(req.state, "request_metadata") + else None, + ) + param = None errors = exc.errors() for error in errors: diff --git a/vllm/entrypoints/openai/speech_to_text/api_router.py b/vllm/entrypoints/openai/speech_to_text/api_router.py index 7477b79c08b0..2c4f6bc9a1ce 100644 --- a/vllm/entrypoints/openai/speech_to_text/api_router.py +++ b/vllm/entrypoints/openai/speech_to_text/api_router.py @@ -71,10 +71,9 @@ async def create_transcriptions( ) audio_data = await request.file.read() - try: - generator = await handler.create_transcription(audio_data, request, raw_request) - except Exception as e: - return handler.create_error_response(e) + + generator = await handler.create_transcription(audio_data, request, raw_request) + if isinstance(generator, ErrorResponse): return JSONResponse( content=generator.model_dump(), status_code=generator.error.code @@ -108,10 +107,8 @@ async def create_translations( ) audio_data = await request.file.read() - try: - generator = await handler.create_translation(audio_data, request, raw_request) - except Exception as e: - return handler.create_error_response(e) + + generator = await handler.create_translation(audio_data, request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( @@ -140,7 +137,6 @@ def init_transcription_state( engine_client, state.openai_serving_models, request_logger=request_logger, - log_error_stack=args.log_error_stack, enable_force_include_usage=args.enable_force_include_usage, ) if "transcription" in supported_tasks @@ -151,7 +147,6 @@ def init_transcription_state( engine_client, state.openai_serving_models, request_logger=request_logger, - log_error_stack=args.log_error_stack, enable_force_include_usage=args.enable_force_include_usage, ) if "transcription" in supported_tasks diff --git a/vllm/entrypoints/openai/speech_to_text/serving.py b/vllm/entrypoints/openai/speech_to_text/serving.py index b5ce17d0ef79..28e798a986f7 100644 --- a/vllm/entrypoints/openai/speech_to_text/serving.py +++ b/vllm/entrypoints/openai/speech_to_text/serving.py @@ -40,7 +40,6 @@ def __init__( *, request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, - log_error_stack: bool = False, enable_force_include_usage: bool = False, ): super().__init__( @@ -49,7 +48,6 @@ def __init__( request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, task_type="transcribe", - log_error_stack=log_error_stack, enable_force_include_usage=enable_force_include_usage, ) @@ -113,7 +111,6 @@ def __init__( *, request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, - log_error_stack: bool = False, enable_force_include_usage: bool = False, ): super().__init__( @@ -122,7 +119,6 @@ def __init__( request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, task_type="translate", - log_error_stack=log_error_stack, enable_force_include_usage=enable_force_include_usage, ) diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index 1c56f092029d..7f12892f4060 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -97,7 +97,6 @@ def __init__( request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, task_type: Literal["transcribe", "translate"] = "transcribe", - log_error_stack: bool = False, enable_force_include_usage: bool = False, ): super().__init__( @@ -105,7 +104,6 @@ def __init__( models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - log_error_stack=log_error_stack, ) self.default_sampling_params = self.model_config.get_diff_sampling_param() @@ -517,69 +515,61 @@ async def _create_speech_to_text( if raw_request: raw_request.state.request_metadata = request_metadata - try: - lora_request = self._maybe_get_adapters(request) - - engine_prompts, duration_s = await self._preprocess_speech_to_text( - request=request, - audio_data=audio_data, - request_id=request_id, - ) + lora_request = self._maybe_get_adapters(request) - except ValueError as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(e) + engine_prompts, duration_s = await self._preprocess_speech_to_text( + request=request, + audio_data=audio_data, + request_id=request_id, + ) # Schedule the request and get the result generator. max_model_len = self.model_config.max_model_len list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None - try: - # Unlike most decoder-only models, whisper generation length is not - # constrained by the size of the input audio, which is mapped to a - # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be - # generated by respecting the extra completion tokens arg. - max_tokens = get_max_tokens( - max_model_len, - request.max_completion_tokens, - 0, - self.default_sampling_params, - ) + # Unlike most decoder-only models, whisper generation length is not + # constrained by the size of the input audio, which is mapped to a + # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be + # generated by respecting the extra completion tokens arg. + max_tokens = get_max_tokens( + max_model_len, + request.max_completion_tokens, + 0, + self.default_sampling_params, + ) - sampling_params = request.to_sampling_params( - max_tokens, - self.default_sampling_params, + sampling_params = request.to_sampling_params( + max_tokens, + self.default_sampling_params, + ) + if request.response_format == "verbose_json": + sampling_params.logprobs = 1 + + list_result_generator = [] + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}_{i}" + + self._log_inputs( + request_id_item, + engine_prompt, + params=sampling_params, + lora_request=lora_request, ) - if request.response_format == "verbose_json": - sampling_params.logprobs = 1 - - list_result_generator = [] - for i, engine_prompt in enumerate(engine_prompts): - request_id_item = f"{request_id}_{i}" - - self._log_inputs( - request_id_item, - engine_prompt, - params=sampling_params, - lora_request=lora_request, - ) - trace_headers = ( - None - if raw_request is None - else await self._get_trace_headers(raw_request.headers) - ) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) - generator = self.engine_client.generate( - engine_prompt, - sampling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - ) + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + ) - list_result_generator.append(generator) - except ValueError as e: - return self.create_error_response(e) + list_result_generator.append(generator) if request.stream: return stream_generator_method( @@ -663,8 +653,6 @@ async def _create_speech_to_text( return final_response except asyncio.CancelledError: return self.create_error_response("Client disconnected") - except ValueError as e: - return self.create_error_response(e) async def _speech_to_text_stream_generator( self, diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index 3ba131d5f831..8de8338f552d 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -72,7 +72,6 @@ def init_pooling_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, ) ) if any(t in supported_tasks for t in POOLING_TASKS) @@ -86,7 +85,6 @@ def init_pooling_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, ) if "embed" in supported_tasks else None @@ -99,7 +97,6 @@ def init_pooling_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, ) if "classify" in supported_tasks else None @@ -114,7 +111,6 @@ def init_pooling_state( state.openai_serving_models, request_logger=request_logger, score_template=resolved_chat_template, - log_error_stack=args.log_error_stack, use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False), ) if any(t in supported_tasks for t in ("embed", "score", "token_embed")) diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py index 813282d3d13f..a3a5682aa540 100644 --- a/vllm/entrypoints/pooling/base/serving.py +++ b/vllm/entrypoints/pooling/base/serving.py @@ -41,7 +41,6 @@ from vllm.utils import random_uuid from vllm.utils.async_utils import merge_async_iterators -from ...utils import create_error_response from .io_processor import PoolingIOProcessor PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest) @@ -112,34 +111,25 @@ async def __call__( request: AnyPoolingRequest, raw_request: Request, ) -> JSONResponse: - try: - model_name = self.models.model_name() - request_id = ( - f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" - ) + model_name = self.models.model_name() + request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" - await self._check_model(request) + await self._check_model(request) - ctx = PoolingServeContext( - request=request, - raw_request=raw_request, - model_name=model_name, - request_id=request_id, - ) + ctx = PoolingServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + ) - self._validate_request(ctx) - self._maybe_get_adapters(ctx) - await self._preprocess(ctx) - await self._prepare_generators(ctx) - await self._collect_batch(ctx) - response = await self._build_response(ctx) - return JSONResponse(content=response.model_dump()) - except Exception as e: - error_response = create_error_response(e) - return JSONResponse( - content=error_response.model_dump(), - status_code=error_response.error.code, - ) + self._validate_request(ctx) + self._maybe_get_adapters(ctx) + await self._preprocess(ctx) + await self._prepare_generators(ctx) + await self._collect_batch(ctx) + response = await self._build_response(ctx) + return JSONResponse(content=response.model_dump()) async def _preprocess( self, diff --git a/vllm/entrypoints/pooling/embed/api_router.py b/vllm/entrypoints/pooling/embed/api_router.py index f77c07069288..1c9347d37820 100644 --- a/vllm/entrypoints/pooling/embed/api_router.py +++ b/vllm/entrypoints/pooling/embed/api_router.py @@ -61,10 +61,7 @@ async def create_embedding( message="The model does not support Embeddings API" ) - try: - generator = await handler.create_embedding(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) + generator = await handler.create_embedding(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index de4dca623503..d15209ede093 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -54,13 +54,11 @@ def __init__( chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, trust_request_chat_template: bool = False, - log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, models=models, request_logger=request_logger, - log_error_stack=log_error_stack, ) self.chat_template = chat_template @@ -75,38 +73,34 @@ async def _preprocess( self, ctx: EmbeddingServeContext, ) -> ErrorResponse | None: - try: - ctx.lora_request = self._maybe_get_adapters(ctx.request) - - if isinstance(ctx.request, EmbeddingChatRequest): - error_check_ret = self._validate_chat_template( - request_chat_template=ctx.request.chat_template, - chat_template_kwargs=ctx.request.chat_template_kwargs, - trust_request_chat_template=self.trust_request_chat_template, - ) - if error_check_ret is not None: - return error_check_ret - - _, ctx.engine_prompts = await self._preprocess_chat( - ctx.request, - ctx.request.messages, - default_template=self.chat_template, - default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, - ) - elif isinstance(ctx.request, EmbeddingCompletionRequest): - ctx.engine_prompts = await self._preprocess_completion( - ctx.request, - prompt_input=ctx.request.input, - prompt_embeds=None, - ) - else: - return self.create_error_response("Invalid classification request type") + ctx.lora_request = self._maybe_get_adapters(ctx.request) + + if isinstance(ctx.request, EmbeddingChatRequest): + error_check_ret = self._validate_chat_template( + request_chat_template=ctx.request.chat_template, + chat_template_kwargs=ctx.request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret + + _, ctx.engine_prompts = await self._preprocess_chat( + ctx.request, + ctx.request.messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, + ) + elif isinstance(ctx.request, EmbeddingCompletionRequest): + ctx.engine_prompts = await self._preprocess_completion( + ctx.request, + prompt_input=ctx.request.input, + prompt_embeds=None, + ) + else: + return self.create_error_response("Invalid classification request type") - return None - except (ValueError, TypeError) as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + return None def request_output_to_embed_json_response( self, @@ -397,51 +391,47 @@ async def _prepare_generators( # Custom logic for chunked processing generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - try: - trace_headers = ( - None - if ctx.raw_request is None - else await self._get_trace_headers(ctx.raw_request.headers) - ) - - pooling_params = self._create_pooling_params(ctx) - if isinstance(pooling_params, ErrorResponse): - return pooling_params + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) - if ctx.engine_prompts is None: - return self.create_error_response("Engine prompts not available") + pooling_params = self._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params - max_pos_embeddings = self._get_max_position_embeddings() + if ctx.engine_prompts is None: + return self.create_error_response("Engine prompts not available") - for i, engine_prompt in enumerate(ctx.engine_prompts): - # Check if this specific prompt needs chunked processing - if "prompt_token_ids" in engine_prompt: - prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item] - - if len(prompt_token_ids) > max_pos_embeddings: - # Use chunked processing for this prompt - chunk_generators = await self._process_chunked_request( - ctx, - prompt_token_ids, - pooling_params, - trace_headers, - i, - ) - generators.extend(chunk_generators) - continue + max_pos_embeddings = self._get_max_position_embeddings() - # Normal processing for short prompts or non-token prompts - generator = await self._create_single_prompt_generator( - ctx, engine_prompt, pooling_params, trace_headers, i - ) - generators.append(generator) + for i, engine_prompt in enumerate(ctx.engine_prompts): + # Check if this specific prompt needs chunked processing + if "prompt_token_ids" in engine_prompt: + prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item] + + if len(prompt_token_ids) > max_pos_embeddings: + # Use chunked processing for this prompt + chunk_generators = await self._process_chunked_request( + ctx, + prompt_token_ids, + pooling_params, + trace_headers, + i, + ) + generators.extend(chunk_generators) + continue - ctx.result_generator = merge_async_iterators(*generators) + # Normal processing for short prompts or non-token prompts + generator = await self._create_single_prompt_generator( + ctx, engine_prompt, pooling_params, trace_headers, i + ) + generators.append(generator) - return None + ctx.result_generator = merge_async_iterators(*generators) - except Exception as e: - return self.create_error_response(e) + return None async def _collect_batch( self, @@ -454,164 +444,157 @@ async def _collect_batch( minimize memory usage. For regular requests, collects results normally. """ - try: - if ctx.engine_prompts is None: - return self.create_error_response("Engine prompts not available") - - # Check if we used chunked processing - use_chunked = self._should_use_chunked_processing(ctx.request) - - if not use_chunked: - return await super()._collect_batch(ctx=ctx) - - if ctx.result_generator is None: - return self.create_error_response("Result generator not available") - - # Online aggregation for chunked requests to - # minimize memory usage - # Track aggregation state for each prompt - prompt_aggregators: dict[int, dict[str, Any]] = {} - short_prompts_results: dict[int, PoolingRequestOutput] = {} - - async for result_idx, result in ctx.result_generator: - if "-chunk-" in result.request_id: - # Extract prompt_idx from chunked request_id - parts = result.request_id.split("-") - try: - prompt_idx = int(parts[parts.index("prompt") + 1]) - except (ValueError, IndexError): - # Fallback: extract from result_idx if parsing fails - prompt_idx = result_idx - - # Initialize aggregator for this prompt if needed - if prompt_idx not in prompt_aggregators: - prompt_aggregators[prompt_idx] = { - "weighted_sum": None, - "total_weight": 0, - "chunk_count": 0, - "request_id": result.request_id.split("-chunk-")[0], - } - - aggregator = prompt_aggregators[prompt_idx] - - # MEAN pooling with online weighted averaging - # Ensure result is PoolingRequestOutput - # for embedding processing - if not isinstance(result, PoolingRequestOutput): - return self.create_error_response( - f"Expected PoolingRequestOutput for " - f"chunked embedding, got " - f"{type(result).__name__}" - ) + if ctx.engine_prompts is None: + return self.create_error_response("Engine prompts not available") - # Handle both PoolingOutput and - # EmbeddingOutput types - if hasattr(result.outputs, "data"): - # PoolingOutput case - embedding_data = result.outputs.data - elif hasattr(result.outputs, "embedding"): - # EmbeddingOutput case - - # convert embedding list to tensor - embedding_data = result.outputs.embedding - else: - return self.create_error_response( - f"Unsupported output type: {type(result.outputs).__name__}" - ) + # Check if we used chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) - if not isinstance(embedding_data, torch.Tensor): - embedding_data = torch.tensor( - embedding_data, dtype=torch.float32 - ) + if not use_chunked: + return await super()._collect_batch(ctx=ctx) + + if ctx.result_generator is None: + return self.create_error_response("Result generator not available") + + # Online aggregation for chunked requests to + # minimize memory usage + # Track aggregation state for each prompt + prompt_aggregators: dict[int, dict[str, Any]] = {} + short_prompts_results: dict[int, PoolingRequestOutput] = {} + + async for result_idx, result in ctx.result_generator: + if "-chunk-" in result.request_id: + # Extract prompt_idx from chunked request_id + parts = result.request_id.split("-") + try: + prompt_idx = int(parts[parts.index("prompt") + 1]) + except (ValueError, IndexError): + # Fallback: extract from result_idx if parsing fails + prompt_idx = result_idx + + # Initialize aggregator for this prompt if needed + if prompt_idx not in prompt_aggregators: + prompt_aggregators[prompt_idx] = { + "weighted_sum": None, + "total_weight": 0, + "chunk_count": 0, + "request_id": result.request_id.split("-chunk-")[0], + } - if result.prompt_token_ids is None: - return self.create_error_response( - "prompt_token_ids cannot be None for chunked processing" - ) - weight = len(result.prompt_token_ids) + aggregator = prompt_aggregators[prompt_idx] + + # MEAN pooling with online weighted averaging + # Ensure result is PoolingRequestOutput + # for embedding processing + if not isinstance(result, PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}" + ) - weighted_embedding = embedding_data.to(dtype=torch.float32) * weight + # Handle both PoolingOutput and + # EmbeddingOutput types + if hasattr(result.outputs, "data"): + # PoolingOutput case + embedding_data = result.outputs.data + elif hasattr(result.outputs, "embedding"): + # EmbeddingOutput case - + # convert embedding list to tensor + embedding_data = result.outputs.embedding + else: + return self.create_error_response( + f"Unsupported output type: {type(result.outputs).__name__}" + ) - if aggregator["weighted_sum"] is None: - # First chunk - aggregator["weighted_sum"] = weighted_embedding - else: - # Accumulate - aggregator["weighted_sum"] += weighted_embedding + if not isinstance(embedding_data, torch.Tensor): + embedding_data = torch.tensor(embedding_data, dtype=torch.float32) - aggregator["total_weight"] += weight - aggregator["chunk_count"] += 1 + if result.prompt_token_ids is None: + return self.create_error_response( + "prompt_token_ids cannot be None for chunked processing" + ) + weight = len(result.prompt_token_ids) + + weighted_embedding = embedding_data.to(dtype=torch.float32) * weight + + if aggregator["weighted_sum"] is None: + # First chunk + aggregator["weighted_sum"] = weighted_embedding else: - # Non-chunked result - extract prompt_idx from request_id - parts = result.request_id.split("-") - try: - # Last part should be prompt index - prompt_idx = int(parts[-1]) - except (ValueError, IndexError): - prompt_idx = result_idx # Fallback to result_idx - - short_prompts_results[prompt_idx] = result - - # Finalize aggregated results - final_res_batch: list[PoolingRequestOutput] = [] - num_prompts = len(ctx.engine_prompts) - - for prompt_idx in range(num_prompts): - if prompt_idx in prompt_aggregators: - # Finalize MEAN aggregation for this chunked prompt - aggregator = prompt_aggregators[prompt_idx] - - weighted_sum = aggregator["weighted_sum"] - total_weight = aggregator["total_weight"] - - if ( - weighted_sum is not None - and isinstance(weighted_sum, torch.Tensor) - and isinstance(total_weight, (int, float)) - and total_weight > 0 - ): - # Compute final mean embedding - final_embedding = weighted_sum / total_weight - - # Create a PoolingRequestOutput - # for the aggregated result - pooling_output_data = PoolingOutput(data=final_embedding) - - # Get original prompt token IDs for this prompt - original_prompt = ctx.engine_prompts[prompt_idx] - if "prompt_token_ids" not in original_prompt: - return self.create_error_response( - f"Chunked prompt {prompt_idx} does not contain " - "token IDs" - ) - - original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item] - - pooling_request_output = PoolingRequestOutput( - request_id=aggregator["request_id"], - prompt_token_ids=original_token_ids, - outputs=pooling_output_data, - num_cached_tokens=0, - finished=True, - ) + # Accumulate + aggregator["weighted_sum"] += weighted_embedding - final_res_batch.append(pooling_request_output) - else: + aggregator["total_weight"] += weight + aggregator["chunk_count"] += 1 + else: + # Non-chunked result - extract prompt_idx from request_id + parts = result.request_id.split("-") + try: + # Last part should be prompt index + prompt_idx = int(parts[-1]) + except (ValueError, IndexError): + prompt_idx = result_idx # Fallback to result_idx + + short_prompts_results[prompt_idx] = result + + # Finalize aggregated results + final_res_batch: list[PoolingRequestOutput] = [] + num_prompts = len(ctx.engine_prompts) + + for prompt_idx in range(num_prompts): + if prompt_idx in prompt_aggregators: + # Finalize MEAN aggregation for this chunked prompt + aggregator = prompt_aggregators[prompt_idx] + + weighted_sum = aggregator["weighted_sum"] + total_weight = aggregator["total_weight"] + + if ( + weighted_sum is not None + and isinstance(weighted_sum, torch.Tensor) + and isinstance(total_weight, (int, float)) + and total_weight > 0 + ): + # Compute final mean embedding + final_embedding = weighted_sum / total_weight + + # Create a PoolingRequestOutput + # for the aggregated result + pooling_output_data = PoolingOutput(data=final_embedding) + + # Get original prompt token IDs for this prompt + original_prompt = ctx.engine_prompts[prompt_idx] + if "prompt_token_ids" not in original_prompt: return self.create_error_response( - f"Failed to aggregate chunks for prompt {prompt_idx}" + f"Chunked prompt {prompt_idx} does not contain token IDs" ) - elif prompt_idx in short_prompts_results: - final_res_batch.append(short_prompts_results[prompt_idx]) + + original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item] + + pooling_request_output = PoolingRequestOutput( + request_id=aggregator["request_id"], + prompt_token_ids=original_token_ids, + outputs=pooling_output_data, + num_cached_tokens=0, + finished=True, + ) + + final_res_batch.append(pooling_request_output) else: return self.create_error_response( - f"Result not found for prompt {prompt_idx}" + f"Failed to aggregate chunks for prompt {prompt_idx}" ) + elif prompt_idx in short_prompts_results: + final_res_batch.append(short_prompts_results[prompt_idx]) + else: + return self.create_error_response( + f"Result not found for prompt {prompt_idx}" + ) - ctx.final_res_batch = final_res_batch - - return None + ctx.final_res_batch = final_res_batch - except Exception as e: - return self.create_error_response(e) + return None async def create_embedding( self, diff --git a/vllm/entrypoints/pooling/pooling/api_router.py b/vllm/entrypoints/pooling/pooling/api_router.py index 6084e724dac6..538ce8dad9b3 100644 --- a/vllm/entrypoints/pooling/pooling/api_router.py +++ b/vllm/entrypoints/pooling/pooling/api_router.py @@ -41,10 +41,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): return base_server.create_error_response( message="The model does not support Pooling API" ) - try: - generator = await handler.create_pooling(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) + + generator = await handler.create_pooling(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index f27a27191f99..bcd331b01435 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -8,7 +8,6 @@ from functools import partial from typing import Final, Literal, cast -import jinja2 from fastapi import Request from typing_extensions import assert_never @@ -53,13 +52,11 @@ def __init__( chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, trust_request_chat_template: bool = False, - log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, models=models, request_logger=request_logger, - log_error_stack=log_error_stack, ) self.chat_template = chat_template @@ -84,101 +81,92 @@ async def create_pooling( request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) - try: - lora_request = self._maybe_get_adapters(request) + lora_request = self._maybe_get_adapters(request) - if getattr(request, "dimensions", None) is not None: - return self.create_error_response( - "dimensions is currently not supported" - ) + if getattr(request, "dimensions", None) is not None: + return self.create_error_response("dimensions is currently not supported") - engine_prompts: Sequence[ProcessorInputs] - if use_io_processor := isinstance(request, IOProcessorRequest): - if self.io_processor is None: - raise ValueError( - "No IOProcessor plugin installed. Please refer " - "to the documentation and to the " - "'prithvi_geospatial_mae_io_processor' " - "offline inference example for more details." - ) + engine_prompts: Sequence[ProcessorInputs] + if use_io_processor := isinstance(request, IOProcessorRequest): + if self.io_processor is None: + raise ValueError( + "No IOProcessor plugin installed. Please refer " + "to the documentation and to the " + "'prithvi_geospatial_mae_io_processor' " + "offline inference example for more details." + ) - validated_prompt = self.io_processor.parse_data(request.data) + validated_prompt = self.io_processor.parse_data(request.data) - raw_prompts = await self.io_processor.pre_process_async( - prompt=validated_prompt, request_id=request_id - ) - engine_prompts = await self._preprocess_cmpl( - request, - prompt_to_seq(raw_prompts), - ) - elif isinstance(request, PoolingChatRequest): - error_check_ret = self._validate_chat_template( - request_chat_template=request.chat_template, - chat_template_kwargs=request.chat_template_kwargs, - trust_request_chat_template=self.trust_request_chat_template, - ) - if error_check_ret is not None: - return error_check_ret - - _, engine_prompts = await self._preprocess_chat( - request, - request.messages, - default_template=self.chat_template, - default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, - ) - elif isinstance(request, PoolingCompletionRequest): - engine_prompts = await self._preprocess_completion( - request, - prompt_input=request.input, - prompt_embeds=None, - ) - else: - raise ValueError(f"Unsupported request of type {type(request)}") - except (ValueError, TypeError, jinja2.TemplateError) as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) + raw_prompts = await self.io_processor.pre_process_async( + prompt=validated_prompt, request_id=request_id + ) + engine_prompts = await self._preprocess_cmpl( + request, + prompt_to_seq(raw_prompts), + ) + elif isinstance(request, PoolingChatRequest): + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret + + _, engine_prompts = await self._preprocess_chat( + request, + request.messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, + ) + elif isinstance(request, PoolingCompletionRequest): + engine_prompts = await self._preprocess_completion( + request, + prompt_input=request.input, + prompt_embeds=None, + ) + else: + raise ValueError(f"Unsupported request of type {type(request)}") # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - try: - if use_io_processor: - assert self.io_processor is not None - - pooling_params = self.io_processor.merge_pooling_params() - if pooling_params.task is None: - pooling_params.task = "plugin" - else: - pooling_params = request.to_pooling_params() # type: ignore - - for i, engine_prompt in enumerate(engine_prompts): - request_id_item = f"{request_id}-{i}" - - self._log_inputs( - request_id_item, - engine_prompt, - params=pooling_params, - lora_request=lora_request, - ) + if use_io_processor: + assert self.io_processor is not None - trace_headers = ( - None - if raw_request is None - else await self._get_trace_headers(raw_request.headers) - ) + pooling_params = self.io_processor.merge_pooling_params() + if pooling_params.task is None: + pooling_params.task = "plugin" + else: + pooling_params = request.to_pooling_params() # type: ignore - generator = self.engine_client.encode( - engine_prompt, - pooling_params, - request_id_item, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - ) + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=lora_request, + ) + + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) - generators.append(generator) - except ValueError as e: - return self.create_error_response(e) + generators.append(generator) result_generator = merge_async_iterators(*generators) @@ -233,8 +221,6 @@ async def create_pooling( ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") - except ValueError as e: - return self.create_error_response(e) return response diff --git a/vllm/entrypoints/pooling/score/api_router.py b/vllm/entrypoints/pooling/score/api_router.py index ef64ba45ebd7..c71b67ff08fe 100644 --- a/vllm/entrypoints/pooling/score/api_router.py +++ b/vllm/entrypoints/pooling/score/api_router.py @@ -49,10 +49,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): message="The model does not support Score API" ) - try: - generator = await handler.create_score(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) + generator = await handler.create_score(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( @@ -100,10 +97,8 @@ async def do_rerank(request: RerankRequest, raw_request: Request): return base_server.create_error_response( message="The model does not support Rerank (Score) API" ) - try: - generator = await handler.do_rerank(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) + + generator = await handler.do_rerank(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index 60d6db6a7003..a30942097fd9 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -62,7 +62,6 @@ def __init__( engine_client=engine_client, models=models, request_logger=request_logger, - log_error_stack=log_error_stack, ) self.score_template = score_template self.use_gpu_for_pooling_score = use_gpu_for_pooling_score @@ -518,8 +517,6 @@ async def create_score( ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") - except ValueError as e: - return self.create_error_response(e) async def do_rerank( self, request: RerankRequest, raw_request: Request | None = None @@ -562,8 +559,6 @@ async def do_rerank( ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") - except ValueError as e: - return self.create_error_response(e) def request_output_to_score_response( self, diff --git a/vllm/entrypoints/serve/disagg/api_router.py b/vllm/entrypoints/serve/disagg/api_router.py index 9966ba47be06..a9c6d3cdcbb7 100644 --- a/vllm/entrypoints/serve/disagg/api_router.py +++ b/vllm/entrypoints/serve/disagg/api_router.py @@ -64,10 +64,8 @@ async def generate(request: GenerateRequest, raw_request: Request): return tokenization(raw_request).create_error_response( message="The model does not support generate tokens API" ) - try: - generator = await handler.serve_tokens(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) + + generator = await handler.serve_tokens(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( diff --git a/vllm/entrypoints/serve/disagg/serving.py b/vllm/entrypoints/serve/disagg/serving.py index f004e5269830..322314907dd8 100644 --- a/vllm/entrypoints/serve/disagg/serving.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -49,7 +49,6 @@ def __init__( request_logger: RequestLogger | None, force_no_detokenize: bool = False, return_tokens_as_token_ids: bool = False, - log_error_stack: bool = False, enable_prompt_tokens_details: bool = False, enable_log_outputs: bool = False, ): @@ -58,7 +57,6 @@ def __init__( models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_log_outputs = enable_log_outputs @@ -108,45 +106,38 @@ async def serve_tokens( # Schedule the request and get the result generator. result_generator: AsyncGenerator[RequestOutput, None] | None = None - try: - sampling_params = request.sampling_params - if self.force_no_detokenize: - sampling_params.detokenize = False - - self._log_inputs( - request_id, - engine_prompt, - params=sampling_params, - lora_request=lora_request, - ) - - trace_headers = ( - None - if raw_request is None - else await self._get_trace_headers(raw_request.headers) - ) + sampling_params = request.sampling_params + if self.force_no_detokenize: + sampling_params.detokenize = False + + self._log_inputs( + request_id, + engine_prompt, + params=sampling_params, + lora_request=lora_request, + ) - result_generator = self.engine_client.generate( - engine_prompt, - sampling_params, - request_id, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - ) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) - except ValueError as e: - return self.create_error_response(str(e)) + result_generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) # TODO(NickLucche): Implement streaming response - try: - assert result_generator is not None - return await self.serve_tokens_full_generator( - request, result_generator, request_id, model_name, request_metadata - ) - except ValueError as e: - return self.create_error_response(str(e)) + assert result_generator is not None + return await self.serve_tokens_full_generator( + request, result_generator, request_id, model_name, request_metadata + ) async def serve_tokens_full_generator( self, @@ -165,8 +156,6 @@ async def serve_tokens_full_generator( final_res = res except asyncio.CancelledError: return self.create_error_response("Client disconnected") - except ValueError as e: - return self.create_error_response(str(e)) assert final_res is not None diff --git a/vllm/entrypoints/serve/tokenize/api_router.py b/vllm/entrypoints/serve/tokenize/api_router.py index 333acbca1077..d165b555385d 100644 --- a/vllm/entrypoints/serve/tokenize/api_router.py +++ b/vllm/entrypoints/serve/tokenize/api_router.py @@ -49,10 +49,7 @@ def tokenization(request: Request) -> OpenAIServingTokenization: async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) - try: - generator = await handler.create_tokenize(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) + generator = await handler.create_tokenize(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse( diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index 55d7ea827c57..77ce2787c54b 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Any, Final -import jinja2 from fastapi import Request from vllm.engine.protocol import EngineClient @@ -37,13 +36,11 @@ def __init__( chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, trust_request_chat_template: bool = False, - log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, models=models, request_logger=request_logger, - log_error_stack=log_error_stack, ) self.chat_template = chat_template @@ -61,40 +58,36 @@ async def create_tokenize( request_id = f"tokenize-{self._base_request_id(raw_request)}" - try: - lora_request = self._maybe_get_adapters(request) - - if isinstance(request, TokenizeChatRequest): - tool_dicts = ( - None - if request.tools is None - else [tool.model_dump() for tool in request.tools] - ) - error_check_ret = self._validate_chat_template( - request_chat_template=request.chat_template, - chat_template_kwargs=request.chat_template_kwargs, - trust_request_chat_template=self.trust_request_chat_template, - ) - if error_check_ret is not None: - return error_check_ret - - _, engine_prompts = await self._preprocess_chat( - request, - request.messages, - default_template=self.chat_template, - default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, - tool_dicts=tool_dicts, - ) - else: - engine_prompts = await self._preprocess_completion( - request, - prompt_input=request.prompt, - prompt_embeds=None, - ) - except (ValueError, TypeError, jinja2.TemplateError) as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(f"{e} {e.__cause__}") + lora_request = self._maybe_get_adapters(request) + + if isinstance(request, TokenizeChatRequest): + tool_dicts = ( + None + if request.tools is None + else [tool.model_dump() for tool in request.tools] + ) + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret + + _, engine_prompts = await self._preprocess_chat( + request, + request.messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, + tool_dicts=tool_dicts, + ) + else: + engine_prompts = await self._preprocess_completion( + request, + prompt_input=request.prompt, + prompt_embeds=None, + ) input_ids: list[int] = [] for engine_prompt in engine_prompts: @@ -152,12 +145,9 @@ async def get_tokenizer_info( self, ) -> TokenizerInfoResponse | ErrorResponse: """Get comprehensive tokenizer information.""" - try: - tokenizer = self.renderer.get_tokenizer() - info = TokenizerInfo(tokenizer, self.chat_template).to_dict() - return TokenizerInfoResponse(**info) - except Exception as e: - return self.create_error_response(f"Failed to get tokenizer info: {str(e)}") + tokenizer = self.renderer.get_tokenizer() + info = TokenizerInfo(tokenizer, self.chat_template).to_dict() + return TokenizerInfoResponse(**info) @dataclass diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 6390a72ce0e1..40d58e1a7fa1 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -5,13 +5,10 @@ import dataclasses import functools import os -import sys -import traceback from argparse import Namespace from http import HTTPStatus from logging import Logger from string import Template -from typing import TYPE_CHECKING import regex as re from fastapi import Request @@ -20,24 +17,17 @@ from vllm import envs from vllm.engine.arg_utils import EngineArgs -from vllm.exceptions import VLLMValidationError +from vllm.entrypoints.openai.engine.protocol import ( + ErrorInfo, + ErrorResponse, + GenerationError, + StreamOptions, +) +from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.logger import current_formatter_type, init_logger from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser -if TYPE_CHECKING: - from vllm.entrypoints.openai.engine.protocol import ( - ErrorInfo, - ErrorResponse, - StreamOptions, - ) - from vllm.entrypoints.openai.models.protocol import LoRAModulePath -else: - ErrorResponse = object - ErrorInfo = object - LoRAModulePath = object - StreamOptions = object - logger = init_logger(__name__) VLLM_SUBCMD_PARSER_EPILOG = ( @@ -307,20 +297,19 @@ def create_error_response( err_type: str = "BadRequestError", status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, param: str | None = None, - log_error_stack: bool = False, -) -> "ErrorResponse": +) -> ErrorResponse: exc: Exception | None = None - from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse - if isinstance(message, Exception): exc = message + from vllm.exceptions import VLLMValidationError + if isinstance(exc, VLLMValidationError): err_type = "BadRequestError" status_code = HTTPStatus.BAD_REQUEST param = exc.parameter - elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): + elif isinstance(exc, (ValueError, TypeError, OverflowError)): # Common validation errors from user input err_type = "BadRequestError" status_code = HTTPStatus.BAD_REQUEST @@ -329,6 +318,10 @@ def create_error_response( err_type = "NotImplementedError" status_code = HTTPStatus.NOT_IMPLEMENTED param = None + elif isinstance(exc, GenerationError): + err_type = "InternalServerError" + status_code = exc.status_code + param = None elif exc.__class__.__name__ == "TemplateError": # jinja2.TemplateError (avoid importing jinja2) err_type = "BadRequestError" @@ -341,13 +334,6 @@ def create_error_response( message = str(exc) - if log_error_stack: - exc_type, _, _ = sys.exc_info() - if exc_type is not None: - traceback.print_exc() - else: - traceback.print_stack() - return ErrorResponse( error=ErrorInfo( message=sanitize_message(message),