Skip to content

Commit de56065

Browse files
DarkLight1337ywang96
authored andcommitted
[Core] Registry for processing model inputs (vllm-project#5214)
Co-authored-by: ywang96 <[email protected]> Signed-off-by: Alvant <[email protected]>
1 parent 26eb509 commit de56065

26 files changed

+778
-392
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
.. _input_processing_pipeline:
2+
3+
Input Processing Pipeline
4+
=========================
5+
6+
1. Input data is passed to :class:`~vllm.LLMEngine` (or :class:`~vllm.AsyncLLMEngine`).
7+
8+
2. Tokenize the data if necessary.
9+
10+
3. Process the inputs using :meth:`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
11+
12+
- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.
13+
14+
4. Send the processed inputs to :class:`~vllm.executor.executor_base.ExecutorBase`.
15+
16+
5. Distribute the inputs via :class:`~vllm.worker.worker_base.WorkerBase` to :class:`~vllm.worker.model_runner_base.ModelRunnerBase`.
17+
18+
6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
19+
20+
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
.. _input_processing:
2+
3+
Input Processing
4+
================
5+
6+
.. currentmodule:: vllm.inputs
7+
8+
vLLM provides a mechanism for defining input processors for each model so that the inputs are processed
9+
in :class:`~vllm.LLMEngine` before they are passed to model executors.
10+
11+
Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input
12+
data in addition to input prompt, but it can be extended to text-only language models when needed.
13+
14+
Guides
15+
++++++
16+
17+
.. toctree::
18+
:maxdepth: 1
19+
20+
input_processing_pipeline
21+
22+
Module Contents
23+
+++++++++++++++
24+
25+
LLM Engine Inputs
26+
-----------------
27+
28+
.. autoclass:: vllm.inputs.LLMInputs
29+
:members:
30+
:show-inheritance:
31+
32+
Registry
33+
--------
34+
35+
.. autodata:: vllm.inputs.INPUT_REGISTRY
36+
37+
.. automodule:: vllm.inputs.registry
38+
:members:
39+
:show-inheritance:

docs/source/dev/multimodal/multimodal_index.rst

+1-7
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ By default, vLLM models do not support multi-modal inputs. To enable multi-modal
1212
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
1313
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.
1414

15-
.. contents::
16-
:local:
17-
:backlinks: none
18-
1915
Module Contents
2016
+++++++++++++++
2117

@@ -24,9 +20,7 @@ Module Contents
2420
Registry
2521
--------
2622

27-
.. data:: vllm.multimodal.MULTIMODAL_REGISTRY
28-
29-
The global :class:`MultiModalRegistry` which is used by model runners.
23+
.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY
3024

3125
.. autoclass:: vllm.multimodal.MultiModalRegistry
3226
:members:

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ Documentation
120120
dev/offline_inference/offline_index
121121
dev/engine/engine_index
122122
dev/kernel/paged_attention
123+
dev/input_processing/model_inputs_index
123124
dev/multimodal/multimodal_index
124125
dev/dockerfile/dockerfile
125126

docs/source/models/adding_model.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/
3737
2. Rewrite the :code:`forward` methods
3838
--------------------------------------
3939

40-
Next, you need to rewrite the :code:`forward` methods of your model by following these steps:
40+
Next, you need to rewrite the :meth:`~torch.nn.Module.forward` method of your model by following these steps:
4141

4242
1. Remove any unnecessary code, such as the code only used for training.
4343
2. Change the input parameters:
@@ -75,7 +75,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
7575

7676
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
7777
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
78-
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
78+
For the embedding layer, you can simply replace :class:`torch.nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
7979
When it comes to the linear layers, we provide the following options to parallelize them:
8080

8181
* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.

examples/phi3v_example.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ def run_phi3v():
1111
model_path = "microsoft/Phi-3-vision-128k-instruct"
1212

1313
# Note: The model has 128k context length by default which may cause OOM
14-
# If that's the case, override `max_model_len` with a smaller value via args
14+
# In this example, we override max_model_len to 2048.
1515
llm = LLM(
1616
model=model_path,
1717
trust_remote_code=True,
1818
image_input_type="pixel_values",
1919
image_token_id=32044,
2020
image_input_shape="1,3,1008,1344",
2121
image_feature_size=1921,
22+
max_model_len=2048,
2223
)
2324

2425
image = Image.open("images/cherry_blossom.jpg")

tests/multimodal/test_processor.py renamed to tests/multimodal/test_mapper.py

+32-37
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,24 @@ def test_clip_image_processor(image_assets, dtype):
2525
seed=0,
2626
dtype=dtype,
2727
revision=None,
28-
)
29-
vlm_config = VisionLanguageConfig(
30-
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
31-
image_token_id=32000,
32-
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
33-
image_feature_size=576,
34-
image_processor=MODEL_NAME,
35-
image_processor_revision=None,
28+
multimodal_config=VisionLanguageConfig(
29+
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
30+
image_token_id=32000,
31+
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
32+
image_feature_size=576,
33+
image_processor=MODEL_NAME,
34+
image_processor_revision=None,
35+
),
3636
)
3737

