Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 130 additions & 2 deletions vllm/entrypoints/pooling/base/io_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,31 @@ def pre_process_online(self, ctx: PoolingServeContext):
ctx.engine_inputs = engine_inputs

async def pre_process_online_async(self, ctx: PoolingServeContext):
self.pre_process_online(ctx)
request = ctx.request

if isinstance(ctx.request, PoolingChatLikeRequest):
self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
_, engine_inputs = await self._preprocess_chat_online_async(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionLikeRequest):
engine_inputs = await self._preprocess_completion_online_async(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError(f"Invalid {self.name} request type")

ctx.engine_inputs = engine_inputs

def post_process_online(
self,
Expand All @@ -109,7 +133,13 @@ def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput
)

async def pre_process_offline_async(self, ctx: OfflineInputsContext):
return self.pre_process_offline(ctx)
assert not isinstance(ctx.prompts, ScoringData)
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
return await self._preprocess_completion_offline_async(
prompts=ctx.prompts, tok_params=tok_params
)

def post_process_offline(
self,
Expand All @@ -126,6 +156,104 @@ async def post_process_offline_async(
#######################################
# helpers

async def _preprocess_completion_online_async(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[EngineInput]:
renderer = self.renderer
model_config = self.model_config

prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))

parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
tok_params = request.build_tok_params(model_config)

return await renderer.render_cmpl_async(
parsed_prompts,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)

async def _preprocess_chat_online_async(
self,
request: RendererChatRequest,
messages: list[ChatCompletionMessageParam],
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: type[ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[EngineInput]]:
renderer = self.renderer

default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
)

mm_config = self.model_config.multimodal_config

tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(
default_template_kwargs,
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
)

(conversation,), (engine_input,) = await renderer.render_chat_async(
[messages],
chat_params,
tok_params,
prompt_extras={
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
},
)

return conversation, [engine_input]

async def _preprocess_completion_offline_async(
self,
prompts: "PromptType | Sequence[PromptType]",
tok_params: TokenizeParams,
prompt_extras: dict[str, Any] | None = None,
) -> Sequence[EngineInput]:
prompts = prompt_to_seq(prompts)
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(self.model_config, prompt)
)
for prompt in prompts
]

return await self.renderer.render_cmpl_async(
parsed_prompts, tok_params, prompt_extras=prompt_extras
)

def _preprocess_completion_online(
self,
request: RendererRequest,
Expand Down
148 changes: 148 additions & 0 deletions vllm/entrypoints/pooling/embed/io_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def pre_process_online(self, ctx: PoolingServeContext):
if self.enable_chunked_processing:
self._pre_process_chunked(ctx)

async def pre_process_online_async(self, ctx: PoolingServeContext):
if isinstance(ctx.request, CohereEmbedRequest):
await self._pre_process_cohere_online_async(ctx)
else:
await super().pre_process_online_async(ctx)

if self.enable_chunked_processing:
self._pre_process_chunked(ctx)

def post_process_online(
self,
ctx: PoolingServeContext,
Expand Down Expand Up @@ -438,6 +447,82 @@ def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None:
request, all_messages, truncate_prompt_tokens, truncation_side
)

async def _pre_process_cohere_online_async(self, ctx: PoolingServeContext) -> None:
"""Async version of ``_pre_process_cohere_online``.

Uses the async renderer path so that CPU-bound multimodal
preprocessing runs in a thread pool instead of blocking the
asyncio event loop.
"""
request = ctx.request
assert isinstance(request, CohereEmbedRequest)

if request.texts is None and request.images is None and request.inputs is None:
raise ValueError("One of texts, images, or inputs must be provided")

truncate_prompt_tokens, truncation_side = self._resolve_cohere_truncation(
request
)
input_type = request.input_type
self._validate_input_type(input_type)

if request.images is not None:
input: list[CohereEmbedInput] = [
CohereEmbedInput(
content=[
CohereEmbedContent(type="image_url", image_url={"url": uri})
]
)
for uri in request.images
]
elif request.inputs is not None:
input = request.inputs
else:
texts = request.texts or []
task_prefix = self._get_task_instruction_prefix(input_type)

if task_prefix is None:
ctx.engine_inputs = await self._preprocess_cohere_text_completion_async(
request,
texts,
truncate_prompt_tokens,
truncation_side,
)
return

all_messages = [
self._mixed_input_to_messages(
CohereEmbedInput(
content=[CohereEmbedContent(type="text", text=text)]
),
task_prefix=task_prefix,
)
for text in texts
]
if self._has_chat_template():
ctx.engine_inputs = await self._batch_render_chat_async(
request,
all_messages,
truncate_prompt_tokens,
truncation_side,
)
else:
ctx.engine_inputs = await self._preprocess_cohere_text_completion_async(
request,
self._apply_task_instruction(texts, input_type),
truncate_prompt_tokens,
truncation_side,
)
return

task_prefix = self._get_task_instruction_prefix(input_type)
all_messages = [
self._mixed_input_to_messages(inp, task_prefix=task_prefix) for inp in input
]
ctx.engine_inputs = await self._batch_render_chat_async(
request, all_messages, truncate_prompt_tokens, truncation_side
)

def _has_chat_template(self) -> bool:
return (
resolve_chat_template(
Expand Down Expand Up @@ -509,6 +594,69 @@ def _batch_render_chat(
_, engine_inputs = renderer.render_chat(all_messages, chat_params, tok_params)
return engine_inputs

async def _batch_render_chat_async(
self,
request: CohereEmbedRequest,
all_messages: Sequence[list[ChatCompletionMessageParam]],
truncate_prompt_tokens: int | None,
truncation_side: Literal["left", "right"] | None,
) -> list[EngineInput]:
"""Async version of ``_batch_render_chat``."""
if not all_messages:
return []

proxy = EmbeddingChatRequest(
model=request.model,
messages=list(all_messages[0]),
dimensions=request.output_dimension,
encoding_format="float",
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side,
)

renderer = self.renderer
mm_config = self.model_config.multimodal_config

tok_params = proxy.build_tok_params(self.model_config)
chat_params = proxy.build_chat_params(
self.chat_template,
self.chat_template_content_format,
).with_defaults(
merge_kwargs(
None,
dict(
tools=None,
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
),
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
)

_, engine_inputs = await renderer.render_chat_async(
all_messages, chat_params, tok_params
)
return engine_inputs

async def _preprocess_cohere_text_completion_async(
self,
request: CohereEmbedRequest,
texts: list[str],
truncate_prompt_tokens: int | None,
truncation_side: Literal["left", "right"] | None,
) -> list[EngineInput]:
"""Async version of ``_preprocess_cohere_text_completion``."""
proxy = EmbeddingCompletionRequest(
model=request.model,
input=texts,
dimensions=request.output_dimension,
encoding_format="float",
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side,
)
return await self._preprocess_completion_online_async(
proxy, prompt_input=proxy.input, prompt_embeds=None
)

def _validate_input_type(self, input_type: str | None) -> None:
"""Raise if *input_type* is not supported by this model."""
if input_type is None:
Expand Down
Loading
Loading