Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions docs/source/dev/logits_processors/logits_processor_plugins.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
.. Logits Processor Plugins:

Logits Processor Plugins
========================

vLLM supports using custom logits processors through plugins.
This means you can use custom logits processors, and even create your own without having to change vLLM.

Installing a logits processor plugin
------------------------------------

To install a logits processor plugin,
all you have to do it install the Python package containing the plugin in the same environment as vLLM.


Using the installed plugin
--------------------------------

To use the logits processor plugins you installed,
you can use the :code:`logits_processors` field in the generation request body as such:

.. code-block:: console

$ curl http://localhost:8000/v1/completions \
$ -H "Content-Type: application/json" \
$ -d '{
$ "model": "...",
$ "prompt": "Hello!",
$ "max_tokens": 32,
$ "temperature": 0,
$ "logits_processors": {
$ "my_logits_processor": {
$ "my_param": 100
$ }
$ }
$ }'

.. note::
This is only an example, in reality each logits processor plugin has a name which is the key in the :code:`logits_processors` dictionary.
The value is a dictionary of parameters passed to the plugin's implementation.


Creating your own logits processor plugin
-----------------------------------------

Advanced users might want to build their custom logits processors and publish them as plugins.
This can be done by simply creating a Python package with your implementation.

Here is an example :code:`main.py` for a logits processor plugin implementation,
that takes a token ID and multiplies it's logit by 100:

.. code-block:: python

from pydantic import BaseModel


class MyParameters(BaseModel):
token_id: int


class MyLogitsProcessor:
def __init__(self, tokenizer, parameters: MyParameters):
self.tokenizer = tokenizer
self.parameters = parameters

def __call__(self, token_ids, logits):
new_logits = logits.clone()
new_logits[self.parameters.token_id] *= 100
return new_logits


LOGITS_PROCESSOR_PLUGIN = {
'logits_processor_class': MyLogitsProcessor,
'parameters_model': MyParameters
}


The :code:`setup.py` file for the plugin package should look something like this:

.. code-block:: python

from setuptools import setup

setup(name='example_logits_processor',
version='0.1',
install_requires=[
"pydantic>=1.8.2"
],
entry_points={
'vllm.logits_processors': ['example_plugin=example_plugin.main:LOGITS_PROCESSOR_PLUGIN']
}
)

After installing the plugin package in the same environment as vLLM,
you can run vLLM and use you custom logits processor as such:

.. code-block:: console

$ curl http://localhost:8000/v1/completions \
$ -H "Content-Type: application/json" \
$ -d '{
$ "model": "...",
$ "prompt": "Hello!",
$ "max_tokens": 32,
$ "temperature": 0,
$ "logits_processors": {
$ "example_plugin": {
$ "token_id": 10
$ }
$ }
$ }'
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Documentation
dev/engine/engine_index
dev/kernel/paged_attention
dev/dockerfile/dockerfile
dev/logits_processors/logits_processor_plugins

Indices and tables
==================
Expand Down
32 changes: 26 additions & 6 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
from contextlib import asynccontextmanager
from http import HTTPStatus
from importlib.metadata import entry_points
from typing import Optional, Set

import fastapi
Expand All @@ -13,6 +14,7 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from pydantic import ValidationError
from starlette.routing import Mount

import vllm
Expand All @@ -26,6 +28,7 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.logger import init_logger
from vllm.plugins import LogitsProcessorPlugin
from vllm.usage.usage_lib import UsageContext

TIMEOUT_KEEP_ALIVE = 5 # seconds
Expand Down Expand Up @@ -134,6 +137,24 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
allow_headers=args.allowed_headers,
)

logits_processors_plugins = {}
for entry_point in entry_points().\
select(group='vllm.logits_processors'): # type: ignore
logits_processor_plugin = entry_point.load()
try:
logits_processor_plugin = LogitsProcessorPlugin(
**logits_processor_plugin)
except ValidationError as e:
raise ValueError(
f"Invalid logits processor plugin {entry_point.name}. "
f"Please check the configuration. {e}") from e

logits_processors_plugins[entry_point.name] = logits_processor_plugin

logger.info('Loaded %d logits processor plugins (%s)',
len(logits_processors_plugins),
", ".join(logits_processors_plugins.keys()))

if token := envs.VLLM_API_KEY or args.api_key:

@app.middleware("http")
Expand Down Expand Up @@ -183,13 +204,12 @@ async def authentication(request: Request, call_next):
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())

openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_chat = OpenAIServingChat(
engine, model_config, served_model_names, args.response_role,
args.lora_modules, logits_processors_plugins, args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules)
engine, model_config, served_model_names, args.lora_modules,
logits_processors_plugins)