3838
for asset in image_assets:
3939
hf_result = hf_processor.preprocess(
4040
asset.pil_image,
4141
return_tensors="pt",
4242
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
43-
vllm_result = MULTIMODAL_REGISTRY.process_input(
43+
vllm_result = MULTIMODAL_REGISTRY.map_input(
44+
model_config,
4445
ImagePixelData(asset.pil_image),
45-
model_config=model_config,
46-
vlm_config=vlm_config,
4746
)
4847

4948
assert hf_result.keys() == vllm_result.keys()
@@ -74,25 +73,24 @@ def test_llava_next_image_processor(image_assets, dtype):
7473
seed=0,
7574
dtype=dtype,
7675
revision=None,
77-
)
78-
vlm_config = VisionLanguageConfig(
79-
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
80-
image_token_id=64000,
81-
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
82-
image_feature_size=2928,
83-
image_processor=MODEL_NAME,
84-
image_processor_revision=None,
76+
multimodal_config=VisionLanguageConfig(
77+
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
78+
image_token_id=64000,
79+
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
80+
image_feature_size=2928,
81+
image_processor=MODEL_NAME,
82+
image_processor_revision=None,
83+
),
8584
)
8685

8786
for asset in image_assets:
8887
hf_result = hf_processor.preprocess(
8988
asset.pil_image,
9089
return_tensors="pt",
9190
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
92-
vllm_result = MULTIMODAL_REGISTRY.process_input(
91+
vllm_result = MULTIMODAL_REGISTRY.map_input(
92+
model_config,
9393
ImagePixelData(asset.pil_image),
94-
model_config=model_config,
95-
vlm_config=vlm_config,
9694
)
9795

9896
assert hf_result.keys() == vllm_result.keys()
@@ -119,26 +117,23 @@ def test_image_pixel_types(image_assets, dtype):
119117
seed=0,
120118
dtype=dtype,
121119
revision=None,
122-
)
123-
vlm_config = VisionLanguageConfig(
124-
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
125-
image_token_id=32000,
126-
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
127-
image_feature_size=576,
128-
image_processor=MODEL_NAME,
129-
image_processor_revision=None,
130-
)
120+
multimodal_config=VisionLanguageConfig(
121+
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
122+
image_token_id=32000,
123+
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
124+
image_feature_size=576,
125+
image_processor=MODEL_NAME,
126+
image_processor_revision=None,
127+
))
131128

132129
for asset in image_assets:
133-
image_result = MULTIMODAL_REGISTRY.process_input(
130+
image_result = MULTIMODAL_REGISTRY.map_input(
131+
model_config,
134132
ImagePixelData(asset.pil_image),
135-
model_config=model_config,
136-
vlm_config=vlm_config,
137133
)
138-
tensor_result = MULTIMODAL_REGISTRY.process_input(
134+
tensor_result = MULTIMODAL_REGISTRY.map_input(
135+
model_config,
139136
ImagePixelData(asset.pixel_values),
140-
model_config=model_config,
141-
vlm_config=vlm_config,
142137
)
143138

144139
assert image_result.keys() == tensor_result.keys()

vllm/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
disable_sliding_window: bool = False,
110110
skip_tokenizer_init: bool = False,
111111
served_model_name: Optional[Union[str, List[str]]] = None,
112+
multimodal_config: Optional["VisionLanguageConfig"] = None,
112113
) -> None:
113114
self.model = model
114115
self.tokenizer = tokenizer
@@ -159,6 +160,8 @@ def __init__(
159160
sliding_window_len=self.get_hf_config_sliding_window())
160161
self.served_model_name = get_served_model_name(model,
161162
served_model_name)
163+
self.multimodal_config = multimodal_config
164+
162165
if not self.skip_tokenizer_init:
163166
self._verify_tokenizer_mode()
164167
self._verify_embedding_mode()

vllm/engine/arg_utils.py

+32-32
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,36 @@ def create_engine_config(self, ) -> EngineConfig:
643643
raise ValueError(
644644
"BitsAndBytes load format and QLoRA adapter only support "
645645
f"'bitsandbytes' quantization, but got {self.quantization}")
646+
if self.image_input_type:
647+
if (not self.image_token_id or not self.image_input_shape
648+
or not self.image_feature_size):
649+
raise ValueError(
650+
'Specify `image_token_id`, `image_input_shape` and '
651+
'`image_feature_size` together with `image_input_type`.')
652+
653+
if self.image_processor is None:
654+
self.image_processor = self.model
655+
if self.disable_image_processor:
656+
if self.image_processor != self.model:
657+
warnings.warn(
658+
"You've specified an image processor "
659+
f"({self.image_processor}) but also disabled "
660+
"it via `--disable-image-processor`.",
661+
stacklevel=2)
662+
663+
self.image_processor = None
664+
665+
vision_language_config = VisionLanguageConfig(
666+
image_input_type=VisionLanguageConfig.
667+
get_image_input_enum_type(self.image_input_type),
668+
image_token_id=self.image_token_id,
669+
image_input_shape=str_to_int_tuple(self.image_input_shape),
670+
image_feature_size=self.image_feature_size,
671+
image_processor=self.image_processor,
672+
image_processor_revision=self.image_processor_revision,
673+
)
674+
else:
675+
vision_language_config = None
646676

647677
device_config = DeviceConfig(device=self.device)
648678
model_config = ModelConfig(
@@ -666,7 +696,8 @@ def create_engine_config(self, ) -> EngineConfig:
666696
max_logprobs=self.max_logprobs,
667697
disable_sliding_window=self.disable_sliding_window,
668698
skip_tokenizer_init=self.skip_tokenizer_init,
669-
served_model_name=self.served_model_name)
699+
served_model_name=self.served_model_name,
700+
multimodal_config=vision_language_config)
670701
cache_config = CacheConfig(
671702
block_size=self.block_size,
672703
gpu_memory_utilization=self.gpu_memory_utilization,
@@ -742,37 +773,6 @@ def create_engine_config(self, ) -> EngineConfig:
742773
model_loader_extra_config=self.model_loader_extra_config,
743774
)
744775

745-
if self.image_input_type:
746-
if (not self.image_token_id or not self.image_input_shape
747-
or not self.image_feature_size):
748-
raise ValueError(
749-
'Specify `image_token_id`, `image_input_shape` and '
750-
'`image_feature_size` together with `image_input_type`.')
751-
752-
if self.image_processor is None:
753-
self.image_processor = self.model
754-
if self.disable_image_processor:
755-
if self.image_processor != self.model:
756-
warnings.warn(
757-
"You've specified an image processor "
758-
f"({self.image_processor}) but also disabled "
759-
"it via `--disable-image-processor`.",
760-
stacklevel=2)
761-
762-
self.image_processor = None
763-
764-
vision_language_config = VisionLanguageConfig(
765-
image_input_type=VisionLanguageConfig.
766-
get_image_input_enum_type(self.image_input_type),
767-
image_token_id=self.image_token_id,
768-
image_input_shape=str_to_int_tuple(self.image_input_shape),
769-
image_feature_size=self.image_feature_size,
770-
image_processor=self.image_processor,
771-
image_processor_revision=self.image_processor_revision,
772-
)
773-
else:
774-
vision_language_config = None
775-
776776
decoding_config = DecodingConfig(
777777
guided_decoding_backend=self.guided_decoding_backend)
778778

vllm/engine/async_llm_engine.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,11 @@ async def process_model_inputs_async(
278278
else:
279279
prompt_token_ids = inputs["prompt_token_ids"]
280280

281-
return LLMInputs(prompt_token_ids=prompt_token_ids,
282-
prompt=inputs.get("prompt"),
283-
multi_modal_data=inputs.get("multi_modal_data"))
281+
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
282+
prompt=inputs.get("prompt"),
283+
multi_modal_data=inputs.get("multi_modal_data"))
284+
285+
return self.input_processor(llm_inputs)
284286

285287
async def add_request_async(
286288
self,

vllm/engine/llm_engine.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vllm.engine.output_processor.util import create_output_by_sequence_group
2121
from vllm.executor.executor_base import ExecutorBase
2222
from vllm.executor.ray_utils import initialize_ray_cluster
23-
from vllm.inputs import LLMInputs, PromptInputs
23+
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
2424
from vllm.logger import init_logger
2525
from vllm.lora.request import LoRARequest
2626
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
@@ -227,6 +227,9 @@ def __init__(
227227
self.generation_config_fields = _load_generation_config_dict(
228228
model_config)
229229

230+
self.input_processor = INPUT_REGISTRY.create_input_processor(
231+
self.model_config)
232+
230233
self.model_executor = executor_class(
231234
model_config=model_config,
232235
cache_config=cache_config,
@@ -513,9 +516,11 @@ def process_model_inputs(
513516
else:
514517
prompt_token_ids = inputs["prompt_token_ids"]
515518

516-
return LLMInputs(prompt_token_ids=prompt_token_ids,
517-
prompt=inputs.get("prompt"),
518-
multi_modal_data=inputs.get("multi_modal_data"))
519+
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
520+
prompt=inputs.get("prompt"),
521+
multi_modal_data=inputs.get("multi_modal_data"))
522+
523+
return self.input_processor(llm_inputs)
519524

520525
def process_model_params(
521526
self,

0 commit comments

Comments
 (0)