Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions docs/design/io_processor_plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,26 @@ IOProcessorOutput = TypeVar("IOProcessorOutput")

class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
super().__init__()

self.vllm_config = vllm_config

@abstractmethod
def parse_data(self, data: object) -> IOProcessorInput:
raise NotImplementedError

def merge_sampling_params(
self,
params: SamplingParams | None = None,
) -> SamplingParams:
return params or SamplingParams()

def merge_pooling_params(
self,
params: PoolingParams | None = None,
) -> PoolingParams:
return params or PoolingParams()

@abstractmethod
def pre_process(
self,
Expand Down Expand Up @@ -55,29 +73,13 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
[(i, item) async for i, item in model_output], key=lambda output: output[0]
)
collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id, **kwargs)

@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError

def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()

@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
raise NotImplementedError
return self.post_process(collected_output, request_id=request_id, **kwargs)
```

The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
The `parse_data` method is used for validating the user data and converting it into the input expected by the `pre_process*` methods.
The `merge_sampling_params` and `merge_pooling_params` methods merge input `SamplingParams` or `PoolingParams` (if any) with the default one.
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py).

An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_online.py](../../examples/pooling/plugin/prithvi_geospatial_mae_online.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,10 @@
from terratorch.datamodules import Sen1Floods11NonGeoDataModule

from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
IOProcessorResponse,
)
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (
IOProcessor,
IOProcessorInput,
IOProcessorOutput,
)
from vllm.plugins.io_processors.interface import IOProcessor

from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput

Expand Down Expand Up @@ -227,7 +219,7 @@ def load_image(
return imgs, temporal_coords, location_coords, metas


class PrithviMultimodalDataProcessor(IOProcessor):
class PrithviMultimodalDataProcessor(IOProcessor[ImagePrompt, ImageRequestOutput]):
indices = [0, 1, 2, 3, 4, 5]

def __init__(self, vllm_config: VllmConfig):
Expand All @@ -251,34 +243,15 @@ def __init__(self, vllm_config: VllmConfig):
self.requests_cache: dict[str, dict[str, Any]] = {}
self.indices = DEFAULT_INPUT_INDICES

def parse_request(self, request: Any) -> IOProcessorInput:
if type(request) is dict:
image_prompt = ImagePrompt(**request)
return image_prompt
if isinstance(request, IOProcessorRequest):
if not hasattr(request, "data"):
raise ValueError("missing 'data' field in OpenAIBaseModel Request")

request_data = request.data

if type(request_data) is dict:
return ImagePrompt(**request_data)
else:
raise ValueError("Unable to parse the request data")

raise ValueError("Unable to parse request")

def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
return IOProcessorResponse(
request_id=plugin_output.request_id,
data=plugin_output,
)
def parse_data(self, data: object) -> ImagePrompt:
if isinstance(data, dict):
return ImagePrompt(**data)

raise ValueError("Prompt data should be an `ImagePrompt`")

def pre_process(
self,
prompt: IOProcessorInput,
prompt: ImagePrompt,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
Expand Down Expand Up @@ -364,7 +337,7 @@ def post_process(
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
) -> ImageRequestOutput:
pred_imgs_list = []

if request_id and (request_id in self.requests_cache):
Expand Down Expand Up @@ -409,5 +382,7 @@ def post_process(
)

return ImageRequestOutput(
type=out_format, format="tiff", data=out_data, request_id=request_id
type=out_format,
format="tiff",
data=out_data,
)
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class ImagePrompt(BaseModel):
"""


MultiModalPromptType = ImagePrompt


class ImageRequestOutput(BaseModel):
"""
The output data of an image request to vLLM.
Expand All @@ -54,4 +51,3 @@ class ImageRequestOutput(BaseModel):
type: Literal["path", "b64_json"]
format: str
data: str
request_id: str | None = None
8 changes: 2 additions & 6 deletions tests/plugins_tests/test_io_processor_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ async def test_prithvi_mae_plugin_online(
# verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data

assert all(
plugin_data.get(attr) for attr in ["type", "format", "data", "request_id"]
)
assert all(plugin_data.get(attr) for attr in ["type", "format", "data"])

# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
Expand Down Expand Up @@ -110,9 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
output = pooler_output[0].outputs

# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr) for attr in ["type", "format", "data", "request_id"]
)
assert all(hasattr(output, attr) for attr in ["type", "format", "data"])

# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
Expand Down
74 changes: 35 additions & 39 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
Expand All @@ -95,6 +94,7 @@

logger = init_logger(__name__)

_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
_R = TypeVar("_R", default=Any)


Expand Down Expand Up @@ -1056,9 +1056,7 @@ def encode(
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)

io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True
if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
Expand All @@ -1068,40 +1066,42 @@ def encode(
)

# Validate the request data is valid for the loaded plugin
validated_prompt = self.io_processor.parse_request(prompts)
validated_prompt = self.io_processor.parse_data(prompts)

# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)
prompts_seq = prompt_to_seq(prompts)

if io_processor_prompt:
assert self.io_processor is not None
if is_list_of(pooling_params, PoolingParams):
validated_pooling_params: list[PoolingParams] = []
for param in as_iter(pooling_params):
validated_pooling_params.append(
self.io_processor.validate_or_generate_params(param)
)
pooling_params = validated_pooling_params
else:
assert not isinstance(pooling_params, Sequence)
pooling_params = self.io_processor.validate_or_generate_params(
pooling_params
params_seq: Sequence[PoolingParams] = [
self.io_processor.merge_pooling_params(param)
for param in self._params_to_seq(
pooling_params,
len(prompts_seq),
)

if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()

for param in as_iter(pooling_params):
if param.task is None:
param.task = pooling_task
elif param.task != pooling_task:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg)
]
for p in params_seq:
if p.task is None:
p.task = "plugin"
else:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()

prompts_seq = prompt_to_seq(prompts)
params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

for param in params_seq:
if param.task is None:
param.task = pooling_task
elif param.task != pooling_task:
msg = (
f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
)
raise ValueError(msg)

outputs = self._run_completion(
prompts=prompts,
params=pooling_params,
prompts=prompts_seq,
params=params_seq,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
Expand All @@ -1111,12 +1111,10 @@ def encode(
outputs, PoolingRequestOutput
)

if io_processor_prompt:
if use_io_processor:
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(
model_output=model_outputs
)
processed_outputs = self.io_processor.post_process(model_outputs)

return [
PoolingRequestOutput[Any](
Expand Down Expand Up @@ -1662,11 +1660,9 @@ def get_metrics(self) -> list["Metric"]:

def _params_to_seq(
self,
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
params: _P | Sequence[_P],
num_requests: int,
) -> Sequence[SamplingParams | PoolingParams]:
) -> Sequence[_P]:
if isinstance(params, Sequence):
if len(params) != num_requests:
raise ValueError(
Expand Down
3 changes: 0 additions & 3 deletions vllm/entrypoints/pooling/pooling/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
data: T
task: PoolingTask = "plugin"

def to_pooling_params(self):
return PoolingParams(task=self.task)


class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
request_id: str | None = None
Expand Down
Loading