app.root_path = args.root_path
uvicorn.run(app,
Expand Down
61 changes: 58 additions & 3 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

import torch
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import (BaseModel, ConfigDict, Field, ValidationError,
model_validator)
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated

from vllm.plugins import LogitsProcessorPlugin
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

Expand Down Expand Up @@ -153,10 +156,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
logits_processors: Optional[Dict[str, dict]] = Field(
default=None,
description=("If specified, will use custom logit processor plugins"))

# doc: end-chat-completion-extra-params

def to_sampling_params(self) -> SamplingParams:
def to_sampling_params(self, logits_processor_plugins: Dict[
str, LogitsProcessorPlugin],
tokenizer: PreTrainedTokenizer) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")

Expand All @@ -175,6 +183,27 @@ def logit_bias_logits_processor(

logits_processors = [logit_bias_logits_processor]

if self.logits_processors is not None:
logits_processors = logits_processors or []
for lp_name, lp_parameters in self.logits_processors.items():
lp_plugin = logits_processor_plugins.get(lp_name, None)
if lp_plugin is None:
available_lps = list(logits_processor_plugins.keys())
raise ValueError(
f"Logits processor {lp_name} not found in available "
f"logits processors ({available_lps}).")

try:
lp_parameters = lp_plugin.parameters_model.parse_obj(
lp_parameters)
logits_processor = lp_plugin.logits_processor_class(
tokenizer, lp_parameters)
logits_processors.append(logits_processor)
except ValidationError as e:
raise ValueError(
f"Invalid parameters for logits processor "
f"{lp_name}: {e}") from e

return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
Expand Down Expand Up @@ -299,10 +328,15 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
logits_processors: Optional[Dict[str, dict]] = Field(
default=None,
description=("If specified, will use custom logit processor plugins"))

# doc: end-completion-extra-params

def to_sampling_params(self):
def to_sampling_params(
self, logits_processor_plugins: Dict[str, LogitsProcessorPlugin],
tokenizer: PreTrainedTokenizer):
echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = None
Expand All @@ -320,6 +354,27 @@ def logit_bias_logits_processor(

logits_processors = [logit_bias_logits_processor]

if self.logits_processors is not None:
logits_processors = logits_processors or []
for lp_name, lp_parameters in self.logits_processors.items():
lp_plugin = logits_processor_plugins.get(lp_name, None)
if lp_plugin is None:
available_lps = list(logits_processor_plugins.keys())
raise ValueError(
f"Logits processor {lp_name} not found in available "
f"logits processors ({available_lps}).")

try:
lp_parameters = lp_plugin.parameters_model.parse_obj(
lp_parameters)
logits_processor = lp_plugin.logits_processor_class(
tokenizer, lp_parameters)
logits_processors.append(logits_processor)
except ValidationError as e:
raise ValueError(
f"Invalid parameters for logits processor "
f"{lp_name}: {e}") from e

return SamplingParams(
n=self.n,
best_of=self.best_of,
Expand Down
24 changes: 14 additions & 10 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import codecs
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Optional, Tuple, TypedDict, Union, final)
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
List, Optional, Tuple, TypedDict, Union, final)

from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam,
Expand All @@ -20,6 +20,7 @@
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.plugins import LogitsProcessorPlugin
from vllm.utils import random_uuid

logger = init_logger(__name__)
Expand All @@ -39,12 +40,14 @@ def __init__(self,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None,
logits_processor_plugins: Optional[Dict[
str, LogitsProcessorPlugin]] = None,
chat_template: Optional[str] = None):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)

lora_modules=lora_modules,
logits_processor_plugins=logits_processor_plugins)
self.response_role = response_role
self._load_chat_template(chat_template)

Expand Down Expand Up @@ -136,18 +139,19 @@ async def create_chat_completion(

request_id = f"cmpl-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt)
sampling_params = request.to_sampling_params()
tokenizer = await self.engine.get_tokenizer()
sampling_params = request.to_sampling_params(
self.logits_processor_plugins, tokenizer)
lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
Expand Down Expand Up @@ -387,4 +391,4 @@ async def chat_completion_full_generator(
usage=usage,
)

return response
return response
16 changes: 12 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.plugins import LogitsProcessorPlugin
from vllm.utils import merge_async_iterators, random_uuid

logger = init_logger(__name__)
Expand Down Expand Up @@ -53,13 +54,18 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:

class OpenAIServingCompletion(OpenAIServing):

def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
lora_modules: Optional[List[LoRAModulePath]],
logits_processor_plugins: Optional[Dict[
str, LogitsProcessorPlugin]] = None):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
logits_processor_plugins=logits_processor_plugins)

async def create_completion(self, request: CompletionRequest,
raw_request: Request):
Expand Down Expand Up @@ -88,7 +94,9 @@ async def create_completion(self, request: CompletionRequest,
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
try:
sampling_params = request.to_sampling_params()
tokenizer = await self.engine.get_tokenizer()
sampling_params = request.to_sampling_params(
self.logits_processor_plugins, tokenizer)
lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
Expand Down
Loading