diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index 3e029259e025..c6945e443c37 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -14,8 +14,26 @@ IOProcessorOutput = TypeVar("IOProcessorOutput") class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): def __init__(self, vllm_config: VllmConfig): + super().__init__() + self.vllm_config = vllm_config + @abstractmethod + def parse_data(self, data: object) -> IOProcessorInput: + raise NotImplementedError + + def merge_sampling_params( + self, + params: SamplingParams | None = None, + ) -> SamplingParams: + return params or SamplingParams() + + def merge_pooling_params( + self, + params: PoolingParams | None = None, + ) -> PoolingParams: + return params or PoolingParams() + @abstractmethod def pre_process( self, @@ -55,29 +73,13 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): [(i, item) async for i, item in model_output], key=lambda output: output[0] ) collected_output = [output[1] for output in sorted_output] - return self.post_process(collected_output, request_id, **kwargs) - - @abstractmethod - def parse_request(self, request: Any) -> IOProcessorInput: - raise NotImplementedError - - def validate_or_generate_params( - self, params: SamplingParams | PoolingParams | None = None - ) -> SamplingParams | PoolingParams: - return params or PoolingParams() - - @abstractmethod - def output_to_response( - self, plugin_output: IOProcessorOutput - ) -> IOProcessorResponse: - raise NotImplementedError + return self.post_process(collected_output, request_id=request_id, **kwargs) ``` -The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods. +The `parse_data` method is used for validating the user data and converting it into the input expected by the `pre_process*` methods. +The `merge_sampling_params` and `merge_pooling_params` methods merge input `SamplingParams` or `PoolingParams` (if any) with the default one. The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. -The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. -The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py). An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_online.py](../../examples/pooling/plugin/prithvi_geospatial_mae_online.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples. diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index 329b09c6824a..7915da94f88c 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -18,18 +18,10 @@ from terratorch.datamodules import Sen1Floods11NonGeoDataModule from vllm.config import VllmConfig -from vllm.entrypoints.pooling.pooling.protocol import ( - IOProcessorRequest, - IOProcessorResponse, -) from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput -from vllm.plugins.io_processors.interface import ( - IOProcessor, - IOProcessorInput, - IOProcessorOutput, -) +from vllm.plugins.io_processors.interface import IOProcessor from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput @@ -227,7 +219,7 @@ def load_image( return imgs, temporal_coords, location_coords, metas -class PrithviMultimodalDataProcessor(IOProcessor): +class PrithviMultimodalDataProcessor(IOProcessor[ImagePrompt, ImageRequestOutput]): indices = [0, 1, 2, 3, 4, 5] def __init__(self, vllm_config: VllmConfig): @@ -251,34 +243,15 @@ def __init__(self, vllm_config: VllmConfig): self.requests_cache: dict[str, dict[str, Any]] = {} self.indices = DEFAULT_INPUT_INDICES - def parse_request(self, request: Any) -> IOProcessorInput: - if type(request) is dict: - image_prompt = ImagePrompt(**request) - return image_prompt - if isinstance(request, IOProcessorRequest): - if not hasattr(request, "data"): - raise ValueError("missing 'data' field in OpenAIBaseModel Request") - - request_data = request.data - - if type(request_data) is dict: - return ImagePrompt(**request_data) - else: - raise ValueError("Unable to parse the request data") - - raise ValueError("Unable to parse request") - - def output_to_response( - self, plugin_output: IOProcessorOutput - ) -> IOProcessorResponse: - return IOProcessorResponse( - request_id=plugin_output.request_id, - data=plugin_output, - ) + def parse_data(self, data: object) -> ImagePrompt: + if isinstance(data, dict): + return ImagePrompt(**data) + + raise ValueError("Prompt data should be an `ImagePrompt`") def pre_process( self, - prompt: IOProcessorInput, + prompt: ImagePrompt, request_id: str | None = None, **kwargs, ) -> PromptType | Sequence[PromptType]: @@ -364,7 +337,7 @@ def post_process( model_output: Sequence[PoolingRequestOutput], request_id: str | None = None, **kwargs, - ) -> IOProcessorOutput: + ) -> ImageRequestOutput: pred_imgs_list = [] if request_id and (request_id in self.requests_cache): @@ -409,5 +382,7 @@ def post_process( ) return ImageRequestOutput( - type=out_format, format="tiff", data=out_data, request_id=request_id + type=out_format, + format="tiff", + data=out_data, ) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py index d1d7873211f2..3a1a9c3be41e 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py @@ -38,9 +38,6 @@ class ImagePrompt(BaseModel): """ -MultiModalPromptType = ImagePrompt - - class ImageRequestOutput(BaseModel): """ The output data of an image request to vLLM. @@ -54,4 +51,3 @@ class ImageRequestOutput(BaseModel): type: Literal["path", "b64_json"] format: str data: str - request_id: str | None = None diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 2088ee36e89a..6e820f1a4def 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -75,9 +75,7 @@ async def test_prithvi_mae_plugin_online( # verify the output is formatted as expected for this plugin plugin_data = parsed_response.data - assert all( - plugin_data.get(attr) for attr in ["type", "format", "data", "request_id"] - ) + assert all(plugin_data.get(attr) for attr in ["type", "format", "data"]) # We just check that the output is a valid base64 string. # Raises an exception and fails the test if the string is corrupted. @@ -110,9 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): output = pooler_output[0].outputs # verify the output is formatted as expected for this plugin - assert all( - hasattr(output, attr) for attr in ["type", "format", "data", "request_id"] - ) + assert all(hasattr(output, attr) for attr in ["type", "format", "data"]) # We just check that the output is a valid base64 string. # Raises an exception and fails the test if the string is corrupted. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b9147b99c985..2b4ed86957f4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -85,7 +85,6 @@ from vllm.tokenizers import TokenizerLike from vllm.tokenizers.mistral import MistralTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.counter import Counter from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.sample.logits_processor import LogitsProcessor @@ -95,6 +94,7 @@ logger = init_logger(__name__) +_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None) _R = TypeVar("_R", default=Any) @@ -1056,9 +1056,7 @@ def encode( dict(truncate_prompt_tokens=truncate_prompt_tokens), ) - io_processor_prompt = False - if isinstance(prompts, dict) and "data" in prompts: - io_processor_prompt = True + if use_io_processor := (isinstance(prompts, dict) and "data" in prompts): if self.io_processor is None: raise ValueError( "No IOProcessor plugin installed. Please refer " @@ -1068,40 +1066,42 @@ def encode( ) # Validate the request data is valid for the loaded plugin - validated_prompt = self.io_processor.parse_request(prompts) + validated_prompt = self.io_processor.parse_data(prompts) # obtain the actual model prompts from the pre-processor prompts = self.io_processor.pre_process(prompt=validated_prompt) + prompts_seq = prompt_to_seq(prompts) - if io_processor_prompt: - assert self.io_processor is not None - if is_list_of(pooling_params, PoolingParams): - validated_pooling_params: list[PoolingParams] = [] - for param in as_iter(pooling_params): - validated_pooling_params.append( - self.io_processor.validate_or_generate_params(param) - ) - pooling_params = validated_pooling_params - else: - assert not isinstance(pooling_params, Sequence) - pooling_params = self.io_processor.validate_or_generate_params( - pooling_params + params_seq: Sequence[PoolingParams] = [ + self.io_processor.merge_pooling_params(param) + for param in self._params_to_seq( + pooling_params, + len(prompts_seq), ) - - if pooling_params is None: - # Use default pooling params. - pooling_params = PoolingParams() - - for param in as_iter(pooling_params): - if param.task is None: - param.task = pooling_task - elif param.task != pooling_task: - msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!" - raise ValueError(msg) + ] + for p in params_seq: + if p.task is None: + p.task = "plugin" + else: + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + prompts_seq = prompt_to_seq(prompts) + params_seq = self._params_to_seq(pooling_params, len(prompts_seq)) + + for param in params_seq: + if param.task is None: + param.task = pooling_task + elif param.task != pooling_task: + msg = ( + f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!" + ) + raise ValueError(msg) outputs = self._run_completion( - prompts=prompts, - params=pooling_params, + prompts=prompts_seq, + params=params_seq, use_tqdm=use_tqdm, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, @@ -1111,12 +1111,10 @@ def encode( outputs, PoolingRequestOutput ) - if io_processor_prompt: + if use_io_processor: # get the post-processed model outputs assert self.io_processor is not None - processed_outputs = self.io_processor.post_process( - model_output=model_outputs - ) + processed_outputs = self.io_processor.post_process(model_outputs) return [ PoolingRequestOutput[Any]( @@ -1662,11 +1660,9 @@ def get_metrics(self) -> list["Metric"]: def _params_to_seq( self, - params: SamplingParams - | PoolingParams - | Sequence[SamplingParams | PoolingParams], + params: _P | Sequence[_P], num_requests: int, - ) -> Sequence[SamplingParams | PoolingParams]: + ) -> Sequence[_P]: if isinstance(params, Sequence): if len(params) != num_requests: raise ValueError( diff --git a/vllm/entrypoints/pooling/pooling/protocol.py b/vllm/entrypoints/pooling/pooling/protocol.py index ab2d82d8e94a..6a5a743cd68c 100644 --- a/vllm/entrypoints/pooling/pooling/protocol.py +++ b/vllm/entrypoints/pooling/pooling/protocol.py @@ -100,9 +100,6 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic data: T task: PoolingTask = "plugin" - def to_pooling_params(self): - return PoolingParams(task=self.task) - class IOProcessorResponse(OpenAIBaseModel, Generic[T]): request_id: str | None = None diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 3ad5786db0e6..5c5d649f67fd 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -85,7 +85,6 @@ async def create_pooling( request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) - is_io_processor_request = isinstance(request, IOProcessorRequest) try: lora_request = self._maybe_get_adapters(request) @@ -95,7 +94,7 @@ async def create_pooling( ) engine_prompts: Sequence[PromptType | TokPrompt] - if is_io_processor_request: + if use_io_processor := isinstance(request, IOProcessorRequest): if self.io_processor is None: raise ValueError( "No IOProcessor plugin installed. Please refer " @@ -104,7 +103,7 @@ async def create_pooling( "offline inference example for more details." ) - validated_prompt = self.io_processor.parse_request(request) + validated_prompt = self.io_processor.parse_data(request.data) raw_prompts = await self.io_processor.pre_process_async( prompt=validated_prompt, request_id=request_id @@ -141,13 +140,18 @@ async def create_pooling( # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] try: - if is_io_processor_request: - assert self.io_processor is not None and isinstance( - request, IOProcessorRequest - ) - pooling_params = self.io_processor.validate_or_generate_params() + if use_io_processor: + assert self.io_processor is not None + + pooling_params = self.io_processor.merge_pooling_params() + if pooling_params.task is None: + pooling_params.task = "plugin" + + tokenization_kwargs: dict[str, Any] = {} else: - pooling_params = request.to_pooling_params() + pooling_params = request.to_pooling_params() # type: ignore + tok_params = request.build_tok_params(self.model_config) # type: ignore + tokenization_kwargs = tok_params.get_encode_kwargs() for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" @@ -165,12 +169,6 @@ async def create_pooling( else await self._get_trace_headers(raw_request.headers) ) - if is_io_processor_request: - tokenization_kwargs: dict[str, Any] = {} - else: - tok_params = request.build_tok_params(self.model_config) # type: ignore - tokenization_kwargs = tok_params.get_encode_kwargs() - generator = self.engine_client.encode( engine_prompt, pooling_params, @@ -187,13 +185,31 @@ async def create_pooling( result_generator = merge_async_iterators(*generators) - if is_io_processor_request: + if use_io_processor: assert self.io_processor is not None output = await self.io_processor.post_process_async( - model_output=result_generator, + result_generator, request_id=request_id, ) - return self.io_processor.output_to_response(output) + + if callable( + output_to_response := getattr( + self.io_processor, "output_to_response", None + ) + ): + logger.warning_once( + "`IOProcessor.output_to_response` is deprecated. To ensure " + "consistency between offline and online APIs, " + "`IOProcessorResponse` will become a transparent wrapper " + "around output data from v0.19 onwards.", + ) + + if hasattr(output, "request_id") and output.request_id is None: + output.request_id = request_id # type: ignore + + return output_to_response(output) # type: ignore + + return IOProcessorResponse(request_id=request_id, data=output) assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest)) num_prompts = len(engine_prompts) diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py index d2dd8b1bdc1f..a978b1e74865 100644 --- a/vllm/plugins/io_processors/interface.py +++ b/vllm/plugins/io_processors/interface.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import warnings from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Sequence -from typing import Any, Generic, TypeVar +from typing import Generic, TypeVar from vllm.config import VllmConfig -from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse from vllm.inputs.data import PromptType from vllm.outputs import PoolingRequestOutput from vllm.pooling_params import PoolingParams @@ -18,8 +17,68 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): def __init__(self, vllm_config: VllmConfig): + super().__init__() + self.vllm_config = vllm_config + def parse_data(self, data: object) -> IOProcessorInput: + if callable(parse_request := getattr(self, "parse_request", None)): + warnings.warn( + "`parse_request` has been renamed to `parse_data`. " + "Please update your IO Processor Plugin to use the new name. " + "The old name will be removed in v0.19.", + DeprecationWarning, + stacklevel=2, + ) + + return parse_request(data) # type: ignore + + raise NotImplementedError + + def merge_sampling_params( + self, + params: SamplingParams | None = None, + ) -> SamplingParams: + if callable( + validate_or_generate_params := getattr( + self, "validate_or_generate_params", None + ) + ): + warnings.warn( + "`validate_or_generate_params` has been split into " + "`merge_sampling_params` and `merge_pooling_params`." + "Please update your IO Processor Plugin to use the new methods. " + "The old name will be removed in v0.19.", + DeprecationWarning, + stacklevel=2, + ) + + return validate_or_generate_params(params) # type: ignore + + return params or SamplingParams() + + def merge_pooling_params( + self, + params: PoolingParams | None = None, + ) -> PoolingParams: + if callable( + validate_or_generate_params := getattr( + self, "validate_or_generate_params", None + ) + ): + warnings.warn( + "`validate_or_generate_params` has been split into " + "`merge_sampling_params` and `merge_pooling_params`." + "Please update your IO Processor Plugin to use the new methods. " + "The old name will be removed in v0.19.", + DeprecationWarning, + stacklevel=2, + ) + + return validate_or_generate_params(params) # type: ignore + + return params or PoolingParams(task="plugin") + @abstractmethod def pre_process( self, @@ -59,19 +118,4 @@ async def post_process_async( [(i, item) async for i, item in model_output], key=lambda output: output[0] ) collected_output = [output[1] for output in sorted_output] - return self.post_process(collected_output, request_id, **kwargs) - - @abstractmethod - def parse_request(self, request: Any) -> IOProcessorInput: - raise NotImplementedError - - def validate_or_generate_params( - self, params: SamplingParams | PoolingParams | None = None - ) -> SamplingParams | PoolingParams: - return params or PoolingParams() - - @abstractmethod - def output_to_response( - self, plugin_output: IOProcessorOutput - ) -> IOProcessorResponse: - raise NotImplementedError + return self.post_process(collected_output, request_id=request_id, **kwargs) diff --git a/vllm/utils/collection_utils.py b/vllm/utils/collection_utils.py index aefaf84ee8e8..e0bd2045f701 100644 --- a/vllm/utils/collection_utils.py +++ b/vllm/utils/collection_utils.py @@ -51,12 +51,6 @@ def as_list(maybe_list: Iterable[T]) -> list[T]: return maybe_list if isinstance(maybe_list, list) else list(maybe_list) -def as_iter(obj: T | Iterable[T]) -> Iterable[T]: - if isinstance(obj, str) or not isinstance(obj, Iterable): - return [obj] # type: ignore[list-item] - return obj - - def is_list_of( value: object, typ: type[T] | tuple[type[T], ...],