diff --git a/requirements/common.txt b/requirements/common.txt index 8562649a9c4e..81c4d6675006 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -48,3 +48,4 @@ pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss +anthropic == 0.71.0 diff --git a/tests/entrypoints/anthropic/__init__.py b/tests/entrypoints/anthropic/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/anthropic/test_messages.py b/tests/entrypoints/anthropic/test_messages.py new file mode 100644 index 000000000000..4e35554b4e33 --- /dev/null +++ b/tests/entrypoints/anthropic/test_messages.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import anthropic +import pytest +import pytest_asyncio + +from ...utils import RemoteAnthropicServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +@pytest.fixture(scope="module") +def server(): # noqa: F811 + args = [ + "--max-model-len", + "2048", + "--enforce-eager", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--served-model-name", + "claude-3-7-sonnet-latest", + ] + + with RemoteAnthropicServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_simple_messages(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[{"role": "user", "content": "how are you!"}], + ) + assert resp.stop_reason == "end_turn" + assert resp.role == "assistant" + + print(f"Anthropic response: {resp.model_dump_json()}") + + +@pytest.mark.asyncio +async def test_system_message(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + system="you are a helpful assistant", + messages=[{"role": "user", "content": "how are you!"}], + ) + assert resp.stop_reason == "end_turn" + assert resp.role == "assistant" + + print(f"Anthropic response: {resp.model_dump_json()}") + + +@pytest.mark.asyncio +async def test_anthropic_streaming(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[{"role": "user", "content": "how are you!"}], + stream=True, + ) + + async for chunk in resp: + print(chunk.model_dump_json()) + + +@pytest.mark.asyncio +async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[ + {"role": "user", "content": "What's the weather like in New York today?"} + ], + tools=[ + { + "name": "get_current_weather", + "description": "Useful for querying the weather in a specified city.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City or region, for example: " + "New York, London, Tokyo, etc.", + } + }, + "required": ["location"], + }, + } + ], + stream=False, + ) + assert resp.stop_reason == "tool_use" + assert resp.role == "assistant" + + print(f"Anthropic response: {resp.model_dump_json()}") + + @pytest.mark.asyncio + async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "What's the weather like in New York today?", + } + ], + tools=[ + { + "name": "get_current_weather", + "description": "Useful for querying the weather " + "in a specified city.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City or region, for example: " + "New York, London, Tokyo, etc.", + } + }, + "required": ["location"], + }, + } + ], + stream=True, + ) + + async for chunk in resp: + print(chunk.model_dump_json()) diff --git a/tests/utils.py b/tests/utils.py index c29597a26ecc..e52497cf52a1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,6 +23,7 @@ from typing import Any, Literal from unittest.mock import patch +import anthropic import cloudpickle import httpx import openai @@ -294,6 +295,131 @@ def __exit__(self, exc_type, exc_value, traceback): self.proc.kill() +class RemoteAnthropicServer: + DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key + + def __init__( + self, + model: str, + vllm_serve_args: list[str], + *, + env_dict: dict[str, str] | None = None, + seed: int | None = 0, + auto_port: bool = True, + max_wait_seconds: float | None = None, + ) -> None: + if auto_port: + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: + raise ValueError( + "You have manually specified the port when `auto_port=True`." + ) + + # Don't mutate the input args + vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())] + if seed is not None: + if "--seed" in vllm_serve_args: + raise ValueError( + f"You have manually specified the seed when `seed={seed}`." + ) + + vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + + parser = FlexibleArgumentParser(description="vLLM's remote Anthropic server.") + subparsers = parser.add_subparsers(required=False, dest="subparser") + parser = ServeSubcommand().subparser_init(subparsers) + args = parser.parse_args(["--model", model, *vllm_serve_args]) + self.host = str(args.host or "localhost") + self.port = int(args.port) + + self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None + + # download the model before starting the server to avoid timeout + is_local = os.path.isdir(model) + if not is_local: + engine_args = AsyncEngineArgs.from_cli_args(args) + model_config = engine_args.create_model_config() + load_config = engine_args.create_load_config() + + model_loader = get_model_loader(load_config) + model_loader.download_model(model_config) + + env = os.environ.copy() + # the current process might initialize cuda, + # to be safe, we should use spawn method + env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if env_dict is not None: + env.update(env_dict) + self.proc = subprocess.Popen( + [ + sys.executable, + "-m", + "vllm.entrypoints.anthropic.api_server", + model, + *vllm_serve_args, + ], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + max_wait_seconds = max_wait_seconds or 240 + self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + try: + self.proc.wait(8) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() + + def _wait_for_server(self, *, url: str, timeout: float): + # run health check + start = time.time() + while True: + try: + if requests.get(url).status_code == 200: + break + except Exception: + # this exception can only be raised by requests.get, + # which means the server is not ready yet. + # the stack trace is not useful, so we suppress it + # by using `raise from None`. + result = self.proc.poll() + if result is not None and result != 0: + raise RuntimeError("Server exited unexpectedly.") from None + + time.sleep(0.5) + if time.time() - start > timeout: + raise RuntimeError("Server failed to start in time.") from None + + @property + def url_root(self) -> str: + return f"http://{self.host}:{self.port}" + + def url_for(self, *parts: str) -> str: + return self.url_root + "/" + "/".join(parts) + + def get_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.Anthropic( + base_url=self.url_for(), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) + + def get_async_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.AsyncAnthropic( + base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs + ) + + def _test_completion( client: openai.OpenAI, model: str, diff --git a/vllm/entrypoints/anthropic/__init__.py b/vllm/entrypoints/anthropic/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/anthropic/api_server.py b/vllm/entrypoints/anthropic/api_server.py new file mode 100644 index 000000000000..249a7ee0121a --- /dev/null +++ b/vllm/entrypoints/anthropic/api_server.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from: +# https://github.com/vllm/vllm/entrypoints/openai/api_server.py + +import asyncio +import signal +import tempfile +from argparse import Namespace +from http import HTTPStatus + +import uvloop +from fastapi import APIRouter, Depends, FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State + +import vllm.envs as envs +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.anthropic.protocol import ( + AnthropicErrorResponse, + AnthropicMessagesRequest, + AnthropicMessagesResponse, +) +from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client, + create_server_socket, + lifespan, + load_log_config, + validate_api_server_args, + validate_json_request, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.openai.protocol import ErrorResponse +from vllm.entrypoints.openai.serving_models import ( + BaseModelPath, + OpenAIServingModels, +) + +# +# yapf: enable +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import ( + cli_env_setup, + load_aware_call, + process_chat_template, + process_lora_modules, + with_cancellation, +) +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser, set_ulimit +from vllm.utils.network_utils import is_valid_ipv6_address +from vllm.version import __version__ as VLLM_VERSION + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger("vllm.entrypoints.anthropic.api_server") + +_running_tasks: set[asyncio.Task] = set() + +router = APIRouter() + + +def messages(request: Request) -> AnthropicServingMessages: + return request.app.state.anthropic_serving_messages + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health", response_class=Response) +async def health(raw_request: Request) -> Response: + """Health check.""" + await engine_client(raw_request).check_health() + return Response(status_code=200) + + +@router.get("/ping", response_class=Response) +@router.post("/ping", response_class=Response) +async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + +@router.post( + "/v1/messages", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): + handler = messages(raw_request) + if handler is None: + return messages(raw_request).create_error_response( + message="The model does not support Messages API" + ) + + generator = await handler.create_messages(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump()) + + elif isinstance(generator, AnthropicMessagesResponse): + logger.debug( + "Anthropic Messages Response: %s", generator.model_dump(exclude_none=True) + ) + return JSONResponse(content=generator.model_dump(exclude_none=True)) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +async def init_app_state( + engine_client: EngineClient, + state: State, + args: Namespace, +) -> None: + vllm_config = engine_client.vllm_config + + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) for name in served_model_names + ] + + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + state.vllm_config = vllm_config + model_config = vllm_config.model_config + + default_mm_loras = ( + vllm_config.lora_config.default_mm_loras + if vllm_config.lora_config is not None + else {} + ) + lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) + + resolved_chat_template = await process_chat_template( + args.chat_template, engine_client, model_config + ) + + state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + base_model_paths=base_model_paths, + lora_modules=lora_modules, + ) + await state.openai_serving_models.init_static_loras() + state.anthropic_serving_messages = AnthropicServingMessages( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) + + +def setup_server(args): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + + logger.info("vLLM API server version %s", VLLM_VERSION) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + validate_api_server_args(args) + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + + return listen_address, sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + """Run a single-worker API server.""" + listen_address, sock = setup_server(args) + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +def build_app(args: Namespace) -> FastAPI: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + return app + + +async def run_server_worker( + listen_address, sock, args, client_config=None, **uvicorn_kwargs +) -> None: + """Run a single API server worker.""" + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + server_index = client_config.get("client_index", 0) if client_config else 0 + + # Load logging config for uvicorn if specified + log_config = load_log_config(args.log_config_file) + if log_config is not None: + uvicorn_kwargs["log_config"] = log_config + + async with build_async_engine_client( + args, + client_config=client_config, + ) as engine_client: + app = build_app(args) + + await init_app_state(engine_client, app.state, args) + + logger.info("Starting vLLM API server %d on %s", server_index, listen_address) + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/entrypoints/cli/main.py for CLI + # entrypoints. + cli_env_setup() + parser = FlexibleArgumentParser( + description="vLLM Anthropic-Compatible RESTful API server." + ) + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py new file mode 100644 index 000000000000..626ca7472ae6 --- /dev/null +++ b/vllm/entrypoints/anthropic/protocol.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pydantic models for Anthropic API protocol""" + +import time +from typing import Any, Literal, Optional + +from pydantic import BaseModel, field_validator + + +class AnthropicError(BaseModel): + """Error structure for Anthropic API""" + + type: str + message: str + + +class AnthropicErrorResponse(BaseModel): + """Error response structure for Anthropic API""" + + type: Literal["error"] = "error" + error: AnthropicError + + +class AnthropicUsage(BaseModel): + """Token usage information""" + + input_tokens: int + output_tokens: int + cache_creation_input_tokens: int | None = None + cache_read_input_tokens: int | None = None + + +class AnthropicContentBlock(BaseModel): + """Content block in message""" + + type: Literal["text", "image", "tool_use", "tool_result"] + text: str | None = None + # For image content + source: dict[str, Any] | None = None + # For tool use/result + id: str | None = None + name: str | None = None + input: dict[str, Any] | None = None + content: str | list[dict[str, Any]] | None = None + is_error: bool | None = None + + +class AnthropicMessage(BaseModel): + """Message structure""" + + role: Literal["user", "assistant"] + content: str | list[AnthropicContentBlock] + + +class AnthropicTool(BaseModel): + """Tool definition""" + + name: str + description: str | None = None + input_schema: dict[str, Any] + + @field_validator("input_schema") + @classmethod + def validate_input_schema(cls, v): + if not isinstance(v, dict): + raise ValueError("input_schema must be a dictionary") + if "type" not in v: + v["type"] = "object" # Default to object type + return v + + +class AnthropicToolChoice(BaseModel): + """Tool Choice definition""" + + type: Literal["auto", "any", "tool"] + name: str | None = None + + +class AnthropicMessagesRequest(BaseModel): + """Anthropic Messages API request""" + + model: str + messages: list[AnthropicMessage] + max_tokens: int + metadata: dict[str, Any] | None = None + stop_sequences: list[str] | None = None + stream: bool | None = False + system: str | list[AnthropicContentBlock] | None = None + temperature: float | None = None + tool_choice: AnthropicToolChoice | None = None + tools: list[AnthropicTool] | None = None + top_k: int | None = None + top_p: float | None = None + + @field_validator("model") + @classmethod + def validate_model(cls, v): + if not v: + raise ValueError("Model is required") + return v + + @field_validator("max_tokens") + @classmethod + def validate_max_tokens(cls, v): + if v <= 0: + raise ValueError("max_tokens must be positive") + return v + + +class AnthropicDelta(BaseModel): + """Delta for streaming responses""" + + type: Literal["text_delta", "input_json_delta"] | None = None + text: str | None = None + partial_json: str | None = None + + # Message delta + stop_reason: ( + Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None + ) = None + stop_sequence: str | None = None + + +class AnthropicStreamEvent(BaseModel): + """Streaming event""" + + type: Literal[ + "message_start", + "message_delta", + "message_stop", + "content_block_start", + "content_block_delta", + "content_block_stop", + "ping", + "error", + ] + message: Optional["AnthropicMessagesResponse"] = None + delta: AnthropicDelta | None = None + content_block: AnthropicContentBlock | None = None + index: int | None = None + error: AnthropicError | None = None + usage: AnthropicUsage | None = None + + +class AnthropicMessagesResponse(BaseModel): + """Anthropic Messages API response""" + + id: str + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + content: list[AnthropicContentBlock] + model: str + stop_reason: ( + Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None + ) = None + stop_sequence: str | None = None + usage: AnthropicUsage | None = None + + def model_post_init(self, __context): + if not self.id: + self.id = f"msg_{int(time.time() * 1000)}" diff --git a/vllm/entrypoints/anthropic/serving_messages.py b/vllm/entrypoints/anthropic/serving_messages.py new file mode 100644 index 000000000000..11c96adf332f --- /dev/null +++ b/vllm/entrypoints/anthropic/serving_messages.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from +# https://github.com/vllm/vllm/entrypoints/openai/serving_chat.py + +"""Anthropic Messages API serving handler""" + +import json +import logging +import time +from collections.abc import AsyncGenerator +from typing import Any + +from fastapi import Request + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.anthropic.protocol import ( + AnthropicContentBlock, + AnthropicDelta, + AnthropicError, + AnthropicMessagesRequest, + AnthropicMessagesResponse, + AnthropicStreamEvent, + AnthropicUsage, +) +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionStreamResponse, + ChatCompletionToolsParam, + ErrorResponse, + StreamOptions, +) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_models import OpenAIServingModels + +logger = logging.getLogger(__name__) + + +def wrap_data_with_event(data: str, event: str): + return f"event: {event}\ndata: {data}\n\n" + + +class AnthropicServingMessages(OpenAIServingChat): + """Handler for Anthropic Messages API requests""" + + def __init__( + self, + engine_client: EngineClient, + models: OpenAIServingModels, + response_role: str, + *, + request_logger: RequestLogger | None, + chat_template: str | None, + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + reasoning_parser: str = "", + enable_auto_tools: bool = False, + tool_parser: str | None = None, + enable_prompt_tokens_details: bool = False, + enable_force_include_usage: bool = False, + ): + super().__init__( + engine_client=engine_client, + models=models, + response_role=response_role, + request_logger=request_logger, + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + return_tokens_as_token_ids=return_tokens_as_token_ids, + reasoning_parser=reasoning_parser, + enable_auto_tools=enable_auto_tools, + tool_parser=tool_parser, + enable_prompt_tokens_details=enable_prompt_tokens_details, + enable_force_include_usage=enable_force_include_usage, + ) + self.stop_reason_map = { + "stop": "end_turn", + "length": "max_tokens", + "tool_calls": "tool_use", + } + + def _convert_anthropic_to_openai_request( + self, anthropic_request: AnthropicMessagesRequest + ) -> ChatCompletionRequest: + """Convert Anthropic message format to OpenAI format""" + openai_messages = [] + + # Add system message if provided + if anthropic_request.system: + if isinstance(anthropic_request.system, str): + openai_messages.append( + {"role": "system", "content": anthropic_request.system} + ) + else: + system_prompt = "" + for block in anthropic_request.system: + if block.type == "text" and block.text: + system_prompt += block.text + openai_messages.append({"role": "system", "content": system_prompt}) + + for msg in anthropic_request.messages: + openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore + if isinstance(msg.content, str): + openai_msg["content"] = msg.content + else: + # Handle complex content blocks + content_parts: list[dict[str, Any]] = [] + tool_calls: list[dict[str, Any]] = [] + + for block in msg.content: + if block.type == "text" and block.text: + content_parts.append({"type": "text", "text": block.text}) + elif block.type == "image" and block.source: + content_parts.append( + { + "type": "image_url", + "image_url": {"url": block.source.get("data", "")}, + } + ) + elif block.type == "tool_use": + # Convert tool use to function call format + tool_call = { + "id": block.id or f"call_{int(time.time())}", + "type": "function", + "function": { + "name": block.name or "", + "arguments": json.dumps(block.input or {}), + }, + } + tool_calls.append(tool_call) + elif block.type == "tool_result": + if msg.role == "user": + openai_messages.append( + { + "role": "tool", + "tool_call_id": block.id or "", + "content": str(block.content) + if block.content + else "", + } + ) + else: + # Assistant tool result becomes regular text + tool_result_text = ( + str(block.content) if block.content else "" + ) + content_parts.append( + { + "type": "text", + "text": f"Tool result: {tool_result_text}", + } + ) + + # Add tool calls to the message if any + if tool_calls: + openai_msg["tool_calls"] = tool_calls # type: ignore + + # Add content parts if any + if content_parts: + if len(content_parts) == 1 and content_parts[0]["type"] == "text": + openai_msg["content"] = content_parts[0]["text"] + else: + openai_msg["content"] = content_parts # type: ignore + elif not tool_calls: + continue + + openai_messages.append(openai_msg) + + req = ChatCompletionRequest( + model=anthropic_request.model, + messages=openai_messages, + max_tokens=anthropic_request.max_tokens, + max_completion_tokens=anthropic_request.max_tokens, + stop=anthropic_request.stop_sequences, + temperature=anthropic_request.temperature, + top_p=anthropic_request.top_p, + top_k=anthropic_request.top_k, + ) + + if anthropic_request.stream: + req.stream = anthropic_request.stream + req.stream_options = StreamOptions.validate({"include_usage": True}) + + if anthropic_request.tool_choice is None: + req.tool_choice = None + elif anthropic_request.tool_choice.type == "auto": + req.tool_choice = "auto" + elif anthropic_request.tool_choice.type == "any": + req.tool_choice = "required" + elif anthropic_request.tool_choice.type == "tool": + req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate( + { + "type": "function", + "function": {"name": anthropic_request.tool_choice.name}, + } + ) + + tools = [] + if anthropic_request.tools is None: + return req + for tool in anthropic_request.tools: + tools.append( + ChatCompletionToolsParam.model_validate( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + }, + } + ) + ) + if req.tool_choice is None: + req.tool_choice = "auto" + req.tools = tools + return req + + async def create_messages( + self, + request: AnthropicMessagesRequest, + raw_request: Request | None = None, + ) -> AsyncGenerator[str, None] | AnthropicMessagesResponse | ErrorResponse: + """ + Messages API similar to Anthropic's API. + + See https://docs.anthropic.com/en/api/messages + for the API specification. This API mimics the Anthropic messages API. + """ + logger.debug("Received messages request %s", request.model_dump_json()) + chat_req = self._convert_anthropic_to_openai_request(request) + logger.debug("Convert to OpenAI request %s", request.model_dump_json()) + generator = await self.create_chat_completion(chat_req, raw_request) + + if isinstance(generator, ErrorResponse): + return generator + + elif isinstance(generator, ChatCompletionResponse): + return self.messages_full_converter(generator) + + return self.message_stream_converter(generator) + + def messages_full_converter( + self, + generator: ChatCompletionResponse, + ) -> AnthropicMessagesResponse: + result = AnthropicMessagesResponse( + id=generator.id, + content=[], + model=generator.model, + usage=AnthropicUsage( + input_tokens=generator.usage.prompt_tokens, + output_tokens=generator.usage.completion_tokens, + ), + ) + if generator.choices[0].finish_reason == "stop": + result.stop_reason = "end_turn" + elif generator.choices[0].finish_reason == "length": + result.stop_reason = "max_tokens" + elif generator.choices[0].finish_reason == "tool_calls": + result.stop_reason = "tool_use" + + content: list[AnthropicContentBlock] = [ + AnthropicContentBlock( + type="text", + text=generator.choices[0].message.content + if generator.choices[0].message.content + else "", + ) + ] + + for tool_call in generator.choices[0].message.tool_calls: + anthropic_tool_call = AnthropicContentBlock( + type="tool_use", + id=tool_call.id, + name=tool_call.function.name, + input=json.loads(tool_call.function.arguments), + ) + content += [anthropic_tool_call] + + result.content = content + + return result + + async def message_stream_converter( + self, + generator: AsyncGenerator[str, None], + ) -> AsyncGenerator[str, None]: + try: + first_item = True + finish_reason = None + content_block_index = 0 + content_block_started = False + + async for item in generator: + if item.startswith("data:"): + data_str = item[5:].strip().rstrip("\n") + if data_str == "[DONE]": + stop_message = AnthropicStreamEvent( + type="message_stop", + ) + data = stop_message.model_dump_json( + exclude_unset=True, exclude_none=True + ) + yield wrap_data_with_event(data, "message_stop") + yield "data: [DONE]\n\n" + else: + origin_chunk = ChatCompletionStreamResponse.model_validate_json( + data_str + ) + + if first_item: + chunk = AnthropicStreamEvent( + type="message_start", + message=AnthropicMessagesResponse( + id=origin_chunk.id, + content=[], + model=origin_chunk.model, + ), + ) + first_item = False + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "message_start") + continue + + # last chunk including usage info + if len(origin_chunk.choices) == 0: + if content_block_started: + stop_chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_stop", + ) + data = stop_chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_stop") + stop_reason = self.stop_reason_map.get( + finish_reason or "stop" + ) + chunk = AnthropicStreamEvent( + type="message_delta", + delta=AnthropicDelta(stop_reason=stop_reason), + usage=AnthropicUsage( + input_tokens=origin_chunk.usage.prompt_tokens + if origin_chunk.usage + else 0, + output_tokens=origin_chunk.usage.completion_tokens + if origin_chunk.usage + else 0, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "message_delta") + continue + + if origin_chunk.choices[0].finish_reason is not None: + finish_reason = origin_chunk.choices[0].finish_reason + continue + + # content + if origin_chunk.choices[0].delta.content is not None: + if not content_block_started: + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_start", + content_block=AnthropicContentBlock( + type="text", text="" + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_start") + content_block_started = True + + if origin_chunk.choices[0].delta.content == "": + continue + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_delta", + delta=AnthropicDelta( + type="text_delta", + text=origin_chunk.choices[0].delta.content, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_delta") + continue + + # tool calls + elif len(origin_chunk.choices[0].delta.tool_calls) > 0: + tool_call = origin_chunk.choices[0].delta.tool_calls[0] + if tool_call.id is not None: + if content_block_started: + stop_chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_stop", + ) + data = stop_chunk.model_dump_json( + exclude_unset=True + ) + yield wrap_data_with_event( + data, "content_block_stop" + ) + content_block_started = False + content_block_index += 1 + + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_start", + content_block=AnthropicContentBlock( + type="tool_use", + id=tool_call.id, + name=tool_call.function.name + if tool_call.function + else None, + input={}, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_start") + content_block_started = True + + else: + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_delta", + delta=AnthropicDelta( + type="input_json_delta", + partial_json=tool_call.function.arguments + if tool_call.function + else None, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_delta") + continue + else: + error_response = AnthropicStreamEvent( + type="error", + error=AnthropicError( + type="internal_error", + message="Invalid data format received", + ), + ) + data = error_response.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "error") + yield "data: [DONE]\n\n" + + except Exception as e: + logger.exception("Error in message stream converter.") + error_response = AnthropicStreamEvent( + type="error", + error=AnthropicError(type="internal_error", message=str(e)), + ) + data = error_response.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "error") + yield "data: [DONE]\n\n" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 555c95effd1d..abc772c79a5a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -41,11 +41,6 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ( - load_chat_template, - resolve_hf_chat_template, - resolve_mistral_chat_template, -) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args @@ -88,7 +83,6 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import ( BaseModelPath, - LoRAModulePath, OpenAIServingModels, ) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling @@ -105,11 +99,12 @@ cli_env_setup, load_aware_call, log_non_default_args, + process_chat_template, + process_lora_modules, with_cancellation, ) from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import ( Device, @@ -1638,32 +1633,9 @@ async def init_app_state( supported_tasks = await engine_client.get_supported_tasks() logger.info("Supported tasks: %s", supported_tasks) - resolved_chat_template = load_chat_template(args.chat_template) - if resolved_chat_template is not None: - # Get the tokenizer to check official template - tokenizer = await engine_client.get_tokenizer() - - if isinstance(tokenizer, MistralTokenizer): - # The warning is logged in resolve_mistral_chat_template. - resolved_chat_template = resolve_mistral_chat_template( - chat_template=resolved_chat_template - ) - else: - hf_chat_template = resolve_hf_chat_template( - tokenizer=tokenizer, - chat_template=None, - tools=None, - model_config=vllm_config.model_config, - ) - - if hf_chat_template != resolved_chat_template: - logger.warning( - "Using supplied chat template: %s\n" - "It is different from official chat template '%s'. " - "This discrepancy may lead to performance degradation.", - resolved_chat_template, - args.model, - ) + resolved_chat_template = await process_chat_template( + args.chat_template, engine_client, vllm_config.model_config + ) if args.tool_server == "demo": tool_server: ToolServer | None = DemoToolServer() @@ -1682,19 +1654,12 @@ async def init_app_state( else {} ) - lora_modules = args.lora_modules - if default_mm_loras: - default_mm_lora_paths = [ - LoRAModulePath( - name=modality, - path=lora_path, - ) - for modality, lora_path in default_mm_loras.items() - ] - if args.lora_modules is None: - lora_modules = default_mm_lora_paths - else: - lora_modules += default_mm_lora_paths + default_mm_loras = ( + vllm_config.lora_config.default_mm_loras + if vllm_config.lora_config is not None + else {} + ) + lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) state.openai_serving_models = OpenAIServingModels( engine_client=engine_client, diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index c006a76d3cdf..ec5fb3b56b7f 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -6,21 +6,31 @@ import functools import os from argparse import Namespace +from pathlib import Path from typing import Any from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse from starlette.background import BackgroundTask, BackgroundTasks +from vllm.config import ModelConfig from vllm.engine.arg_utils import EngineArgs +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ( + load_chat_template, + resolve_hf_chat_template, + resolve_mistral_chat_template, +) from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, CompletionRequest, StreamOptions, ) +from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.utils import FlexibleArgumentParser logger = init_logger(__name__) @@ -254,3 +264,56 @@ def should_include_usage( else: include_usage, include_continuous_usage = enable_force_include_usage, False return include_usage, include_continuous_usage + + +def process_lora_modules( + args_lora_modules: list[LoRAModulePath], default_mm_loras: dict[str, str] | None +) -> list[LoRAModulePath]: + lora_modules = args_lora_modules + if default_mm_loras: + default_mm_lora_paths = [ + LoRAModulePath( + name=modality, + path=lora_path, + ) + for modality, lora_path in default_mm_loras.items() + ] + if args_lora_modules is None: + lora_modules = default_mm_lora_paths + else: + lora_modules += default_mm_lora_paths + return lora_modules + + +async def process_chat_template( + args_chat_template: Path | str | None, + engine_client: EngineClient, + model_config: ModelConfig, +) -> str | None: + resolved_chat_template = load_chat_template(args_chat_template) + if resolved_chat_template is not None: + # Get the tokenizer to check official template + tokenizer = await engine_client.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + # The warning is logged in resolve_mistral_chat_template. + resolved_chat_template = resolve_mistral_chat_template( + chat_template=resolved_chat_template + ) + else: + hf_chat_template = resolve_hf_chat_template( + tokenizer=tokenizer, + chat_template=None, + tools=None, + model_config=model_config, + ) + + if hf_chat_template != resolved_chat_template: + logger.warning( + "Using supplied chat template: %s\n" + "It is different from official chat template '%s'. " + "This discrepancy may lead to performance degradation.", + resolved_chat_template, + model_config.model, + ) + return resolved_chat_template