From cfd7ebf4ab3bec5a42f87f3e6eb48c52a914b4a0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 28 Nov 2025 14:51:29 +0000 Subject: [PATCH 1/7] [Chore] Rename `Processor` to `InputProcessor` Signed-off-by: DarkLight1337 --- .../entrypoints/openai/test_lora_resolvers.py | 2 +- tests/entrypoints/openai/test_serving_chat.py | 14 +- .../entrypoints/openai/test_serving_engine.py | 2 +- .../entrypoints/openai/test_serving_models.py | 2 +- .../openai/test_serving_responses.py | 4 +- .../test_processor_multi_modal_uuids.py | 8 +- vllm/engine/protocol.py | 4 +- vllm/entrypoints/llm.py | 6 +- vllm/entrypoints/openai/serving_engine.py | 8 +- vllm/entrypoints/openai/serving_models.py | 2 +- vllm/model_executor/models/nemotron_vl.py | 2 +- vllm/v1/engine/async_llm.py | 14 +- vllm/v1/engine/input_processor.py | 637 +++++++++++++++++ vllm/v1/engine/llm_engine.py | 14 +- vllm/v1/engine/processor.py | 641 +----------------- 15 files changed, 690 insertions(+), 670 deletions(-) create mode 100644 vllm/v1/engine/input_processor.py diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index b05fa379c69f..4856cafef44b 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -114,7 +114,7 @@ async def mock_generate(*args, **kwargs): mock_engine.add_lora.reset_mock() mock_engine.model_config = MockModelConfig() - mock_engine.processor = MagicMock() + mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() models = OpenAIServingModels( diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index dd10384a7e8c..492e15fc82a6 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -429,7 +429,7 @@ async def test_serving_chat_returns_correct_model_name(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() - mock_engine.processor = MagicMock() + mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() serving_chat = _build_serving_chat(mock_engine) @@ -459,7 +459,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() - mock_engine.processor = MagicMock() + mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() serving_chat = _build_serving_chat(mock_engine) @@ -492,7 +492,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = mock_model_config - mock_engine.processor = MagicMock() + mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() # Initialize the serving chat @@ -537,7 +537,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = mock_model_config - mock_engine.processor = MagicMock() + mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() # Initialize the serving chat @@ -583,7 +583,7 @@ async def test_serving_chat_could_load_correct_generation_config(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = mock_model_config - mock_engine.processor = MagicMock() + mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() # Initialize the serving chat @@ -629,7 +629,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = mock_model_config - mock_engine.processor = MagicMock() + mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() serving_chat = _build_serving_chat(mock_engine) @@ -662,7 +662,7 @@ async def test_serving_chat_data_parallel_rank_extraction(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False mock_engine.model_config = MockModelConfig() - mock_engine.processor = MagicMock() + mock_engine.input_processor = MagicMock() mock_engine.io_processor = MagicMock() # Mock the generate method to return an async generator diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 46d8871441a7..29892d0bf38a 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -23,7 +23,7 @@ def serving() -> OpenAIServing: model_config.max_model_len = 32768 models = Mock(spec=OpenAIServingModels) models.model_config = model_config - models.processor = Mock() + models.input_processor = Mock() models.io_processor = Mock() serving = OpenAIServing( diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index 3c022870dba4..b585835a0667 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -30,7 +30,7 @@ async def _async_serving_models_init() -> OpenAIServingModels: mock_model_config = MagicMock(spec=ModelConfig) mock_model_config.max_model_len = 2048 mock_engine_client.model_config = mock_model_config - mock_engine_client.processor = MagicMock() + mock_engine_client.input_processor = MagicMock() mock_engine_client.io_processor = MagicMock() serving_models = OpenAIServingModels( diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index 93e11b61020c..6af32774cc5c 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -127,7 +127,7 @@ async def serving_responses_instance(self): model_config.get_diff_sampling_param.return_value = {} engine_client.model_config = model_config - engine_client.processor = MagicMock() + engine_client.input_processor = MagicMock() engine_client.io_processor = MagicMock() models = MagicMock() @@ -213,7 +213,7 @@ async def serving_responses_instance(self): model_config.get_diff_sampling_param.return_value = {} engine_client.model_config = model_config - engine_client.processor = MagicMock() + engine_client.input_processor = MagicMock() engine_client.io_processor = MagicMock() models = MagicMock() diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index cb6865e42ef8..f6b6eeeaade1 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -7,8 +7,8 @@ from vllm.assets.video import VideoAsset from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig from vllm.sampling_params import SamplingParams -from vllm.v1.engine import processor as processor_mod -from vllm.v1.engine.processor import Processor +from vllm.v1.engine import input_processor as processor_mod +from vllm.v1.engine.input_processor import InputProcessor cherry_pil_image = ImageAsset("cherry_blossom").pil_image stop_pil_image = ImageAsset("stop_sign").pil_image @@ -18,7 +18,7 @@ # Mock processor for testing def _mk_processor( monkeypatch, *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True -) -> Processor: +) -> InputProcessor: """ Create a Processor instance with minimal configuration suitable for unit tests without accessing external resources. @@ -65,7 +65,7 @@ def __init__(self, gb: float): device_config=DeviceConfig(device="cpu"), ) - return Processor(vllm_config, tokenizer=None) + return InputProcessor(vllm_config, tokenizer=None) def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 6b3ee042daf3..02741e50f6aa 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -15,7 +15,7 @@ from vllm.tasks import SupportedTask from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.processor import Processor +from vllm.v1.engine.input_processor import InputProcessor class EngineClient(ABC): @@ -23,7 +23,7 @@ class EngineClient(ABC): vllm_config: VllmConfig model_config: ModelConfig - processor: Processor + input_processor: InputProcessor io_processor: IOProcessor | None @property diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f6ee74678998..2b34f36253ed 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -347,7 +347,7 @@ def __init__( self.supported_tasks = supported_tasks self.model_config = self.llm_engine.model_config - self.processor = self.llm_engine.processor + self.input_processor = self.llm_engine.input_processor self.io_processor = self.llm_engine.io_processor def get_tokenizer(self) -> AnyTokenizer: @@ -364,7 +364,7 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) def reset_mm_cache(self) -> None: - self.processor.clear_mm_cache() + self.input_processor.clear_mm_cache() self.llm_engine.reset_mm_cache() def get_default_sampling_params(self) -> SamplingParams: @@ -1674,7 +1674,7 @@ def _process_inputs( tokenization_kwargs, ) - engine_request = self.processor.process_inputs( + engine_request = self.input_processor.process_inputs( request_id, engine_prompt, params, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index d9feee917ff4..e3748141c6da 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -284,7 +284,7 @@ def __init__( self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} self.log_error_stack = log_error_stack - self.processor = self.models.processor + self.input_processor = self.models.input_processor self.io_processor = self.models.io_processor self.model_config = self.models.model_config self.max_model_len = self.model_config.max_model_len @@ -330,7 +330,7 @@ def _get_reasoning_parser( return parser async def reset_mm_cache(self) -> None: - self.processor.clear_mm_cache() + self.input_processor.clear_mm_cache() await self.engine_client.reset_mm_cache() async def beam_search( @@ -348,7 +348,7 @@ async def beam_search( length_penalty = params.length_penalty include_stop_str_in_output = params.include_stop_str_in_output - processor = self.processor + processor = self.input_processor tokenizer = processor.tokenizer if tokenizer is None: raise ValueError( @@ -1214,7 +1214,7 @@ async def _process_inputs( self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs ) - engine_request = self.processor.process_inputs( + engine_request = self.input_processor.process_inputs( request_id, engine_prompt, params, diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index 24b9587010ca..165de5b618c4 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -69,7 +69,7 @@ def __init__( ) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) - self.processor = self.engine_client.processor + self.input_processor = self.engine_client.input_processor self.io_processor = self.engine_client.io_processor self.model_config = self.engine_client.model_config self.max_model_len = self.model_config.max_model_len diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 5a1dda8aac2c..ba9a41b59303 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -34,7 +34,7 @@ from vllm.multimodal.image import convert_image_mode from vllm.multimodal.processing import PromptUpdateDetails from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.processor import cached_image_processor_from_config +from vllm.transformers_utils import cached_image_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer from .interfaces import ( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 827a2736af28..3204018c078f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -35,9 +35,9 @@ from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError +from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.parallel_sampling import ParentRequest -from vllm.v1.engine.processor import Processor from vllm.v1.executor import Executor from vllm.v1.metrics.loggers import ( StatLoggerFactory, @@ -112,7 +112,7 @@ def __init__( else: tokenizer = init_tokenizer_from_configs(self.model_config) - self.processor = Processor(self.vllm_config, tokenizer) + self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.io_processor = get_io_processor( self.vllm_config, self.model_config.io_processor_plugin, @@ -297,7 +297,7 @@ async def add_request( "Processor has been moved under OpenAIServing and will " "be removed from AsyncLLM in v0.13." ) - request = self.processor.process_inputs( + request = self.input_processor.process_inputs( request_id, prompt, params, @@ -481,7 +481,7 @@ def _run_output_handler(self): output_processor = self.output_processor log_stats = self.log_stats logger_manager = self.logger_manager - processor = self.processor + processor = self.input_processor async def output_handler(): try: @@ -699,11 +699,11 @@ async def encode( @property def tokenizer(self) -> AnyTokenizer | None: - return self.processor.tokenizer + return self.input_processor.tokenizer @tokenizer.setter def tokenizer(self, tokenizer: AnyTokenizer | None) -> None: - self.processor.tokenizer = tokenizer + self.input_processor.tokenizer = tokenizer async def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: @@ -738,7 +738,7 @@ async def stop_profile(self) -> None: await asyncio.gather(*coros) async def reset_mm_cache(self) -> None: - self.processor.clear_mm_cache() + self.input_processor.clear_mm_cache() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self) -> None: diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py new file mode 100644 index 000000000000..cfd637931a1c --- /dev/null +++ b/vllm/v1/engine/input_processor.py @@ -0,0 +1,637 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from collections.abc import Mapping +from typing import Any, Literal, cast + +from vllm.config import VllmConfig +from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs +from vllm.inputs.parse import split_enc_dec_inputs +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_cache_from_config +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict +from vllm.multimodal.parse import MultiModalDataParser +from vllm.multimodal.processing import EncDecMultiModalProcessor +from vllm.multimodal.utils import argsort_mm_positions +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.metrics.stats import MultiModalCacheStats +from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar +from vllm.v1.structured_output.backend_lm_format_enforcer import ( + validate_structured_output_request_lm_format_enforcer, +) +from vllm.v1.structured_output.backend_outlines import ( + validate_structured_output_request_outlines, +) +from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar + +logger = init_logger(__name__) + + +class InputProcessor: + def __init__( + self, + vllm_config: VllmConfig, + tokenizer: AnyTokenizer | None, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.structured_outputs_config = vllm_config.structured_outputs_config + + self.generation_config_fields = self.model_config.try_get_generation_config() + + self.mm_registry = mm_registry + self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry) + + self.input_preprocessor = InputPreprocessor( + self.model_config, + tokenizer, + mm_registry, + mm_processor_cache=self.mm_processor_cache, + ) + + @property + def tokenizer(self) -> AnyTokenizer | None: + return self.input_preprocessor.tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: AnyTokenizer | None) -> None: + self.input_preprocessor.tokenizer = tokenizer + + def _validate_logprobs( + self, + params: SamplingParams, + ) -> None: + max_logprobs = self.model_config.max_logprobs + if max_logprobs == -1: + max_logprobs = self.model_config.get_vocab_size() + + # Validate sample logprobs. + if params.logprobs: + num_logprobs = params.logprobs + if num_logprobs == -1: + num_logprobs = self.model_config.get_vocab_size() + if num_logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {num_logprobs}, " + f"which is greater than max allowed: {max_logprobs}" + ) + + # Validate prompt logprobs. + if params.prompt_logprobs: + num_prompt_logprobs = params.prompt_logprobs + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.model_config.get_vocab_size() + if num_prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {num_prompt_logprobs}, " + f"which is greater than max allowed: {max_logprobs}" + ) + + def _validate_sampling_params( + self, + params: SamplingParams, + ) -> None: + self._validate_structured_output(params) + self._validate_logit_bias(params) + + if params.allowed_token_ids is None: + return + if not params.allowed_token_ids: + raise ValueError("allowed_token_ids is not None and empty!") + if self.tokenizer is None: + # When skip_tokenizer_init=True, we can't validate token IDs + # Skip validation and let the model handle invalid tokens + return + vocab_size = len(self.tokenizer) + if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): + raise ValueError("allowed_token_ids contains out-of-vocab token id!") + + def _validate_logit_bias( + self, + params: SamplingParams, + ) -> None: + """Validate logit_bias token IDs are within vocabulary range.""" + if not params.logit_bias: + return + + vocab_size = self.model_config.get_vocab_size() + invalid_token_ids = [] + + for token_id in params.logit_bias: + if token_id < 0 or token_id >= vocab_size: + invalid_token_ids.append(token_id) + + if invalid_token_ids: + raise ValueError( + f"token_id(s) {invalid_token_ids} in logit_bias contain " + f"out-of-vocab token ids. Vocabulary size: {vocab_size}" + ) + + def _validate_supported_sampling_params( + self, + params: SamplingParams, + ) -> None: + # Logits processors not supported. + if params.logits_processors: + raise ValueError( + "vLLM V1 does not support per request user provided logits processors." + ) + # Async scheduling + spec decode currently incompatible with some + # sampling parameters. + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.scheduler_config.async_scheduling + and ( + params.frequency_penalty != 0.0 + or params.presence_penalty != 0.0 + or params.repetition_penalty != 1.0 + or params.bad_words_token_ids + or params.structured_outputs + ) + ): + raise ValueError( + "async scheduling with spec decoding doesn't yet support " + "penalties, bad words or structured outputs in sampling parameters." + ) + + def _validate_params( + self, + params: SamplingParams | PoolingParams, + ): + """ + Validate supported SamplingParam. + Should raise ValueError if unsupported for API Server. + """ + + if isinstance(params, PoolingParams): + return + + self._validate_logprobs(params) + self._validate_sampling_params(params) + self._validate_supported_sampling_params(params) + + def _validate_multi_modal_uuids(self, prompt: PromptType) -> None: + """ + Validate that user-provided multi_modal_uuids align with + multi_modal_data in the incoming request prompt(s). + Only checks lengths; `None` entries are allowed and will be + auto-hashed downstream. + """ + + def _validate_single_prompt(single_prompt: dict | str) -> None: + if not isinstance(single_prompt, dict): + return + mm_data = single_prompt.get("multi_modal_data") + mm_uuids = single_prompt.get("multi_modal_uuids") + if not mm_data or not mm_uuids: + return + + for modality, items in mm_data.items(): + if modality in mm_uuids: + data_len = len(items) if isinstance(items, list) else 1 + uuid_len = ( + len(mm_uuids[modality]) + if isinstance(mm_uuids[modality], list) + else 1 + ) + if uuid_len != data_len: + raise ValueError( + f"multi_modal_uuids for modality '{modality}' " + "must have same length as data: got " + f"{uuid_len} uuids vs " + f"{data_len} items." + ) + else: + raise ValueError( + f"multi_modal_uuids for modality '{modality}' must " + "be provided if multi_modal_data is provided." + ) + + # Handle explicit encoder/decoder prompts or singleton prompt + if isinstance(prompt, dict) and "encoder_prompt" in prompt: + enc = prompt.get("encoder_prompt") + dec = prompt.get("decoder_prompt") + if enc is not None: + _validate_single_prompt(cast(dict | str, enc)) + if dec is not None: + _validate_single_prompt(cast(dict | str, dec)) + else: + _validate_single_prompt(prompt) # type: ignore[arg-type] + + def _validate_lora(self, lora_request: LoRARequest | None) -> None: + if lora_request is None: + return + + # LoRA request passed in while LoRA is not enabled + if not self.lora_config: + raise ValueError( + f"Got lora_request {lora_request} but LoRA is not enabled!" + ) + + if self.tokenizer is not None: + logger.warning_once( + "vLLM has deprecated support for supporting different " + "tokenizers for different LoRAs. By default, vLLM uses base " + "model's tokenizer. If you are using a LoRA " + "with its own tokenizer, consider specifying `--tokenizer " + "[lora_path]` to use the LoRA tokenizer." + ) + + def _validate_structured_output(self, params: SamplingParams) -> None: + if not params.structured_outputs or not self.structured_outputs_config: + return + + if self.model_config.skip_tokenizer_init and params.structured_outputs: + raise ValueError( + "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 + ) + + backend = self.structured_outputs_config.backend + if _backend := params.structured_outputs._backend: + # Request-level backend selection is not supported. + # The values may differ if `params` is reused and was set + # to a specific backend based on `auto` behavior in a previous + # request. We remember that it was set as a result of `auto` + # using the `_backend_was_auto` field set in the params. + if backend != _backend and not ( + backend == "auto" and params.structured_outputs._backend_was_auto + ): + raise ValueError( + "Request-level structured output backend selection is not " + f"supported. The request specified '{_backend}', but vLLM " + f"was initialised with '{backend}'. This error can be " + "resolved by removing '_backend' from the request." + ) + else: + params.structured_outputs._backend = backend + + # Request content validation + if ( + isinstance(params.structured_outputs.choice, list) + and not params.structured_outputs.choice + ): + # It is invalid for choice to be an empty list + raise ValueError( + f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501 + ) + # Reject empty string grammar early to avoid engine-side crashes + if ( + isinstance(params.structured_outputs.grammar, str) + and params.structured_outputs.grammar.strip() == "" + ): + raise ValueError("structured_outputs.grammar cannot be an empty string") + + if backend.startswith("xgrammar"): + # xgrammar with no fallback + validate_xgrammar_grammar(params) + elif backend.startswith("guidance"): + # TODO: ideally we would have the LLTokenizer here as Lark syntax + # allows <|special_token|> and similar, see + # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens + # Without tokenizer these are disallowed in grammars. + if isinstance(self.tokenizer, MistralTokenizer): + raise ValueError( + "Mistral tokenizer is not supported for the 'guidance' " + "structured output backend. Please use ['xgrammar', 'outlines'] " + "backends or tokenizer_mode='hf' instead." + ) + validate_guidance_grammar(params, tokenizer=None) + elif backend == "outlines": + # outlines backend + validate_structured_output_request_outlines(params) + elif backend == "lm-format-enforcer": + # lm format enforcer backend + if isinstance(self.tokenizer, MistralTokenizer): + raise ValueError( + "Mistral tokenizer is not supported for the 'lm-format-enforcer' " + "structured output backend. Please use ['xgrammar', 'outlines'] " + "backends or tokenizer_mode='hf' instead." + ) + validate_structured_output_request_lm_format_enforcer(params) + else: + # NOTE: backend must be "auto" here, because we have + # checked supported_backends above. + # In this mode, we set opinionated defaults based on what we think + # will satisfy the most use cases without having to worry about + # this setting. We include fallback behavior here, but not with any + # other setting where a specific backend was specified. + try: + validate_xgrammar_grammar(params) + params.structured_outputs._backend = "xgrammar" + except ValueError: + # The request either failed validation + # or includes some jsonschema feature(s) that + # are not supported in xgrammar. + if isinstance(self.tokenizer, MistralTokenizer): + # Fall back to outlines if the tokenizer is Mistral + validate_structured_output_request_outlines(params) + params.structured_outputs._backend = "outlines" + else: + # Fall back to guidance by default. + validate_guidance_grammar(params, tokenizer=None) + params.structured_outputs._backend = "guidance" + # Remember that this backend was set automatically + params.structured_outputs._backend_was_auto = True + + def _maybe_build_mm_uuids( + self, + request_id: str, + prompt: PromptType, + ) -> MultiModalUUIDDict | None: + """Build per-item multimodal hash overrides when enabled. In this case, + multimodal data items are identified by their request id, modality and + index rather than their content. + + Returns a dictionary of modality -> list[str] of overrides, or None if + disabled or no multimodal data is present. + """ + + def _extract_mm_data(p: PromptType): + if isinstance(p, dict) and "encoder_prompt" in p: + enc = p.get("encoder_prompt") + if isinstance(enc, dict): + return enc.get("multi_modal_data") + return None + if isinstance(p, dict): + return p.get("multi_modal_data") + return None + + mm_data = _extract_mm_data(prompt) + if not mm_data: + return None + + mm_uuids: dict[str, list[str | None] | str] = {} + for modality, data in mm_data.items(): + # Hash each item for embedding inputs. + n = ( + len(data) + if isinstance(data, list) or MultiModalDataParser.is_embeddings(data) + else 1 + ) + mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] + return mm_uuids + + def process_inputs( + self, + request_id: str, + prompt: PromptType, + params: SamplingParams | PoolingParams, + arrival_time: float | None = None, + lora_request: LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + trace_headers: Mapping[str, str] | None = None, + priority: int = 0, + data_parallel_rank: int | None = None, + ) -> EngineCoreRequest: + self._validate_lora(lora_request) + self._validate_params(params) + + data_parallel_size = self.vllm_config.parallel_config.data_parallel_size + if data_parallel_rank is not None and not ( + 0 <= data_parallel_rank < data_parallel_size + ): + raise ValueError( + f"data_parallel_rank {data_parallel_rank} " + f"is out of range [0, {data_parallel_size})." + ) + + if arrival_time is None: + arrival_time = time.time() + + # Optionally generate multimodal hash overrides to avoid hashing + # multimodal data items by their content as their identifiers. + + # NOTE: when users explicitly turn off BOTH prefix caching and input + # processing caching, no multimodal features or embeddings will be + # reused across requests, therefore identifying multimodal data items + # by their content is no longer necessary, and we create uuids with + # request id-modality-index as multimodal hash overrides. + if ( + self.model_config.multimodal_config + and self.model_config.multimodal_config.mm_processor_cache_gb == 0 + and not self.cache_config.enable_prefix_caching + ): + mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) + else: + # Otherwise, use user-provided uuids as multimodal hash overrides + # if provided. + self._validate_multi_modal_uuids(prompt) + if isinstance(prompt, dict): + mm_uuids = cast( + MultiModalUUIDDict | None, prompt.get("multi_modal_uuids") + ) + else: + mm_uuids = None + + # Process inputs, which includes: + # 1. Tokenize text prompt, with LoRA request if one exists. + # 2. For multimodal models with a merged preprocessor, preprocess + # multimodal data and expand prompt token ids accordingly. + processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( + prompt, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + from vllm.platforms import current_platform + + current_platform.validate_request( + prompt=prompt, + params=params, + processed_inputs=processed_inputs, + ) + + eos_token_id = self.input_preprocessor.get_eos_token_id() + + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + self._validate_model_inputs(encoder_inputs, decoder_inputs) + + # Mypy can be conservative for TypedDict unions; normalize access. + if decoder_inputs["type"] == "embeds": + prompt_token_ids = None + prompt_embeds = decoder_inputs["prompt_embeds"] + else: + prompt_token_ids = decoder_inputs["prompt_token_ids"] + prompt_embeds = None + + sampling_params = None + pooling_params = None + if isinstance(params, SamplingParams): + # TODO: can we avoid cloning here in multiproc case? + sampling_params = params.clone() + # If unset max tokens, then generate up to the max_model_len. + if sampling_params.max_tokens is None: + seq_len = length_from_prompt_token_ids_or_embeds( + prompt_token_ids, prompt_embeds + ) + sampling_params.max_tokens = self.model_config.max_model_len - seq_len + sampling_params.update_from_generation_config( + self.generation_config_fields, eos_token_id + ) + if self.tokenizer is not None: + sampling_params.update_from_tokenizer(self.tokenizer) + else: + pooling_params = params.clone() + + # Multimodal related. + mm_features: list[MultiModalFeatureSpec] | None = None + + if decoder_inputs["type"] == "multimodal": + decoder_mm_inputs = decoder_inputs["mm_kwargs"] + decoder_mm_positions = decoder_inputs["mm_placeholders"] + decoder_mm_hashes = decoder_inputs["mm_hashes"] + + # Merge and flatten multimodal placeholders, hashes and inputs + # from dictionaries to lists, and sort them by each item's position + # in the input sequence. + sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) + + mm_features = [] + for modality, idx in sorted_mm_idxs: + mm_features.append( + MultiModalFeatureSpec( + data=decoder_mm_inputs[modality][idx], + modality=modality, + identifier=decoder_mm_hashes[modality][idx], + mm_position=decoder_mm_positions[modality][idx], + ) + ) + + return EngineCoreRequest( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + prompt_embeds=prompt_embeds, + mm_features=mm_features, + sampling_params=sampling_params, + pooling_params=pooling_params, + eos_token_id=eos_token_id, + arrival_time=arrival_time, + lora_request=lora_request, + cache_salt=decoder_inputs.get("cache_salt"), + priority=priority, + data_parallel_rank=data_parallel_rank, + trace_headers=trace_headers, + ) + + def _validate_model_inputs( + self, encoder_inputs: SingletonInputs | None, decoder_inputs: SingletonInputs + ): + if encoder_inputs is not None: + self._validate_model_input(encoder_inputs, prompt_type="encoder") + + self._validate_model_input(decoder_inputs, prompt_type="decoder") + + def _validate_model_input( + self, + prompt_inputs: SingletonInputs, + *, + prompt_type: Literal["encoder", "decoder"], + ): + model_config = self.model_config + + prompt_ids = ( + None + if prompt_inputs["type"] == "embeds" + else prompt_inputs["prompt_token_ids"] + ) + prompt_embeds = ( + prompt_inputs["prompt_embeds"] + if prompt_inputs["type"] == "embeds" + else None + ) + prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds) + if not prompt_ids: + if prompt_type == "encoder" and model_config.is_multimodal_model: + pass # Mllama may have empty encoder inputs for text-only data + elif prompt_inputs["type"] == "embeds": + pass # Prompt embeds should not have prompt_ids. + else: + raise ValueError(f"The {prompt_type} prompt cannot be empty") + + tokenizer = self.tokenizer + if tokenizer is not None: + max_input_id = max(prompt_ids or [], default=0) + + # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while + # self.model_config.get_vocab_size() is the model’s vocab size. + # For Qwen3 models, the language model has extra tokens that do + # not exist in the tokenizer, and vice versa for multimodal + # placeholder tokens in some multimodal models. + # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501 + # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501 + + # Here we take the max of the two to determine if a token id is + # truly out-of-vocabulary. + if max_input_id > max( + tokenizer.max_token_id, self.model_config.get_vocab_size() - 1 + ): + raise ValueError(f"Token id {max_input_id} is out of vocabulary") + + max_prompt_len = self.model_config.max_model_len + if prompt_len > max_prompt_len: + if prompt_type == "encoder" and model_config.is_multimodal_model: + mm_registry = self.input_preprocessor.mm_registry + mm_processor = mm_registry.create_processor( + model_config, + tokenizer=tokenizer, + ) + assert isinstance(mm_processor, EncDecMultiModalProcessor) + + if mm_processor.pad_dummy_encoder_prompt: + return # Skip encoder length check for Whisper + + if model_config.is_multimodal_model: + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well." + ) + else: + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens." + ) + + raise ValueError( + f"The {prompt_type} prompt (length {prompt_len}) is " + f"longer than the maximum model length of {max_prompt_len}. " + f"{suggestion}" + ) + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens + + if ( + prompt_len == max_prompt_len + and prompt_type == "decoder" + and not model_config.is_multimodal_model + and self.model_config.runner_type != "pooling" + ): + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens (prompt + requested output tokens)." + ) + raise ValueError( + f"The {prompt_type} prompt (length {prompt_len}) plus the number of " + f"requested output tokens (at least 1) is longer than the maximum " + f"model length of {max_prompt_len}. {suggestion}" + ) + + def stat_mm_cache(self) -> MultiModalCacheStats | None: + return self.input_preprocessor.stat_mm_cache() + + def clear_mm_cache(self) -> None: + self.input_preprocessor.clear_mm_cache() diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index dffe05445ee4..3ac4751e08fd 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -28,9 +28,9 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest -from vllm.v1.engine.processor import Processor from vllm.v1.executor import Executor from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot @@ -88,7 +88,7 @@ def __init__( else: tokenizer = init_tokenizer_from_configs(self.model_config) - self.processor = Processor(self.vllm_config, tokenizer) + self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.io_processor = get_io_processor( self.vllm_config, self.model_config.io_processor_plugin, @@ -235,7 +235,7 @@ def add_request( "Processor has been moved under LLM and will " "be removed from LLMEngine in v0.13." ) - request = self.processor.process_inputs( + request = self.input_processor.process_inputs( request_id, prompt, params, @@ -307,7 +307,7 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]: self.logger_manager.record( scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, - mm_cache_stats=self.processor.stat_mm_cache(), + mm_cache_stats=self.input_processor.stat_mm_cache(), ) self.do_log_stats_with_interval() @@ -320,7 +320,7 @@ def stop_profile(self): self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.clear_mm_cache() + self.input_processor.clear_mm_cache() self.engine_core.reset_mm_cache() def reset_prefix_cache(self): @@ -347,11 +347,11 @@ def get_metrics(self) -> list[Metric]: @property def tokenizer(self) -> AnyTokenizer | None: - return self.processor.tokenizer + return self.input_processor.tokenizer @tokenizer.setter def tokenizer(self, tokenizer: AnyTokenizer | None) -> None: - self.processor.tokenizer = tokenizer + self.input_processor.tokenizer = tokenizer def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index af4f0e410e25..bc5c7fc400fd 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,637 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import warnings -import time -from collections.abc import Mapping -from typing import Any, Literal, cast -from vllm.config import VllmConfig -from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs -from vllm.inputs.parse import split_enc_dec_inputs -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.cache import processor_cache_from_config -from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict -from vllm.multimodal.parse import MultiModalDataParser -from vllm.multimodal.processing import EncDecMultiModalProcessor -from vllm.multimodal.utils import argsort_mm_positions -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer -from vllm.utils import length_from_prompt_token_ids_or_embeds -from vllm.v1.engine import EngineCoreRequest -from vllm.v1.metrics.stats import MultiModalCacheStats -from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar -from vllm.v1.structured_output.backend_lm_format_enforcer import ( - validate_structured_output_request_lm_format_enforcer, -) -from vllm.v1.structured_output.backend_outlines import ( - validate_structured_output_request_outlines, -) -from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar +def __getattr__(name: str): + if name == "Processor": + from .input_processor import InputProcessor -logger = init_logger(__name__) - - -class Processor: - def __init__( - self, - vllm_config: VllmConfig, - tokenizer: AnyTokenizer | None, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.structured_outputs_config = vllm_config.structured_outputs_config - - self.generation_config_fields = self.model_config.try_get_generation_config() - - self.mm_registry = mm_registry - self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry) - - self.input_preprocessor = InputPreprocessor( - self.model_config, - tokenizer, - mm_registry, - mm_processor_cache=self.mm_processor_cache, + warnings.warn( + "`vllm.v1.engine.processor.Processor` has been moved to " + "`vllm.v1.engine.input_processor.InputProcessor`. " + "The old name will be removed in v0.13.", + DeprecationWarning, + stacklevel=2, ) - @property - def tokenizer(self) -> AnyTokenizer | None: - return self.input_preprocessor.tokenizer - - @tokenizer.setter - def tokenizer(self, tokenizer: AnyTokenizer | None) -> None: - self.input_preprocessor.tokenizer = tokenizer - - def _validate_logprobs( - self, - params: SamplingParams, - ) -> None: - max_logprobs = self.model_config.max_logprobs - if max_logprobs == -1: - max_logprobs = self.model_config.get_vocab_size() - - # Validate sample logprobs. - if params.logprobs: - num_logprobs = params.logprobs - if num_logprobs == -1: - num_logprobs = self.model_config.get_vocab_size() - if num_logprobs > max_logprobs: - raise ValueError( - f"Requested sample logprobs of {num_logprobs}, " - f"which is greater than max allowed: {max_logprobs}" - ) - - # Validate prompt logprobs. - if params.prompt_logprobs: - num_prompt_logprobs = params.prompt_logprobs - if num_prompt_logprobs == -1: - num_prompt_logprobs = self.model_config.get_vocab_size() - if num_prompt_logprobs > max_logprobs: - raise ValueError( - f"Requested prompt logprobs of {num_prompt_logprobs}, " - f"which is greater than max allowed: {max_logprobs}" - ) - - def _validate_sampling_params( - self, - params: SamplingParams, - ) -> None: - self._validate_structured_output(params) - self._validate_logit_bias(params) - - if params.allowed_token_ids is None: - return - if not params.allowed_token_ids: - raise ValueError("allowed_token_ids is not None and empty!") - if self.tokenizer is None: - # When skip_tokenizer_init=True, we can't validate token IDs - # Skip validation and let the model handle invalid tokens - return - vocab_size = len(self.tokenizer) - if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): - raise ValueError("allowed_token_ids contains out-of-vocab token id!") - - def _validate_logit_bias( - self, - params: SamplingParams, - ) -> None: - """Validate logit_bias token IDs are within vocabulary range.""" - if not params.logit_bias: - return - - vocab_size = self.model_config.get_vocab_size() - invalid_token_ids = [] - - for token_id in params.logit_bias: - if token_id < 0 or token_id >= vocab_size: - invalid_token_ids.append(token_id) - - if invalid_token_ids: - raise ValueError( - f"token_id(s) {invalid_token_ids} in logit_bias contain " - f"out-of-vocab token ids. Vocabulary size: {vocab_size}" - ) - - def _validate_supported_sampling_params( - self, - params: SamplingParams, - ) -> None: - # Logits processors not supported. - if params.logits_processors: - raise ValueError( - "vLLM V1 does not support per request user provided logits processors." - ) - # Async scheduling + spec decode currently incompatible with some - # sampling parameters. - if ( - self.vllm_config.speculative_config is not None - and self.vllm_config.scheduler_config.async_scheduling - and ( - params.frequency_penalty != 0.0 - or params.presence_penalty != 0.0 - or params.repetition_penalty != 1.0 - or params.bad_words_token_ids - or params.structured_outputs - ) - ): - raise ValueError( - "async scheduling with spec decoding doesn't yet support " - "penalties, bad words or structured outputs in sampling parameters." - ) - - def _validate_params( - self, - params: SamplingParams | PoolingParams, - ): - """ - Validate supported SamplingParam. - Should raise ValueError if unsupported for API Server. - """ - - if isinstance(params, PoolingParams): - return - - self._validate_logprobs(params) - self._validate_sampling_params(params) - self._validate_supported_sampling_params(params) - - def _validate_multi_modal_uuids(self, prompt: PromptType) -> None: - """ - Validate that user-provided multi_modal_uuids align with - multi_modal_data in the incoming request prompt(s). - Only checks lengths; `None` entries are allowed and will be - auto-hashed downstream. - """ - - def _validate_single_prompt(single_prompt: dict | str) -> None: - if not isinstance(single_prompt, dict): - return - mm_data = single_prompt.get("multi_modal_data") - mm_uuids = single_prompt.get("multi_modal_uuids") - if not mm_data or not mm_uuids: - return - - for modality, items in mm_data.items(): - if modality in mm_uuids: - data_len = len(items) if isinstance(items, list) else 1 - uuid_len = ( - len(mm_uuids[modality]) - if isinstance(mm_uuids[modality], list) - else 1 - ) - if uuid_len != data_len: - raise ValueError( - f"multi_modal_uuids for modality '{modality}' " - "must have same length as data: got " - f"{uuid_len} uuids vs " - f"{data_len} items." - ) - else: - raise ValueError( - f"multi_modal_uuids for modality '{modality}' must " - "be provided if multi_modal_data is provided." - ) - - # Handle explicit encoder/decoder prompts or singleton prompt - if isinstance(prompt, dict) and "encoder_prompt" in prompt: - enc = prompt.get("encoder_prompt") - dec = prompt.get("decoder_prompt") - if enc is not None: - _validate_single_prompt(cast(dict | str, enc)) - if dec is not None: - _validate_single_prompt(cast(dict | str, dec)) - else: - _validate_single_prompt(prompt) # type: ignore[arg-type] - - def _validate_lora(self, lora_request: LoRARequest | None) -> None: - if lora_request is None: - return - - # LoRA request passed in while LoRA is not enabled - if not self.lora_config: - raise ValueError( - f"Got lora_request {lora_request} but LoRA is not enabled!" - ) - - if self.tokenizer is not None: - logger.warning_once( - "vLLM has deprecated support for supporting different " - "tokenizers for different LoRAs. By default, vLLM uses base " - "model's tokenizer. If you are using a LoRA " - "with its own tokenizer, consider specifying `--tokenizer " - "[lora_path]` to use the LoRA tokenizer." - ) - - def _validate_structured_output(self, params: SamplingParams) -> None: - if not params.structured_outputs or not self.structured_outputs_config: - return - - if self.model_config.skip_tokenizer_init and params.structured_outputs: - raise ValueError( - "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 - ) - - backend = self.structured_outputs_config.backend - if _backend := params.structured_outputs._backend: - # Request-level backend selection is not supported. - # The values may differ if `params` is reused and was set - # to a specific backend based on `auto` behavior in a previous - # request. We remember that it was set as a result of `auto` - # using the `_backend_was_auto` field set in the params. - if backend != _backend and not ( - backend == "auto" and params.structured_outputs._backend_was_auto - ): - raise ValueError( - "Request-level structured output backend selection is not " - f"supported. The request specified '{_backend}', but vLLM " - f"was initialised with '{backend}'. This error can be " - "resolved by removing '_backend' from the request." - ) - else: - params.structured_outputs._backend = backend - - # Request content validation - if ( - isinstance(params.structured_outputs.choice, list) - and not params.structured_outputs.choice - ): - # It is invalid for choice to be an empty list - raise ValueError( - f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501 - ) - # Reject empty string grammar early to avoid engine-side crashes - if ( - isinstance(params.structured_outputs.grammar, str) - and params.structured_outputs.grammar.strip() == "" - ): - raise ValueError("structured_outputs.grammar cannot be an empty string") - - if backend.startswith("xgrammar"): - # xgrammar with no fallback - validate_xgrammar_grammar(params) - elif backend.startswith("guidance"): - # TODO: ideally we would have the LLTokenizer here as Lark syntax - # allows <|special_token|> and similar, see - # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens - # Without tokenizer these are disallowed in grammars. - if isinstance(self.tokenizer, MistralTokenizer): - raise ValueError( - "Mistral tokenizer is not supported for the 'guidance' " - "structured output backend. Please use ['xgrammar', 'outlines'] " - "backends or tokenizer_mode='hf' instead." - ) - validate_guidance_grammar(params, tokenizer=None) - elif backend == "outlines": - # outlines backend - validate_structured_output_request_outlines(params) - elif backend == "lm-format-enforcer": - # lm format enforcer backend - if isinstance(self.tokenizer, MistralTokenizer): - raise ValueError( - "Mistral tokenizer is not supported for the 'lm-format-enforcer' " - "structured output backend. Please use ['xgrammar', 'outlines'] " - "backends or tokenizer_mode='hf' instead." - ) - validate_structured_output_request_lm_format_enforcer(params) - else: - # NOTE: backend must be "auto" here, because we have - # checked supported_backends above. - # In this mode, we set opinionated defaults based on what we think - # will satisfy the most use cases without having to worry about - # this setting. We include fallback behavior here, but not with any - # other setting where a specific backend was specified. - try: - validate_xgrammar_grammar(params) - params.structured_outputs._backend = "xgrammar" - except ValueError: - # The request either failed validation - # or includes some jsonschema feature(s) that - # are not supported in xgrammar. - if isinstance(self.tokenizer, MistralTokenizer): - # Fall back to outlines if the tokenizer is Mistral - validate_structured_output_request_outlines(params) - params.structured_outputs._backend = "outlines" - else: - # Fall back to guidance by default. - validate_guidance_grammar(params, tokenizer=None) - params.structured_outputs._backend = "guidance" - # Remember that this backend was set automatically - params.structured_outputs._backend_was_auto = True - - def _maybe_build_mm_uuids( - self, - request_id: str, - prompt: PromptType, - ) -> MultiModalUUIDDict | None: - """Build per-item multimodal hash overrides when enabled. In this case, - multimodal data items are identified by their request id, modality and - index rather than their content. - - Returns a dictionary of modality -> list[str] of overrides, or None if - disabled or no multimodal data is present. - """ - - def _extract_mm_data(p: PromptType): - if isinstance(p, dict) and "encoder_prompt" in p: - enc = p.get("encoder_prompt") - if isinstance(enc, dict): - return enc.get("multi_modal_data") - return None - if isinstance(p, dict): - return p.get("multi_modal_data") - return None - - mm_data = _extract_mm_data(prompt) - if not mm_data: - return None - - mm_uuids: dict[str, list[str | None] | str] = {} - for modality, data in mm_data.items(): - # Hash each item for embedding inputs. - n = ( - len(data) - if isinstance(data, list) or MultiModalDataParser.is_embeddings(data) - else 1 - ) - mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] - return mm_uuids - - def process_inputs( - self, - request_id: str, - prompt: PromptType, - params: SamplingParams | PoolingParams, - arrival_time: float | None = None, - lora_request: LoRARequest | None = None, - tokenization_kwargs: dict[str, Any] | None = None, - trace_headers: Mapping[str, str] | None = None, - priority: int = 0, - data_parallel_rank: int | None = None, - ) -> EngineCoreRequest: - self._validate_lora(lora_request) - self._validate_params(params) - - data_parallel_size = self.vllm_config.parallel_config.data_parallel_size - if data_parallel_rank is not None and not ( - 0 <= data_parallel_rank < data_parallel_size - ): - raise ValueError( - f"data_parallel_rank {data_parallel_rank} " - f"is out of range [0, {data_parallel_size})." - ) - - if arrival_time is None: - arrival_time = time.time() - - # Optionally generate multimodal hash overrides to avoid hashing - # multimodal data items by their content as their identifiers. - - # NOTE: when users explicitly turn off BOTH prefix caching and input - # processing caching, no multimodal features or embeddings will be - # reused across requests, therefore identifying multimodal data items - # by their content is no longer necessary, and we create uuids with - # request id-modality-index as multimodal hash overrides. - if ( - self.model_config.multimodal_config - and self.model_config.multimodal_config.mm_processor_cache_gb == 0 - and not self.cache_config.enable_prefix_caching - ): - mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) - else: - # Otherwise, use user-provided uuids as multimodal hash overrides - # if provided. - self._validate_multi_modal_uuids(prompt) - if isinstance(prompt, dict): - mm_uuids = cast( - MultiModalUUIDDict | None, prompt.get("multi_modal_uuids") - ) - else: - mm_uuids = None - - # Process inputs, which includes: - # 1. Tokenize text prompt, with LoRA request if one exists. - # 2. For multimodal models with a merged preprocessor, preprocess - # multimodal data and expand prompt token ids accordingly. - processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( - prompt, - tokenization_kwargs=tokenization_kwargs, - mm_uuids=mm_uuids, - ) - from vllm.platforms import current_platform - - current_platform.validate_request( - prompt=prompt, - params=params, - processed_inputs=processed_inputs, - ) - - eos_token_id = self.input_preprocessor.get_eos_token_id() - - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - self._validate_model_inputs(encoder_inputs, decoder_inputs) - - # Mypy can be conservative for TypedDict unions; normalize access. - if decoder_inputs["type"] == "embeds": - prompt_token_ids = None - prompt_embeds = decoder_inputs["prompt_embeds"] - else: - prompt_token_ids = decoder_inputs["prompt_token_ids"] - prompt_embeds = None - - sampling_params = None - pooling_params = None - if isinstance(params, SamplingParams): - # TODO: can we avoid cloning here in multiproc case? - sampling_params = params.clone() - # If unset max tokens, then generate up to the max_model_len. - if sampling_params.max_tokens is None: - seq_len = length_from_prompt_token_ids_or_embeds( - prompt_token_ids, prompt_embeds - ) - sampling_params.max_tokens = self.model_config.max_model_len - seq_len - sampling_params.update_from_generation_config( - self.generation_config_fields, eos_token_id - ) - if self.tokenizer is not None: - sampling_params.update_from_tokenizer(self.tokenizer) - else: - pooling_params = params.clone() - - # Multimodal related. - mm_features: list[MultiModalFeatureSpec] | None = None - - if decoder_inputs["type"] == "multimodal": - decoder_mm_inputs = decoder_inputs["mm_kwargs"] - decoder_mm_positions = decoder_inputs["mm_placeholders"] - decoder_mm_hashes = decoder_inputs["mm_hashes"] - - # Merge and flatten multimodal placeholders, hashes and inputs - # from dictionaries to lists, and sort them by each item's position - # in the input sequence. - sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) - - mm_features = [] - for modality, idx in sorted_mm_idxs: - mm_features.append( - MultiModalFeatureSpec( - data=decoder_mm_inputs[modality][idx], - modality=modality, - identifier=decoder_mm_hashes[modality][idx], - mm_position=decoder_mm_positions[modality][idx], - ) - ) - - return EngineCoreRequest( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - prompt_embeds=prompt_embeds, - mm_features=mm_features, - sampling_params=sampling_params, - pooling_params=pooling_params, - eos_token_id=eos_token_id, - arrival_time=arrival_time, - lora_request=lora_request, - cache_salt=decoder_inputs.get("cache_salt"), - priority=priority, - data_parallel_rank=data_parallel_rank, - trace_headers=trace_headers, - ) - - def _validate_model_inputs( - self, encoder_inputs: SingletonInputs | None, decoder_inputs: SingletonInputs - ): - if encoder_inputs is not None: - self._validate_model_input(encoder_inputs, prompt_type="encoder") - - self._validate_model_input(decoder_inputs, prompt_type="decoder") - - def _validate_model_input( - self, - prompt_inputs: SingletonInputs, - *, - prompt_type: Literal["encoder", "decoder"], - ): - model_config = self.model_config - - prompt_ids = ( - None - if prompt_inputs["type"] == "embeds" - else prompt_inputs["prompt_token_ids"] - ) - prompt_embeds = ( - prompt_inputs["prompt_embeds"] - if prompt_inputs["type"] == "embeds" - else None - ) - prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds) - if not prompt_ids: - if prompt_type == "encoder" and model_config.is_multimodal_model: - pass # Mllama may have empty encoder inputs for text-only data - elif prompt_inputs["type"] == "embeds": - pass # Prompt embeds should not have prompt_ids. - else: - raise ValueError(f"The {prompt_type} prompt cannot be empty") - - tokenizer = self.tokenizer - if tokenizer is not None: - max_input_id = max(prompt_ids or [], default=0) - - # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while - # self.model_config.get_vocab_size() is the model’s vocab size. - # For Qwen3 models, the language model has extra tokens that do - # not exist in the tokenizer, and vice versa for multimodal - # placeholder tokens in some multimodal models. - # See https://github.com/QwenLM/Qwen3/issues/29#issuecomment-1933720399 # noqa: E501 - # and https://github.com/vllm-project/vllm/pull/22471#discussion_r2312251421 # noqa: E501 - - # Here we take the max of the two to determine if a token id is - # truly out-of-vocabulary. - if max_input_id > max( - tokenizer.max_token_id, self.model_config.get_vocab_size() - 1 - ): - raise ValueError(f"Token id {max_input_id} is out of vocabulary") - - max_prompt_len = self.model_config.max_model_len - if prompt_len > max_prompt_len: - if prompt_type == "encoder" and model_config.is_multimodal_model: - mm_registry = self.input_preprocessor.mm_registry - mm_processor = mm_registry.create_processor( - model_config, - tokenizer=tokenizer, - ) - assert isinstance(mm_processor, EncDecMultiModalProcessor) - - if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper - - if model_config.is_multimodal_model: - suggestion = ( - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens plus multimodal tokens. For image " - "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well." - ) - else: - suggestion = ( - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens." - ) - - raise ValueError( - f"The {prompt_type} prompt (length {prompt_len}) is " - f"longer than the maximum model length of {max_prompt_len}. " - f"{suggestion}" - ) - - # TODO: Find out how many placeholder tokens are there so we can - # check that chunked prefill does not truncate them - # max_batch_len = self.scheduler_config.max_num_batched_tokens - - if ( - prompt_len == max_prompt_len - and prompt_type == "decoder" - and not model_config.is_multimodal_model - and self.model_config.runner_type != "pooling" - ): - suggestion = ( - "Make sure that `max_model_len` is no smaller than the " - "number of text tokens (prompt + requested output tokens)." - ) - raise ValueError( - f"The {prompt_type} prompt (length {prompt_len}) plus the number of " - f"requested output tokens (at least 1) is longer than the maximum " - f"model length of {max_prompt_len}. {suggestion}" - ) - - def stat_mm_cache(self) -> MultiModalCacheStats | None: - return self.input_preprocessor.stat_mm_cache() + return InputProcessor - def clear_mm_cache(self) -> None: - self.input_preprocessor.clear_mm_cache() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From 40ceb6faf7f1b4b48659c43cf0f225b490853172 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 28 Nov 2025 14:53:59 +0000 Subject: [PATCH 2/7] Rename Signed-off-by: DarkLight1337 --- tests/v1/engine/test_processor_multi_modal_uuids.py | 4 ++-- vllm/entrypoints/openai/serving_engine.py | 4 ++-- vllm/v1/engine/async_llm.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index f6b6eeeaade1..c4c9b32b2890 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -7,7 +7,7 @@ from vllm.assets.video import VideoAsset from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig from vllm.sampling_params import SamplingParams -from vllm.v1.engine import input_processor as processor_mod +from vllm.v1.engine import input_processor as input_processor_mod from vllm.v1.engine.input_processor import InputProcessor cherry_pil_image = ImageAsset("cherry_blossom").pil_image @@ -36,7 +36,7 @@ def _mk_processor( raising=True, ) monkeypatch.setattr( - processor_mod, + input_processor_mod, "processor_cache_from_config", lambda vllm_config, mm_registry: None, raising=True, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e3748141c6da..cca2fd982fe0 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -348,8 +348,8 @@ async def beam_search( length_penalty = params.length_penalty include_stop_str_in_output = params.include_stop_str_in_output - processor = self.input_processor - tokenizer = processor.tokenizer + input_processor = self.input_processor + tokenizer = input_processor.tokenizer if tokenizer is None: raise ValueError( "You cannot use beam search when `skip_tokenizer_init` is True" diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3204018c078f..71c60b9fd6ad 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -481,7 +481,7 @@ def _run_output_handler(self): output_processor = self.output_processor log_stats = self.log_stats logger_manager = self.logger_manager - processor = self.input_processor + input_processor = self.input_processor async def output_handler(): try: @@ -532,7 +532,7 @@ async def output_handler(): engine_idx=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, - mm_cache_stats=processor.stat_mm_cache(), + mm_cache_stats=input_processor.stat_mm_cache(), ) except Exception as e: logger.exception("AsyncLLM output_handler failed.") From 17c9ec489296bd03116619d10d643067060cc726 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 28 Nov 2025 14:55:14 +0000 Subject: [PATCH 3/7] Oops Signed-off-by: DarkLight1337 --- vllm/model_executor/models/nemotron_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index ba9a41b59303..5a1dda8aac2c 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -34,7 +34,7 @@ from vllm.multimodal.image import convert_image_mode from vllm.multimodal.processing import PromptUpdateDetails from vllm.sequence import IntermediateTensors -from vllm.transformers_utils import cached_image_processor_from_config +from vllm.transformers_utils.processor import cached_image_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer from .interfaces import ( From bab6f294e8a34cdde2cf174b672403bde18953f7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 28 Nov 2025 14:59:05 +0000 Subject: [PATCH 4/7] More renames Signed-off-by: DarkLight1337 --- .../test_processor_multi_modal_uuids.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index c4c9b32b2890..1b11b8af49d1 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -15,8 +15,7 @@ baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays -# Mock processor for testing -def _mk_processor( +def _mock_input_processor( monkeypatch, *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True ) -> InputProcessor: """ @@ -69,7 +68,7 @@ def __init__(self, gb: float): def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): - processor = _mk_processor(monkeypatch) + input_processor = _mock_input_processor(monkeypatch) prompt = { "prompt": "USER: \nDescribe\nASSISTANT:", @@ -79,7 +78,7 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): } with pytest.raises(ValueError, match="must have same length as data"): - processor.process_inputs( + input_processor.process_inputs( request_id="req-1", prompt=prompt, # type: ignore[arg-type] params=SamplingParams(), @@ -87,7 +86,7 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): def test_multi_modal_uuids_missing_modality_raises(monkeypatch): - processor = _mk_processor(monkeypatch) + input_processor = _mock_input_processor(monkeypatch) prompt = { "prompt": "USER: