diff --git a/examples/online_serving/batched_chat_completions.py b/examples/online_serving/batched_chat_completions.py new file mode 100644 index 000000000000..0f76d5cf3620 --- /dev/null +++ b/examples/online_serving/batched_chat_completions.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Examples of batched chat completions via the vLLM OpenAI-compatible API. + +The /v1/chat/completions/batch endpoint accepts ``messages`` as a list of +conversations. Each conversation is processed independently and the response +contains one choice per conversation, indexed 0, 1, ..., N-1. + +Start a server first, e.g.: + vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 + +Current limitations compared to /v1/chat/completions: + - Streaming is not supported. + - Tool use is not supported. + - Beam search is not supported. +""" + +import json +import os + +import httpx + +BASE_URL = os.environ.get("VLLM_BASE_URL", "http://localhost:8000") +MODEL = os.environ.get("VLLM_MODEL", "Qwen/Qwen2.5-1.5B-Instruct") +BATCH_URL = f"{BASE_URL}/v1/chat/completions/batch" + + +def post_batch(payload: dict) -> dict: + response = httpx.post(BATCH_URL, json=payload, timeout=60) + response.raise_for_status() + return response.json() + + +def main() -> None: + print("=== Example 1a: single conversation (standard endpoint) ===") + response = httpx.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "model": MODEL, + "messages": [{"role": "user", "content": "What is the capital of Japan?"}], + }, + timeout=60, + ) + response.raise_for_status() + data = response.json() + for choice in data["choices"]: + print(f" [{choice['index']}] {choice['message']['content']}") + + print("\n=== Example 1b: batched plain text (2 conversations) ===") + data = post_batch( + { + "model": MODEL, + "messages": [ + [{"role": "user", "content": "What is the capital of France?"}], + [{"role": "user", "content": "What is the capital of Japan?"}], + ], + } + ) + for choice in data["choices"]: + print(f" [{choice['index']}] {choice['message']['content']}") + + print("\n=== Example 2: batch with regex constraint (yes|no) ===") + data = post_batch( + { + "model": MODEL, + "messages": [ + [{"role": "user", "content": "Is the sky blue? Answer yes or no."}], + [{"role": "user", "content": "Is fire cold? Answer yes or no."}], + ], + "structured_outputs": {"regex": "(yes|no)"}, + } + ) + for choice in data["choices"]: + print(f" [{choice['index']}] {choice['message']['content']}") + + print("\n=== Example 3: batch with json_schema ===") + person_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Full name of the person"}, + "age": {"type": "integer", "description": "Age in years"}, + }, + "required": ["name", "age"], + } + data = post_batch( + { + "model": MODEL, + "messages": [ + [ + { + "role": "user", + "content": "Describe the person: name Alice, age 30.", + } + ], + [{"role": "user", "content": "Describe the person: name Bob, age 25."}], + ], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "person", + "strict": True, + "schema": person_schema, + }, + }, + } + ) + for choice in data["choices"]: + person = json.loads(choice["message"]["content"]) + print(f" [{choice['index']}] {person}") + + print("\n=== Example 4: batch book summaries ===") + book_schema = { + "type": "object", + "properties": { + "author": { + "type": "string", + "description": "Full name of the author", + }, + "num_pages": { + "type": "integer", + "description": "Number of pages in the book", + }, + "short_summary": { + "type": "string", + "description": "A one-sentence summary of the book", + }, + "long_summary": { + "type": "string", + "description": ( + "A detailed two to three sentence summary covering " + "the main themes and plot" + ), + }, + }, + "required": ["author", "num_pages", "short_summary", "long_summary"], + } + system_msg = { + "role": "system", + "content": ( + "You are a literary analyst. Extract structured information " + "from book descriptions." + ), + } + data = post_batch( + { + "model": MODEL, + "messages": [ + [ + system_msg, + { + "role": "user", + "content": ( + "Extract information from this book: '1984' by George" + " Orwell, published in 1949, 328 pages. A dystopian" + " novel set in a totalitarian society ruled by Big" + " Brother, following Winston Smith as he secretly" + " rebels against the oppressive Party that surveils" + " and controls every aspect of life." + ), + }, + ], + [ + system_msg, + { + "role": "user", + "content": ( + "Extract information from this book: 'The Hitchhiker's" + " Guide to the Galaxy' by Douglas Adams, published in" + " 1979, 193 pages. A comedic science fiction novel" + " following Arthur Dent, an ordinary Englishman who is" + " whisked off Earth moments before it is demolished to" + " make way for a hyperspace bypass, and his subsequent" + " absurd adventures across the universe." + ), + }, + ], + ], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "book_summary", + "strict": True, + "schema": book_schema, + }, + }, + } + ) + for choice in data["choices"]: + book = json.loads(choice["message"]["content"]) + print(f" [{choice['index']}] {book}") + + +if __name__ == "__main__": + main() diff --git a/tests/entrypoints/openai/chat_completion/test_batched_chat_completions.py b/tests/entrypoints/openai/chat_completion/test_batched_chat_completions.py new file mode 100644 index 000000000000..c3a8d0b2bdec --- /dev/null +++ b/tests/entrypoints/openai/chat_completion/test_batched_chat_completions.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import httpx +import pytest + +from tests.utils import RemoteOpenAIServer + +# any model with a chat template defined in tokenizer_config should work here +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_batched_chat_completions( + server: RemoteOpenAIServer, model_name: str +) -> None: + conversations = [ + [{"role": "user", "content": "Reply with exactly the word: alpha"}], + [{"role": "user", "content": "Reply with exactly the word: beta"}], + ] + + async with httpx.AsyncClient() as http_client: + response = await http_client.post( + f"{server.url_for('v1/chat/completions/batch')}", + json={ + "model": model_name, + "messages": conversations, + }, + timeout=60, + ) + + assert response.status_code == 200, response.text + data = response.json() + + choices = data["choices"] + assert len(choices) == 2 + + indices = {choice["index"] for choice in choices} + assert indices == {0, 1} + + # Each conversation should produce a non-empty text response. + for choice in choices: + assert choice["message"]["content"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_batched_chat_completions_with_json_schema( + server: RemoteOpenAIServer, model_name: str +) -> None: + schema = { + "type": "object", + "properties": { + "answer": {"type": "string", "enum": ["yes", "no"]}, + }, + "required": ["answer"], + } + conversations = [ + [{"role": "user", "content": "Is the sky blue? Answer in JSON."}], + [{"role": "user", "content": "Is fire cold? Answer in JSON."}], + ] + + async with httpx.AsyncClient() as http_client: + response = await http_client.post( + f"{server.url_for('v1/chat/completions/batch')}", + json={ + "model": model_name, + "messages": conversations, + "response_format": { + "type": "json_schema", + "json_schema": {"name": "answer", "schema": schema, "strict": True}, + }, + }, + timeout=60, + ) + + assert response.status_code == 200, response.text + data = response.json() + + choices = data["choices"] + assert len(choices) == 2 + + for choice in choices: + parsed = json.loads(choice["message"]["content"]) + assert "answer" in parsed + assert parsed["answer"] in ("yes", "no") diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index d83fffa12ded..083290ed5b3a 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -174,6 +174,7 @@ def test_openapi_stateless(case: Case): timeout = { # requires a longer timeout ("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS, + ("POST", "/v1/chat/completions/batch"): LONG_TIMEOUT_SECONDS, ("POST", "/v1/completions"): LONG_TIMEOUT_SECONDS, ("POST", "/v1/messages"): LONG_TIMEOUT_SECONDS, }.get(key, DEFAULT_TIMEOUT_SECONDS) diff --git a/vllm/entrypoints/openai/chat_completion/api_router.py b/vllm/entrypoints/openai/chat_completion/api_router.py index 28a2eab679c0..cdaaa27fcdab 100644 --- a/vllm/entrypoints/openai/chat_completion/api_router.py +++ b/vllm/entrypoints/openai/chat_completion/api_router.py @@ -7,7 +7,9 @@ from fastapi import APIRouter, Depends, FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse +from vllm.entrypoints.openai.chat_completion.batch_serving import OpenAIServingChatBatch from vllm.entrypoints.openai.chat_completion.protocol import ( + BatchChatCompletionRequest, ChatCompletionRequest, ChatCompletionResponse, ) @@ -31,6 +33,10 @@ def chat(request: Request) -> OpenAIServingChat | None: return request.app.state.openai_serving_chat +def batch_chat(request: Request) -> OpenAIServingChatBatch | None: + return request.app.state.openai_serving_chat_batch + + @router.post( "/v1/chat/completions", dependencies=[Depends(validate_json_request)], @@ -68,5 +74,33 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re return StreamingResponse(content=generator, media_type="text/event-stream") +@router.post( + "/v1/chat/completions/batch", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_batch_chat_completion( + request: BatchChatCompletionRequest, raw_request: Request +): + handler = batch_chat(raw_request) + if handler is None: + raise NotImplementedError("The model does not support Chat Completions API") + + result = await handler.create_batch_chat_completion(request, raw_request) + + if isinstance(result, ErrorResponse): + return JSONResponse(content=result.model_dump(), status_code=result.error.code) + + return JSONResponse(content=result.model_dump()) + + def attach_router(app: FastAPI): app.include_router(router) diff --git a/vllm/entrypoints/openai/chat_completion/batch_serving.py b/vllm/entrypoints/openai/chat_completion/batch_serving.py new file mode 100644 index 000000000000..f97c93bb03c8 --- /dev/null +++ b/vllm/entrypoints/openai/chat_completion/batch_serving.py @@ -0,0 +1,317 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time +from collections.abc import AsyncGenerator +from http import HTTPStatus + +from fastapi import Request + +from vllm.entrypoints.chat_utils import ConversationMessage +from vllm.entrypoints.openai.chat_completion.protocol import ( + BatchChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, +) +from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat +from vllm.entrypoints.openai.engine.protocol import ( + ErrorResponse, + RequestResponseMetadata, + UsageInfo, +) +from vllm.entrypoints.utils import get_max_tokens +from vllm.inputs import EngineInput +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.reasoning import ReasoningParser +from vllm.tokenizers import TokenizerLike +from vllm.utils.async_utils import merge_async_iterators +from vllm.utils.collection_utils import as_list + +logger = init_logger(__name__) + + +class OpenAIServingChatBatch(OpenAIServingChat): + """Extends OpenAIServingChat with the /v1/chat/completions/batch endpoint. + + Processes N conversations from a single request concurrently and returns + one choice per conversation indexed 0, 1, ..., N-1. + """ + + async def render_batch_chat_request( + self, + request: BatchChatCompletionRequest, + ) -> tuple[list[list[ConversationMessage]], list[EngineInput]] | ErrorResponse: + """Validate the model and preprocess a batched chat completion request. + + Performs engine-aware checks then delegates per-conversation + preprocessing to OpenAIServingRender, validating the chat template + once for the whole batch. + + Returns: + A tuple of (all_conversations, engine_prompts) on success — one + entry per conversation — or an ErrorResponse on failure. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + if self.engine_client.errored: + raise self.engine_client.dead_error + + render = self.openai_serving_render + + if not render.use_harmony: + # Common case: validate the chat template once for the whole batch. + error_check_ret = render.validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=render.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret + + tool_parser = render.tool_parser + tool_dicts: list[dict] | None = None + + all_conversations: list[list[ConversationMessage]] = [] + all_engine_prompts: list[EngineInput] = [] + + for messages in request.messages: + single_request = request.to_chat_completion_request(messages) + if render.use_harmony: + conversation, engine_prompts = render._make_request_with_harmony( + single_request, should_include_tools=tool_dicts is not None + ) + else: + conversation, engine_prompts = await render.preprocess_chat( + single_request, + messages, + default_template=render.chat_template, + default_template_content_format=render.chat_template_content_format, + default_template_kwargs=render.default_chat_template_kwargs, + tool_dicts=tool_dicts, + tool_parser=tool_parser, + ) + all_conversations.append(conversation) + all_engine_prompts.append(engine_prompts[0]) + + return all_conversations, all_engine_prompts + + async def create_batch_chat_completion( + self, + request: BatchChatCompletionRequest, + raw_request: Request | None = None, + ) -> ChatCompletionResponse | ErrorResponse: + """Batch Chat Completion endpoint (/v1/chat/completions/batch). + + Processes N conversations from a single request concurrently and + returns one choice per conversation indexed 0, 1, ..., N-1. + Streaming, tool use, and beam search are not supported. + """ + tokenizer = self.renderer.tokenizer + assert tokenizer is not None + + reasoning_parser: ReasoningParser | None = None + if self.reasoning_parser_cls: + chat_template_kwargs = self._prepare_extra_chat_template_kwargs( + request.chat_template_kwargs, + self.default_chat_template_kwargs, + ) + reasoning_parser = self.reasoning_parser_cls( + tokenizer, + chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] + ) + + render_result = await self.render_batch_chat_request(request) + if isinstance(render_result, ErrorResponse): + return render_result + all_conversations, engine_prompts = render_result + + request_id = ( + f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" + ) + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True) + model_name = self.models.model_name(lora_request) + data_parallel_rank = self._get_data_parallel_rank(raw_request) + max_model_len = self.model_config.max_model_len + + generators: list[AsyncGenerator[RequestOutput, None]] = [] + for i, engine_prompt in enumerate(engine_prompts): + sub_request_id = f"{request_id}_{i}" + max_tokens = get_max_tokens( + max_model_len, + request.max_completion_tokens + if request.max_completion_tokens is not None + else request.max_tokens, + self._extract_prompt_len(engine_prompt), + self.default_sampling_params, + self.override_max_tokens, + ) + single_request = request.to_chat_completion_request(request.messages[i]) + sampling_params = single_request.to_sampling_params( + max_tokens, self.default_sampling_params + ) + self._log_inputs( + sub_request_id, + engine_prompt, + params=sampling_params, + lora_request=lora_request, + ) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) + generators.append( + self.engine_client.generate( + engine_prompt, + sampling_params, + sub_request_id, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority if hasattr(request, "priority") else 0, + data_parallel_rank=data_parallel_rank, + reasoning_ended=None, + ) + ) + + return await self.chat_completion_full_generator_batch( + request, # type: ignore[arg-type] + generators, + request_id, + model_name, + all_conversations, + tokenizer, + request_metadata, + reasoning_parser, + ) + + async def chat_completion_full_generator_batch( + self, + request: BatchChatCompletionRequest, # type: ignore[override] + generators: list[AsyncGenerator[RequestOutput, None]], + request_id: str, + model_name: str, + all_conversations: list[list[ConversationMessage]], + tokenizer: TokenizerLike, + request_metadata: RequestResponseMetadata, + reasoning_parser: ReasoningParser | None = None, + ) -> ErrorResponse | ChatCompletionResponse: + """Handle batched (non-streaming) chat completions. + + Fans out N generators (one per conversation in the batch), collects + the final output for each, and assembles a single + ``ChatCompletionResponse`` whose ``choices`` are indexed 0,...,N-1. + + Tool-use and streaming are rejected upstream by the + ``check_batch_mode`` validator, so neither needs to be handled here. + """ + created_time = int(time.time()) + role = self.get_chat_request_role(request) # type: ignore[arg-type] + + final_results: dict[int, RequestOutput] = {} + try: + async for prompt_idx, res in merge_async_iterators(*generators): + final_results[prompt_idx] = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + + choices: list[ChatCompletionResponseChoice] = [] + total_prompt_tokens = 0 + total_completion_tokens = 0 + + for prompt_idx in range(len(generators)): + final_res = final_results.get(prompt_idx) + if final_res is None: + return self.create_error_response( + f"No output received from the engine for prompt {prompt_idx}.", + err_type="InternalServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + total_prompt_tokens += num_prompt_tokens + total_completion_tokens += sum( + len(output.token_ids) for output in final_res.outputs + ) + + for output in final_res.outputs: + self._raise_if_error(output.finish_reason, request_id) + + if request.logprobs and request.top_logprobs is not None: + assert output.logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=output.token_ids, + top_logprobs=output.logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + return_as_token_id=request.return_token_ids, + ) + else: + logprobs = None + + if reasoning_parser: + reasoning, content = reasoning_parser.extract_reasoning( + output.text, + request=request, # type: ignore[arg-type] + ) + if not getattr(request, "include_reasoning", True): + reasoning = None + else: + reasoning = None + content = output.text + + message = ChatMessage(role=role, reasoning=reasoning, content=content) + + if request.echo: + conversation = all_conversations[prompt_idx] + last_msg_content: str | list[dict[str, str]] = "" + if conversation and "content" in conversation[-1]: + last_msg_content = conversation[-1]["content"] or "" + if isinstance(last_msg_content, list): + last_msg_content = "\n".join( + msg["text"] for msg in last_msg_content + ) + message.content = last_msg_content + (message.content or "") + + choice_data = ChatCompletionResponseChoice( + index=prompt_idx, + message=message, + logprobs=logprobs, + finish_reason=output.finish_reason + if output.finish_reason + else "stop", + stop_reason=output.stop_reason, + token_ids=( + as_list(output.token_ids) if request.return_token_ids else None + ), + ) + choices.append(choice_data) + + usage = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, + ) + request_metadata.final_usage_info = usage + + choices.sort(key=lambda c: c.index) + + return ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 09ef448f41ad..533959df6094 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -787,3 +787,83 @@ def set_include_reasoning_for_none_effort(cls, data: Any) -> Any: if data.get("reasoning_effort") == "none": data["include_reasoning"] = False return data + + +class BatchChatCompletionRequest(OpenAIBaseModel): + """Request model for the /v1/chat/completions/batch endpoint. + + Accepts the same fields as ChatCompletionRequest except that ``messages`` + is a list of conversations (each conversation is a + ``list[ChatCompletionMessageParam]``). Each conversation is processed + independently and the response contains one choice per conversation, + indexed 0, 1, ..., N-1. + + Current limitations compared to the single-conversation endpoint: + - Streaming is not supported (``stream`` must be False or omitted). + - Tool use is not supported (``tools`` must be omitted). + - Beam search is not supported (``use_beam_search`` must be False or omitted). + - The ``n`` parameter must be 1 (or omitted). + """ + + messages: list[list[ChatCompletionMessageParam]] = Field(..., min_length=1) + model: str | None = None + + # Shared sampling / generation fields — mirror ChatCompletionRequest. + frequency_penalty: float | None = 0.0 + logit_bias: dict[str, float] | None = None + logprobs: bool | None = False + top_logprobs: int | None = 0 + max_tokens: int | None = None + max_completion_tokens: int | None = None + n: int | None = 1 + presence_penalty: float | None = 0.0 + response_format: Any | None = None + seed: int | None = Field(None, ge=_INT64_MIN, le=_INT64_MAX) + stop: str | list[str] | None = Field(default_factory=list) + temperature: float | None = 0.7 + top_p: float | None = 1.0 + user: str | None = None + + # vLLM extensions + best_of: int | None = None + use_beam_search: bool = False + top_k: int | None = None + min_p: float | None = 0.0 + repetition_penalty: float | None = 1.0 + length_penalty: float | None = 1.0 + early_stopping: bool = False + structured_outputs: StructuredOutputsParams | None = None + request_id: str | None = None + add_generation_prompt: bool = True + continue_final_message: bool = False + chat_template: str | None = None + chat_template_kwargs: dict[str, Any] | None = None + include_stop_str_in_output: bool = False + guided_decoding_backend: str | None = None + echo: bool = False + return_token_ids: bool = False + + @model_validator(mode="before") + @classmethod + def check_batch_mode(cls, data: Any) -> Any: + if isinstance(data, BatchChatCompletionRequest): + data = data.model_dump(exclude_unset=True) + if data.get("use_beam_search"): + raise ValueError( + "Batch chat completions do not support beam search. " + "Please set `use_beam_search` to False." + ) + n = data.get("n", 1) + if n is not None and n != 1: + raise ValueError( + "Batch chat completions do not support `n > 1`. Please set `n` to 1." + ) + return data + + def to_chat_completion_request( + self, messages: list[ChatCompletionMessageParam] + ) -> ChatCompletionRequest: + """Build a single-conversation ChatCompletionRequest from one conversation.""" + data = self.model_dump(exclude={"messages"}, exclude_none=True) + data["messages"] = messages + return ChatCompletionRequest.model_validate(data) diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index d8df1d3c4d96..73e126082837 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -26,6 +26,7 @@ ) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.chat_completion.protocol import ( + BatchChatCompletionRequest, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, @@ -124,7 +125,10 @@ def build_chat_params( ) ChatLikeRequest: TypeAlias = ( - ChatCompletionRequest | TokenizeChatRequest | PoolingChatRequest + ChatCompletionRequest + | BatchChatCompletionRequest + | TokenizeChatRequest + | PoolingChatRequest ) SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest diff --git a/vllm/entrypoints/openai/generate/api_router.py b/vllm/entrypoints/openai/generate/api_router.py index 77c3580de622..9a64db929e8f 100644 --- a/vllm/entrypoints/openai/generate/api_router.py +++ b/vllm/entrypoints/openai/generate/api_router.py @@ -56,6 +56,9 @@ async def init_generate_state( MCPToolServer, ToolServer, ) + from vllm.entrypoints.openai.chat_completion.batch_serving import ( + OpenAIServingChatBatch, + ) from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses @@ -96,27 +99,31 @@ async def init_generate_state( if "generate" in supported_tasks else None ) + _chat_kwargs = dict( + engine_client=engine_client, + models=state.openai_serving_models, + response_role=args.response_role, + openai_serving_render=state.openai_serving_render, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + default_chat_template_kwargs=args.default_chat_template_kwargs, + trust_request_chat_template=args.trust_request_chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + enable_log_deltas=args.enable_log_deltas, + ) state.openai_serving_chat = ( - OpenAIServingChat( - engine_client, - state.openai_serving_models, - args.response_role, - openai_serving_render=state.openai_serving_render, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - default_chat_template_kwargs=args.default_chat_template_kwargs, - trust_request_chat_template=args.trust_request_chat_template, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, - tool_parser=args.tool_call_parser, - reasoning_parser=args.structured_outputs_config.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - enable_log_outputs=args.enable_log_outputs, - enable_log_deltas=args.enable_log_deltas, - ) + OpenAIServingChat(**_chat_kwargs) if "generate" in supported_tasks else None + ) + state.openai_serving_chat_batch = ( + OpenAIServingChatBatch(**_chat_kwargs) if "generate" in supported_tasks else None )