Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def _fake_preprocess_chat(*args, **kwargs):
[{"prompt_token_ids": [1, 2, 3]}],
)

serving_chat.openai_serving_render._preprocess_chat = AsyncMock(
serving_chat.openai_serving_render.preprocess_chat = AsyncMock(
side_effect=_fake_preprocess_chat
)
return serving_chat
Expand Down
4 changes: 4 additions & 0 deletions tests/entrypoints/openai/test_serving_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ async def serving_responses_instance(self):
instance = OpenAIServingResponses(
engine_client=engine_client,
models=models,
openai_serving_render=MagicMock(),
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
Expand Down Expand Up @@ -245,6 +246,7 @@ async def serving_responses_instance(self):
instance = OpenAIServingResponses(
engine_client=engine_client,
models=models,
openai_serving_render=MagicMock(),
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
Expand Down Expand Up @@ -308,6 +310,7 @@ def get_vocab(self):
serving = OpenAIServingResponses(
engine_client=engine_client,
models=models,
openai_serving_render=MagicMock(),
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
Expand Down Expand Up @@ -607,6 +610,7 @@ def _make_serving_instance_with_reasoning():
serving = OpenAIServingResponses(
engine_client=engine_client,
models=models,
openai_serving_render=MagicMock(),
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
Expand Down
19 changes: 19 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from vllm.entrypoints.serve.elastic_ep.middleware import (
ScalingMiddleware,
)
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
cli_env_setup,
Expand Down Expand Up @@ -365,9 +366,27 @@ async def init_app_state(
lora_modules=lora_modules,
)
await state.openai_serving_models.init_static_loras()

state.openai_serving_render = OpenAIServingRender(
model_config=engine_client.model_config,
renderer=engine_client.renderer,
io_processor=engine_client.io_processor,
model_registry=state.openai_serving_models.registry,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
default_chat_template_kwargs=args.default_chat_template_kwargs,
log_error_stack=args.log_error_stack,
)

state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
state.openai_serving_models,
state.openai_serving_render,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
Expand Down
229 changes: 3 additions & 226 deletions vllm/entrypoints/openai/engine/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import json
import time
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from collections.abc import AsyncGenerator, Callable, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
Expand All @@ -22,9 +22,7 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
Expand All @@ -43,19 +41,9 @@
GenerationError,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.responses.context import (
ConversationContext,
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.entrypoints.openai.responses.utils import (
construct_input_messages,
)
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
Expand All @@ -82,26 +70,22 @@
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import create_error_response, get_max_tokens
from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonPrompt,
TokensPrompt,
token_inputs,
)
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers import ChatParams, TokenizeParams
from vllm.renderers.inputs.preprocess import (
extract_prompt_components,
extract_prompt_len,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
Expand All @@ -116,7 +100,6 @@
collect_from_async_generator,
merge_async_iterators,
)
from vllm.utils.mistral import is_mistral_tokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -823,109 +806,6 @@ def _prepare_extra_chat_template_kwargs(
# Apply server defaults first, then request kwargs override.
return default_chat_template_kwargs | request_chat_template_kwargs

async def _preprocess_completion(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[ProcessorInputs]:
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))

return await self._preprocess_cmpl(request, prompts)

async def _preprocess_cmpl(
self,
request: RendererRequest,
prompts: Sequence[PromptType | bytes],
) -> list[ProcessorInputs]:
renderer = self.renderer
model_config = self.model_config

parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = request.build_tok_params(model_config)

return await renderer.render_cmpl_async(
parsed_prompts,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)

async def _preprocess_chat(
self,
request: RendererChatRequest,
messages: list[ChatCompletionMessageParam],
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
renderer = self.renderer

default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
)

mm_config = self.model_config.multimodal_config

tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(
default_template_kwargs,
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
)

(conversation,), (engine_prompt,) = await renderer.render_chat_async(
[messages],
chat_params,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)

# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none":
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported for Chat Completions API "
"or Responses API requests."
)
raise NotImplementedError(msg)

# TODO: Update adjust_request to accept ResponsesRequest
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type]

return conversation, [engine_prompt]

def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
return extract_prompt_components(self.model_config, prompt)

Expand All @@ -935,109 +815,6 @@ def _extract_prompt_text(self, prompt: ProcessorInputs):
def _extract_prompt_len(self, prompt: ProcessorInputs):
return extract_prompt_len(self.model_config, prompt)

async def _render_next_turn(
self,
request: ResponsesRequest,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
new_messages = construct_input_messages(
request_input=messages,
)

_, engine_prompts = await self._preprocess_chat(
request,
new_messages,
default_template=chat_template,
default_template_content_format=chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
)
return engine_prompts

async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: ProcessorInputs,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
max_model_len = self.model_config.max_model_len

orig_priority = priority
sub_request = 0
while True:
# Ensure that each sub-request has a unique request id.
sub_request_id = f"{request_id}_{sub_request}"

self._log_inputs(
sub_request_id,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)

generator = self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)

async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
yield context

if not context.need_builtin_tool_call():
# The model did not ask for a tool call, so we're done.
break

# Call the tool and update the context with the result.
tool_output = await context.call_tool()
context.append_tool_output(tool_output)

# TODO: uncomment this and enable tool output streaming
# yield context

# Create inputs for the next turn.
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion()
engine_prompt = token_inputs(token_ids)

sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
(engine_prompt,) = await self._render_next_turn(
context.request,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
context.chat_template,
context.chat_template_content_format,
)

sampling_params.max_tokens = get_max_tokens(
max_model_len,
context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
self.override_max_tokens, # type: ignore
)

# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1

def _log_inputs(
self,
request_id: str,
Expand Down
Loading
Loading