diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4edd1cadfb2f..4778cdbecbed 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -94,13 +94,13 @@ steps: #mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - - pytest -v -s models --ignore=models/test_llava.py + - pytest -v -s models -m \"not llava\" - label: Llava Test mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - - pytest -v -s models/test_llava.py + - pytest -v -s models -m llava - label: Prefix Caching Test mirror_hardwares: [amd] diff --git a/docs/source/dev/input_processing/model_inputs_index.rst b/docs/source/dev/input_processing/model_inputs_index.rst new file mode 100644 index 000000000000..8529fc1d9af4 --- /dev/null +++ b/docs/source/dev/input_processing/model_inputs_index.rst @@ -0,0 +1,28 @@ +Input Processing +================ + +.. currentmodule:: vllm.inputs + +vLLM provides a mechanism for defining input processors for each model so that the inputs are processed +in :class:`~vllm.LLMEngine` before they are passed to model executors. + +.. contents:: + :local: + :backlinks: none + +Module Contents ++++++++++++++++ + +LLM Engine Inputs +----------------- + +.. autoclass:: vllm.inputs.LLMInputs + :members: + :show-inheritance: + +Registry +-------- + +.. automodule:: vllm.inputs.registry + :members: + :show-inheritance: diff --git a/docs/source/dev/multimodal/adding_multimodal_model.rst b/docs/source/dev/multimodal/adding_multimodal_model.rst new file mode 100644 index 000000000000..4a0010d47ba3 --- /dev/null +++ b/docs/source/dev/multimodal/adding_multimodal_model.rst @@ -0,0 +1,94 @@ +.. _adding_a_new_multimodal_model: + +Adding a New Multimodal Model +============================= + +This document provides a high-level guide on integrating a :ref:`multimodal model ` into vLLM. + +.. note:: + The complexity of adding a new model depends heavily on the model's architecture. + The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. + However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex. + +.. tip:: + If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub `_ repository. + We will be happy to help you out! + + +0. Set up a base vLLM model +--------------------------- + +Follow :ref:`these steps ` to first implement the model in vLLM. +While implementing the :meth:`~torch.nn.Module.forward` method, reserve a keyword parameter +for each input tensor that corresponds to a multi-modal input, as shown in the following example: + +.. code-block:: diff + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + + pixel_values: torch.Tensor, + ) -> SamplerOutput: + +.. note:: + The model class does not have to be named :code:`*ForCausalLM`. + Check out `the HuggingFace Transformers documentation `__ for some examples. + + +1. Register input mappers +------------------------- + +For each modality type to support, decorate the model class with :meth:`vllm.INPUT_REGISTRY.MULTIMODAL.register_input_mapper `. +This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`. + +.. code-block:: diff + + + from vllm.inputs import INPUT_REGISTRY + + + @INPUT_REGISTRY.MULTIMODAL.register_image_feature_input_mapper() + + @INPUT_REGISTRY.MULTIMODAL.register_image_pixel_input_mapper() + class YourModelForImage2Seq(nn.Module): + +A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function. + + +2. (Optional) Register dummy data +--------------------------------- + +During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models. +In such cases, you can define your own dummy data by registering a factory method via :meth:`vllm.inputs.INPUT_REGISTRY.register_dummy_data `. + +.. code-block:: diff + + from vllm.inputs import INPUT_REGISTRY + + @INPUT_REGISTRY.MULTIMODAL.register_image_feature_input_mapper() + @INPUT_REGISTRY.MULTIMODAL.register_image_pixel_input_mapper() + + @INPUT_REGISTRY.register_dummy_data() + class YourModelForImage2Seq(nn.Module): + +Refer to :class:`vllm.multimodal.image.DummyImageDataFactories` for some examples of dummy data factories. + + +3. (Optional) Register input processor +-------------------------------------- + +Sometimes, there is a need to process inputs at the :class:~vllm.LLMEngine` level before they are passed to the model executor. +You can register input processors via :meth:`vllm.inputs.INPUT_REGISTRY.register_input_processor `. + +.. code-block:: diff + + from vllm.inputs import INPUT_REGISTRY + + @INPUT_REGISTRY.MULTIMODAL.register_image_feature_input_mapper() + @INPUT_REGISTRY.MULTIMODAL.register_image_pixel_input_mapper() + @INPUT_REGISTRY.register_dummy_data() + + @INPUT_REGISTRY.register_input_processor() + class YourModelForImage2Seq(nn.Module): + +A common use case of input processors is inserting extra image tokens to leverage the vLLM framework for attention mask generation. +More details can be found in :class:`vllm.multimodal.image.ImageInputProcessors`. + diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index a25eceecc276..ad5ba07aefd9 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -1,3 +1,5 @@ +.. _multi_modality: + Multi-Modality ============== @@ -8,9 +10,15 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm :class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data`` which allows you to pass in multi-modal input alongside text and token prompts. -By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, -you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data `, -as well as :meth:`MULTIMODAL_REGISTRY.register_input ` for each modality type to support. +By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, please follow :ref:`this guide `. + +Guides +++++++ + +.. toctree:: + :maxdepth: 1 + + adding_multimodal_model .. contents:: :local: @@ -24,9 +32,7 @@ Module Contents Registry -------- -.. data:: vllm.multimodal.MULTIMODAL_REGISTRY - - The global :class:`MultiModalRegistry` which is used by model runners. +.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY .. autoclass:: vllm.multimodal.MultiModalRegistry :members: diff --git a/docs/source/index.rst b/docs/source/index.rst index fad3c3b05b0c..dcca28b3b88c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -107,6 +107,7 @@ Documentation dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention + dev/input_processing/model_inputs_index dev/multimodal/multimodal_index dev/dockerfile/dockerfile diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index cbc8099e6f70..f282b594590b 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -37,7 +37,7 @@ For instance, vLLM's `OPT model \nUSER: What's the content of the image?\nASSISTANT:", - "\nUSER: What is the season?\nASSISTANT:" -] -assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len( - _IMAGE_FILES) == len(_IMAGE_PROMPTS) +assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES) def _read_prompts(filename: str) -> List[str]: @@ -84,14 +79,9 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): cleanup() -@pytest.fixture(scope="session") -def hf_image_prompts() -> List[str]: - return _IMAGE_PROMPTS - - @pytest.fixture(scope="session") def hf_images() -> List[Image.Image]: - return [Image.open(filename) for filename in _IMAGE_FILES] + return [Image.open(filename) for filename in IMAGE_FILES] @pytest.fixture() @@ -101,26 +91,17 @@ def vllm_images(request) -> List[MultiModalData]: VisionLanguageConfig.ImageInputType.IMAGE_FEATURES): return [ ImageFeatureData(torch.load(filename)) - for filename in _IMAGE_FEATURES_FILES + for filename in IMAGE_FEATURES_FILES ] else: return [ - ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES + ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES ] @pytest.fixture() def vllm_image_tensors(request) -> List[torch.Tensor]: - return [torch.load(filename) for filename in _PIXEL_VALUES_FILES] - - -@pytest.fixture() -def vllm_image_prompts(request) -> List[str]: - vision_language_config = request.getfixturevalue("model_and_config")[1] - return [ - "" * (vision_language_config.image_feature_size - 1) + p - for p in _IMAGE_PROMPTS - ] + return [torch.load(filename) for filename in PIXEL_VALUES_FILES] @pytest.fixture diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 839a9f78d1bb..ec719c6654e7 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,14 +1,22 @@ -import gc -from dataclasses import fields -from enum import Enum -from typing import Any, Dict, List, Tuple +from typing import List, Tuple import pytest -import torch from transformers import AutoTokenizer from vllm.config import VisionLanguageConfig +from ..conftest import IMAGE_FILES + +pytestmark = pytest.mark.llava + +# The image token is placed before "user" on purpose so that the test can pass +HF_IMAGE_PROMPTS = [ + "\nUSER: What's the content of the image?\nASSISTANT:", + "\nUSER: What is the season?\nASSISTANT:", +] + +assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES) + def iter_llava_configs(model_name: str): image_hw_to_feature_size = { @@ -31,58 +39,37 @@ def iter_llava_configs(model_name: str): model_and_vl_config = [ *iter_llava_configs("llava-hf/llava-1.5-7b-hf"), - # Not enough memory - # *iter_llava_configs("llava-hf/llava-1.5-13b-hf"), ] -def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]: - """Flatten vision language config to pure args. - - Compatible with what llm entrypoint expects. - """ - result = {} - for field in fields(vlm_config): - value = getattr(vlm_config, field.name) - if isinstance(value, Enum): - result[field.name] = value.name.lower() - elif isinstance(value, tuple): - result[field.name] = ",".join([str(item) for item in value]) - else: - result[field.name] = value - - result["disable_image_processor"] = vlm_config.image_processor is None - - return result - - -def sanitize_vllm_output(vllm_output: Tuple[List[int], str], - vision_language_config: VisionLanguageConfig, - model_id: str): +def vllm_to_hf_output(vllm_output: Tuple[List[int], str], + vlm_config: VisionLanguageConfig, model_id: str): """Sanitize vllm output to be comparable with hf output. The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... It also reduces `output_str` from "bla" to "bla". """ - tokenizer = AutoTokenizer.from_pretrained(model_id) - image_token_str = tokenizer.decode(vision_language_config.image_token_id) - image_token_str_len = len(image_token_str) input_ids, output_str = vllm_output - sanitized_input_ids = input_ids[0:2] + input_ids[2 + vision_language_config - .image_feature_size - 1:] - sanitzied_output_str = output_str[vision_language_config. - image_feature_size * - image_token_str_len:] - return sanitized_input_ids, sanitzied_output_str + image_token_id = vlm_config.image_token_id + + tokenizer = AutoTokenizer.from_pretrained(model_id) + image_token_str = tokenizer.decode(image_token_id) + + hf_input_ids = [ + input_id for idx, input_id in enumerate(input_ids) + if input_id != image_token_id or input_ids[idx - 1] != image_token_id + ] + hf_output_str = output_str \ + .replace(image_token_str * vlm_config.image_feature_size, "") + + return hf_input_ids, hf_output_str -@pytest.mark.parametrize("worker_use_ray", [False]) @pytest.mark.parametrize("model_and_config", model_and_vl_config) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) -def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, - vllm_image_prompts, vllm_images, model_and_config, dtype: str, - max_tokens: int, worker_use_ray: bool) -> None: +def test_models(hf_runner, vllm_runner, hf_images, vllm_images, + model_and_config, dtype: str, max_tokens: int) -> None: """Inference result should be the same between hf and vllm. All the image fixtures for the test is under tests/images. @@ -92,36 +79,28 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - model_id, vision_language_config = model_and_config + model_id, vlm_config = model_and_config hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True) - hf_outputs = hf_model.generate_greedy(hf_image_prompts, + hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, max_tokens, images=hf_images) del hf_model vllm_model = vllm_runner(model_id, dtype=dtype, - worker_use_ray=worker_use_ray, enforce_eager=True, - **as_dict(vision_language_config)) - vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, + **vlm_config.as_cli_args_dict()) + vllm_outputs = vllm_model.generate_greedy(HF_IMAGE_PROMPTS, max_tokens, images=vllm_images) del vllm_model - gc.collect() - torch.cuda.empty_cache() - - for i in range(len(hf_image_prompts)): + for i in range(len(HF_IMAGE_PROMPTS)): hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = sanitize_vllm_output( - vllm_outputs[i], vision_language_config, model_id) + vllm_output_ids, vllm_output_str = vllm_to_hf_output( + vllm_outputs[i], vlm_config, model_id) assert hf_output_str == vllm_output_str, ( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - - -# TODO: Add test for `tensor_parallel_size` [ref: PR #3883] -# (Requires multiple GPUs) diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py new file mode 100644 index 000000000000..610e2a2fc310 --- /dev/null +++ b/tests/models/test_llava_next.py @@ -0,0 +1,117 @@ +import re +from typing import List, Tuple + +import pytest +from transformers import AutoTokenizer + +from vllm.config import VisionLanguageConfig + +from ..conftest import IMAGE_FILES + +pytestmark = pytest.mark.llava + +_PREFACE = ( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's " + "questions.") + +# The image token is placed before "user" on purpose so that the test can pass +HF_IMAGE_PROMPTS = [ + f"{_PREFACE} \nUSER: What's the content of the image? ASSISTANT:", + f"{_PREFACE} \nUSER: What is the season? ASSISTANT:", +] + +assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES) + + +def iter_llava_next_configs(model_name: str): + image_hw_to_feature_size = { + (336, 336): 1176, + (672, 672): 2928, + (1344, 336): 1944, + (336, 1344): 1890, + } + + for (h, w), f in image_hw_to_feature_size.items(): + for input_type, input_shape in [ + (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), + ]: + yield (model_name, + VisionLanguageConfig(image_input_type=input_type, + image_feature_size=f, + image_token_id=32000, + image_input_shape=input_shape, + image_processor=model_name, + image_processor_revision=None)) + + +model_and_vl_config = [ + *iter_llava_next_configs("llava-hf/llava-v1.6-vicuna-7b-hf"), +] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str], + vlm_config: VisionLanguageConfig, model_id: str): + """Sanitize vllm output to be comparable with hf output. + The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, + x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... + It also reduces `output_str` from "bla" to "bla". + """ + input_ids, output_str = vllm_output + image_token_id = vlm_config.image_token_id + + tokenizer = AutoTokenizer.from_pretrained(model_id) + image_token_str = tokenizer.decode(image_token_id) + + hf_input_ids = [ + input_id for idx, input_id in enumerate(input_ids) + if input_id != image_token_id or input_ids[idx - 1] != image_token_id + ] + hf_output_str = re.sub(fr"({image_token_str})+", " ", output_str) + + return hf_input_ids, hf_output_str + + +@pytest.mark.parametrize("model_and_config", model_and_vl_config) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models(hf_runner, vllm_runner, hf_images, vllm_images, + model_and_config, dtype: str, max_tokens: int) -> None: + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalData objects and corresponding + vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + model_id, vlm_config = model_and_config + + hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True) + hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, + max_tokens, + images=hf_images) + del hf_model + + vllm_model = vllm_runner( + model_id, + dtype=dtype, + # should be greater than image_feature_size + max_model_len=4096, + enforce_eager=True, + **vlm_config.as_cli_args_dict(), + ) + vllm_outputs = vllm_model.generate_greedy(HF_IMAGE_PROMPTS, + max_tokens, + images=vllm_images) + del vllm_model + + for i in range(len(HF_IMAGE_PROMPTS)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_to_hf_output( + vllm_outputs[i], vlm_config, model_id) + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py new file mode 100644 index 000000000000..e398546256cb --- /dev/null +++ b/tests/multimodal/test_mapper.py @@ -0,0 +1,141 @@ +import numpy as np +import pytest +from transformers import CLIPImageProcessor, LlavaNextImageProcessor + +from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import ImagePixelData + +from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE + + +@pytest.mark.parametrize("dtype", ["half", "float"]) +def test_clip_image_processor(hf_images, dtype): + MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + IMAGE_HEIGHT = IMAGE_WIDTH = 560 + + hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) + assert isinstance(hf_processor, CLIPImageProcessor) + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype=dtype, + revision=None, + multimodal_config=VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=32000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=576, + image_processor=MODEL_NAME, + image_processor_revision=None, + ), + ) + + for image in hf_images: + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) + vllm_result = MULTIMODAL_REGISTRY.map_input( + model_config, + ImagePixelData(image), + ) + + assert hf_result.keys() == vllm_result.keys() + for key, hf_tensor in hf_result.items(): + hf_arr: np.ndarray = hf_tensor.numpy() + vllm_arr: np.ndarray = vllm_result[key].numpy() + + assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" + assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" + + +@pytest.mark.parametrize("dtype", ["half", "float"]) +def test_llava_next_image_processor(hf_images, dtype): + MODEL_NAME = "llava-hf/llava-v1.6-34b-hf" + IMAGE_HEIGHT = IMAGE_WIDTH = 560 + + hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME) + assert isinstance(hf_processor, LlavaNextImageProcessor) + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype=dtype, + revision=None, + multimodal_config=VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=64000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=2928, + image_processor=MODEL_NAME, + image_processor_revision=None, + ), + ) + + for image in hf_images: + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) + vllm_result = MULTIMODAL_REGISTRY.map_input( + model_config, + ImagePixelData(image), + ) + + assert hf_result.keys() == vllm_result.keys() + for key, hf_tensor in hf_result.items(): + hf_arr: np.ndarray = hf_tensor.numpy() + vllm_arr: np.ndarray = vllm_result[key].numpy() + + assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" + assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" + + +@pytest.mark.xfail( + reason="Example image pixels were not processed using HuggingFace") +@pytest.mark.parametrize("dtype", ["float"]) +def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): + MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + IMAGE_HEIGHT = IMAGE_WIDTH = 560 + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype=dtype, + revision=None, + multimodal_config=VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=32000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=576, + image_processor=MODEL_NAME, + image_processor_revision=None, + )) + + for image, tensor in zip(hf_images, vllm_image_tensors): + image_result = MULTIMODAL_REGISTRY.map_input( + model_config, + ImagePixelData(image), + ) + tensor_result = MULTIMODAL_REGISTRY.map_input( + model_config, + ImagePixelData(tensor), + ) + + assert image_result.keys() == tensor_result.keys() + for key, image_arr in image_result.items(): + tensor_arr: np.ndarray = tensor_result[key].numpy() + + assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}" + assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}" diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py deleted file mode 100644 index 3df28e782dd8..000000000000 --- a/tests/multimodal/test_processor.py +++ /dev/null @@ -1,101 +0,0 @@ -import numpy as np -import pytest -from transformers import CLIPImageProcessor - -from vllm.config import ModelConfig, VisionLanguageConfig -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import ImagePixelData - -from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE - - -@pytest.mark.parametrize("dtype", ["half", "float"]) -def test_clip_image_processor(hf_images, dtype): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - IMAGE_HEIGHT = IMAGE_WIDTH = 33 - - hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) - assert isinstance(hf_processor, CLIPImageProcessor) - - model_config = ModelConfig( - model=MODEL_NAME, - tokenizer=MODEL_NAME, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype=dtype, - revision=None, - ) - vlm_config = VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, - image_token_id=32000, - image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), - image_feature_size=576, - image_processor=MODEL_NAME, - image_processor_revision=None, - ) - - for image in hf_images: - hf_result = hf_processor.preprocess( - image, - return_tensors="pt", - ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) - vllm_result = MULTIMODAL_REGISTRY.process_input( - ImagePixelData(image), - model_config=model_config, - vlm_config=vlm_config, - ) - - assert hf_result.keys() == vllm_result.keys() - for key, hf_tensor in hf_result.items(): - hf_arr: np.ndarray = hf_tensor.numpy() - vllm_arr: np.ndarray = vllm_result[key].numpy() - - assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" - assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" - - -@pytest.mark.parametrize("dtype", ["float"]) -def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - IMAGE_HEIGHT = IMAGE_WIDTH = 33 - - model_config = ModelConfig( - model=MODEL_NAME, - tokenizer=MODEL_NAME, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype=dtype, - revision=None, - ) - vlm_config = VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, - image_token_id=32000, - image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), - image_feature_size=576, - image_processor=MODEL_NAME, - image_processor_revision=None, - ) - - for image, tensor in zip(hf_images, vllm_image_tensors): - image_result = MULTIMODAL_REGISTRY.process_input( - ImagePixelData(image), - model_config=model_config, - vlm_config=vlm_config, - ) - tensor_result = MULTIMODAL_REGISTRY.process_input( - ImagePixelData(tensor), - model_config=model_config, - vlm_config=vlm_config, - ) - - assert image_result.keys() == tensor_result.keys() - for key, image_arr in image_result.items(): - tensor_arr: np.ndarray = tensor_result[key].numpy() - - assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}" - - # The examples in PR#3042 have slightly different preprocessing from - # HuggingFace's LlavaProcessor, causing the test to fail. - # assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}" diff --git a/vllm/config.py b/vllm/config.py index 7fd417bd745a..312f331ab0e4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,8 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, + Union) import torch from transformers import PretrainedConfig @@ -103,6 +104,7 @@ def __init__( disable_sliding_window: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None, + multimodal_config: Optional["VisionLanguageConfig"] = None, ) -> None: self.model = model self.tokenizer = tokenizer @@ -137,6 +139,8 @@ def __init__( sliding_window_len=self.get_hf_config_sliding_window()) self.served_model_name = get_served_model_name(model, served_model_name) + self.multimodal_config = multimodal_config + if not self.skip_tokenizer_init: self._verify_tokenizer_mode() self._verify_embedding_mode() @@ -1114,6 +1118,24 @@ def get_image_input_enum_type(cls, value: str) -> ImageInputType: f"Expecting to choose from " f"{[x.name for x in cls.ImageInputType]}.") from e + def as_cli_args_dict(self) -> Dict[str, Any]: + """Flatten vision language config to pure args. + Compatible with what llm entrypoint expects. + """ + result: Dict[str, Any] = {} + for f in fields(self): + value = getattr(self, f.name) + if isinstance(value, enum.Enum): + result[f.name] = value.name.lower() + elif isinstance(value, tuple): + result[f.name] = ",".join([str(item) for item in value]) + else: + result[f.name] = value + + result["disable_image_processor"] = self.image_processor is None + + return result + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 72787d369c0f..fb8ca0b3b764 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -619,6 +619,37 @@ def create_engine_config(self, ) -> EngineConfig: "BitsAndBytes load format and QLoRA adapter only support " f"'bitsandbytes' quantization, but got {self.quantization}") + if self.image_input_type: + if (not self.image_token_id or not self.image_input_shape + or not self.image_feature_size): + raise ValueError( + 'Specify `image_token_id`, `image_input_shape` and ' + '`image_feature_size` together with `image_input_type`.') + + if self.image_processor is None: + self.image_processor = self.model + if self.disable_image_processor: + if self.image_processor != self.model: + warnings.warn( + "You've specified an image processor " + f"({self.image_processor}) but also disabled " + "it via `--disable-image-processor`.", + stacklevel=2) + + self.image_processor = None + + vision_language_config = VisionLanguageConfig( + image_input_type=VisionLanguageConfig. + get_image_input_enum_type(self.image_input_type), + image_token_id=self.image_token_id, + image_input_shape=str_to_int_tuple(self.image_input_shape), + image_feature_size=self.image_feature_size, + image_processor=self.image_processor, + image_processor_revision=self.image_processor_revision, + ) + else: + vision_language_config = None + device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, @@ -628,7 +659,8 @@ def create_engine_config(self, ) -> EngineConfig: self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, self.max_seq_len_to_capture, self.max_logprobs, self.disable_sliding_window, - self.skip_tokenizer_init, self.served_model_name) + self.skip_tokenizer_init, self.served_model_name, + vision_language_config) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, @@ -700,37 +732,6 @@ def create_engine_config(self, ) -> EngineConfig: model_loader_extra_config=self.model_loader_extra_config, ) - if self.image_input_type: - if (not self.image_token_id or not self.image_input_shape - or not self.image_feature_size): - raise ValueError( - 'Specify `image_token_id`, `image_input_shape` and ' - '`image_feature_size` together with `image_input_type`.') - - if self.image_processor is None: - self.image_processor = self.model - if self.disable_image_processor: - if self.image_processor != self.model: - warnings.warn( - "You've specified an image processor " - f"({self.image_processor}) but also disabled " - "it via `--disable-image-processor`.", - stacklevel=2) - - self.image_processor = None - - vision_language_config = VisionLanguageConfig( - image_input_type=VisionLanguageConfig. - get_image_input_enum_type(self.image_input_type), - image_token_id=self.image_token_id, - image_input_shape=str_to_int_tuple(self.image_input_shape), - image_feature_size=self.image_feature_size, - image_processor=self.image_processor, - image_processor_revision=self.image_processor_revision, - ) - else: - vision_language_config = None - decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index db4d2849b3f0..ddca39d67c9d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -265,9 +265,11 @@ async def process_model_inputs_async( else: prompt_token_ids = inputs["prompt_token_ids"] - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) + + return self.input_processor(llm_inputs) async def add_request_async( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cb5893e707c8..59327f3b56eb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -21,7 +21,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import LLMInputs, PromptInputs +from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -219,6 +219,9 @@ def __init__( self.generation_config_fields = _load_generation_config_dict( model_config) + self.input_processor = INPUT_REGISTRY.create_input_processor( + self.model_config) + self.model_executor = executor_class( model_config=model_config, cache_config=cache_config, @@ -484,9 +487,11 @@ def process_model_inputs( else: prompt_token_ids = inputs["prompt_token_ids"] - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) + + return self.input_processor(llm_inputs) def add_request( self, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py new file mode 100644 index 000000000000..6288503bfe19 --- /dev/null +++ b/vllm/inputs/__init__.py @@ -0,0 +1,17 @@ +from vllm.multimodal import MULTIMODAL_REGISTRY + +from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, + PromptStrictInputs, TextPrompt, TextTokensPrompt, + TokensPrompt, parse_and_batch_prompt) +from .registry import InputRegistry + +INPUT_REGISTRY = InputRegistry(multimodal_registry=MULTIMODAL_REGISTRY) +"""The global :class:`~InputRegistry` which is used by model runners.""" + +del MULTIMODAL_REGISTRY + +__all__ = [ + "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", + "TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs", + "LLMInputs", "INPUT_REGISTRY", "InputRegistry" +] diff --git a/vllm/inputs.py b/vllm/inputs/data.py similarity index 97% rename from vllm/inputs.py rename to vllm/inputs/data.py index 85c9cd84f5ed..2c600e9793f3 100644 --- a/vllm/inputs.py +++ b/vllm/inputs/data.py @@ -125,6 +125,11 @@ class TextTokensPrompt(TypedDict): class LLMInputs(TypedDict): + """ + The inputs in :class:`~vllm.LLMEngine` before they are + passed to the model executor. + """ + prompt_token_ids: List[int] prompt: NotRequired[Optional[str]] multi_modal_data: NotRequired[Optional["MultiModalData"]] diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py new file mode 100644 index 000000000000..78f1fdea8945 --- /dev/null +++ b/vllm/inputs/registry.py @@ -0,0 +1,181 @@ +import functools +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, + TypeVar) + +from torch import nn +from transformers import PretrainedConfig + +from vllm.logger import init_logger + +from .data import LLMInputs + +if TYPE_CHECKING: + from vllm.config import ModelConfig, VisionLanguageConfig + from vllm.multimodal import MultiModalData, MultiModalRegistry + from vllm.sequence import SequenceData + +logger = init_logger(__name__) + +C = TypeVar("C", bound=PretrainedConfig) + + +@dataclass(frozen=True) +class InputContext: + model_config: "ModelConfig" + + def get_multimodal_config(self) -> "VisionLanguageConfig": + multimodal_config = self.model_config.multimodal_config + if multimodal_config is None: + raise ValueError("No multimodal config found") + + return multimodal_config + + def get_hf_config(self, hf_config_type: Type[C]) -> C: + hf_config = self.model_config.hf_config + if not isinstance(hf_config, hf_config_type): + raise TypeError("Invalid type of HuggingFace config. " + f"Expected type: {hf_config_type}, but " + f"found type: {type(hf_config)}") + + return hf_config + + +N = TypeVar("N", bound=Type[nn.Module]) + +DummyDataFactory = Callable[[InputContext, int], + Tuple["SequenceData", Optional["MultiModalData"]]] +""" +Create dummy data to be inputted into the model. + +Note: + :data:`InputProcessor` is not applied to the dummy data. +""" + +InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs] +"""Preprocess the inputs to the model.""" + + +class InputRegistry: + """ + This registry is used by :class:`~vllm.LLMEngine` to dispatch data + processing according to the target model. + """ + + def __init__(self, *, multimodal_registry: "MultiModalRegistry") -> None: + self._multimodal_registry = multimodal_registry + + self._dummy_factories_by_model_type: Dict[Type[nn.Module], + DummyDataFactory] = {} + self._input_processors_by_model_type: Dict[Type[nn.Module], + InputProcessor] = {} + + @property + def MULTIMODAL(self) -> "MultiModalRegistry": + """Access the registry for processing multimodal inputs.""" + return self._multimodal_registry + + def _default_dummy_data_factory( + self, + ctx: InputContext, + seq_len: int, + ) -> Tuple["SequenceData", Optional["MultiModalData"]]: + """ + Create dummy data to be inputted into the model. + + Note: + :data:`InputProcessor` is not applied to the dummy data. + """ + # Avoid circular import + from vllm.sequence import SequenceData + + dummy_seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + + return dummy_seq_data, dummy_multi_modal_data + + def register_dummy_data(self, factory: DummyDataFactory): + """ + Register a dummy data factory to a model class. + + During memory profiling, the provided function is invoked to create + dummy data to be inputted into the model. The resulting memory usage + should be an upper bound of what the model would use at inference time. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_factories_by_model_type: + logger.warning( + "Model class %s already has dummy data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def dummy_data_for_profiling(self, model_config: "ModelConfig", + seq_len: int): + """Create dummy data for memory profiling.""" + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + dummy_factory = self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + + return dummy_factory(InputContext(model_config), seq_len) + + def _default_input_processor(self, ctx: InputContext, + inputs: LLMInputs) -> LLMInputs: + """Preprocess the inputs to the model.""" + return inputs + + def register_input_processor(self, processor: InputProcessor): + """ + Register an input processor to a model class. + + The provided function is invoked on each input to the model. This + happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._input_processors_by_model_type: + logger.warning( + "Model class %s already has input processor " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._input_processors_by_model_type[model_cls] = processor + + return model_cls + + return wrapper + + def process_input(self, model_config: "ModelConfig", + inputs: LLMInputs) -> LLMInputs: + """ + Apply an input processor to an instance of model inputs. + + The model is identified by ``model_config``. ``vlm_config`` is + for compatibility purposes and may be merged into ``model_config`` + in the near future. + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + + processor = self._input_processors_by_model_type \ + .get(model_cls, self._default_input_processor) + + return processor(InputContext(model_config), inputs) + + def create_input_processor(self, model_config: "ModelConfig"): + """ + Create an input processor (see :meth:`process_input`) for a + specific model. + """ + return functools.partial(self.process_input, model_config) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a92abe6b5b8d..4446914c67c8 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -33,6 +33,8 @@ "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), + "LlavaNextForConditionalGeneration": + ("llava_next", "LlavaNextForConditionalGeneration"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3332bcc57846..51b3e48709dd 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -8,6 +8,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VisionLanguageConfig +from vllm.inputs import INPUT_REGISTRY from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -17,8 +18,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import get_dummy_image_data +from vllm.multimodal.image import DummyImageDataFactories, ImageInputProcessors from vllm.sequence import SamplerOutput from .vlm_base import VisionLanguageModelBase @@ -84,9 +84,12 @@ class LlavaImageFeatureInputs(TypedDict): LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] -@MULTIMODAL_REGISTRY.register_image_feature_input() -@MULTIMODAL_REGISTRY.register_image_pixel_input() -@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) +@INPUT_REGISTRY.MULTIMODAL.register_image_feature_input_mapper() +@INPUT_REGISTRY.MULTIMODAL.register_image_pixel_input_mapper() +@INPUT_REGISTRY.register_dummy_data( + DummyImageDataFactories.for_model(LlavaConfig)) +@INPUT_REGISTRY.register_input_processor( + ImageInputProcessors.for_model(LlavaConfig)) class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, @@ -151,7 +154,8 @@ def _parse_and_validate_image_input( return None if not isinstance(pixel_values, torch.Tensor): - raise ValueError("Incorrect type of pixel values") + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") return LlavaImagePixelInputs( type="pixel_values", @@ -166,7 +170,8 @@ def _parse_and_validate_image_input( return None if not isinstance(image_features, torch.Tensor): - raise ValueError("Incorrect type of image features") + raise ValueError("Incorrect type of image features. " + f"Got type: {type(image_features)}") return LlavaImageFeatureInputs( type="image_features", diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py new file mode 100644 index 000000000000..3095013dd8c1 --- /dev/null +++ b/vllm/model_executor/models/llava_next.py @@ -0,0 +1,221 @@ +from typing import Optional, TypedDict + +import torch +import torch.nn as nn +from transformers import LlavaNextConfig +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) + +from vllm.config import CacheConfig, VisionLanguageConfig +from vllm.inputs import INPUT_REGISTRY +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.multimodal.image import DummyImageDataFactories, ImageInputProcessors + +from .llava import LlavaForConditionalGeneration, LlavaImagePixelInputs + + +class ImageSizesMixin(TypedDict, total=False): + image_sizes: torch.Tensor + """Shape: (batch_size, 2)""" + + +class LlavaNextImagePixelInputs(ImageSizesMixin, LlavaImagePixelInputs): + data: torch.Tensor + """Shape: (batch_size, 1 + num_patches, num_channels, height, width)""" + + +@INPUT_REGISTRY.MULTIMODAL.register_image_pixel_input_mapper() +@INPUT_REGISTRY.register_dummy_data( + DummyImageDataFactories.for_model(LlavaNextConfig)) +@INPUT_REGISTRY.register_input_processor( + ImageInputProcessors.for_model(LlavaNextConfig)) +class LlavaNextForConditionalGeneration(LlavaForConditionalGeneration): + """ + Args to `forward()`: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: For PIXEL_VALUES, expects a batch with shape + [1, num_patches, 3, 336, 336]. + image_features: For IMAGE_FEATURES, expects a batch with shape + [1, num_patches, 1176, 1024]. + """ + + def __init__(self, + config: LlavaNextConfig, + vision_language_config: VisionLanguageConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__( + config, # type: ignore + vision_language_config, + cache_config, + quant_config, + ) + + # Update the type annotation from that of its superclass + self.config = config + + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) + + def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor: + _, num_channels, _, _ = self.vision_language_config.image_input_shape + + # Note that this is different from that of vLLM vision_language_config + # since the image is resized by the HuggingFace preprocessor + height = width = self.config.vision_config.image_size + + if list(data.shape[2:]) != [num_channels, height, width]: + raise ValueError( + f"The expected image tensor shape is batch dimension plus " + f"num_patches plus {[num_channels, height, width]}. " + f"You supplied {data.shape}. " + f"If you are using vLLM's entrypoint, make sure your " + f"supplied image input is consistent with " + f"image_input_shape in engine args.") + + return data + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + if list(data.shape[1:]) != [2]: + raise ValueError( + f"The expected image sizes shape is batch dimension plus " + f"{[2]}. You supplied {data.shape}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + image_features = kwargs.pop("image_features", None) + + expected_input_type = self.vision_language_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if expected_input_type == ImageInputType.PIXEL_VALUES: + if image_features is not None: + raise ValueError( + "Expected pixel values but got image features") + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, torch.Tensor): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + + return LlavaNextImagePixelInputs( + type="pixel_values", + data=self._validate_image_pixels(pixel_values), + image_sizes=self._validate_image_sizes(image_sizes), + ) + + if expected_input_type == ImageInputType.IMAGE_FEATURES: + raise TypeError("Image features are not supported by LLaVA-NeXT") + + return None + + def _merge_image_patch_embeddings(self, image_size: torch.Tensor, + patch_embeddings: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "flat": + return patch_embeddings.flatten(0, 1) + + if strategy.startswith("spatial"): + orig_width, orig_height = image_size + height = width = self.config.vision_config.image_size \ + // self.config.vision_config.patch_size + + base_patch_embeds = patch_embeddings[0] + if height * width != base_patch_embeds.shape[0]: + raise ValueError( + "The number of patches is not consistent with the " + "image size.") + + if patch_embeddings.shape[0] > 1: + other_patch_embeds = patch_embeddings[1:] + + # image_aspect_ratio == "anyres" + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + (orig_width, orig_height), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + other_patch_embeds = other_patch_embeds \ + .view(num_patch_width, num_patch_height, height, width, -1) + + if "unpad" in strategy: + other_patch_embeds = other_patch_embeds \ + .permute(4, 0, 2, 1, 3).contiguous() \ + .flatten(1, 2).flatten(2, 3) + other_patch_embeds = unpad_image(other_patch_embeds, + image_size) + other_patch_embeds = torch.cat(( + other_patch_embeds, + self.image_newline[:, None, None] \ + .expand(*other_patch_embeds.shape[:-1], 1) \ + .to(other_patch_embeds.device), + ), dim=-1) + other_patch_embeds = other_patch_embeds \ + .flatten(1, 2).transpose(0, 1) + else: + other_patch_embeds = other_patch_embeds \ + .permute(0, 2, 1, 3, 4).contiguous() \ + .flatten(0, 3) + + merged_patch_embeddings = torch.cat( + (base_patch_embeds, other_patch_embeds), dim=0) + else: + if "unpad" in strategy: + merged_patch_embeddings = torch.cat( + (base_patch_embeds, + self.image_newline[None] \ + .to(base_patch_embeds.device) + ), dim=0) + else: + merged_patch_embeddings = base_patch_embeds + + return merged_patch_embeddings + + raise ValueError(f"Unexpected patch merge strategy: {strategy}") + + def _process_image_pixels( + self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + b, num_patches, c, h, w = pixel_values.shape + stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + + return stacked_image_features.view(b, num_patches, + *stacked_image_features.shape[-2:]) + + def _process_image_input( + self, image_input: LlavaNextImagePixelInputs) -> torch.Tensor: + patch_embeddings = super()._process_image_input(image_input) + + image_sizes = image_input.get("image_sizes") + if image_sizes is None: + batch_size = image_input["data"].shape[0] + vision_config = self.config.vision_config + default_width = default_height = vision_config.image_size + image_sizes = torch.as_tensor([[default_width, default_height] + for _ in range(batch_size)]) + + merged_patch_embeddings = [ + self._merge_image_patch_embeddings(image_sizes[i], + patch_features, + strategy="spatial_unpad") + for i, patch_features in enumerate(patch_embeddings) + ] + + return torch.stack(merged_patch_embeddings, dim=0) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 270012e7d1c3..c97586258c90 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,5 +1,8 @@ from .base import MultiModalData, MultiModalPlugin -from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry +from .registry import MultiModalRegistry + +MULTIMODAL_REGISTRY = MultiModalRegistry() +"""The global :class:`~MultiModalRegistry` which is used by model runners.""" __all__ = [ "MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 847752449ba8..6b4684a54d33 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -2,7 +2,8 @@ from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type, TypeVar) -from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.config import ModelConfig +from vllm.inputs.registry import InputContext from vllm.logger import init_logger if TYPE_CHECKING: @@ -32,10 +33,9 @@ class MultiModalData: D = TypeVar("D", bound=MultiModalData) N = TypeVar("N", bound=Type["nn.Module"]) -MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], - Dict[str, "torch.Tensor"]] +MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]] """Return a dictionary to be passed as keyword arguments to -:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers +:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers.""" @@ -50,16 +50,9 @@ class MultiModalPlugin(ABC, Generic[D]): (i.e., the modality of the data). """ - @classmethod - def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]: - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - return get_model_architecture(model_config)[0] - def __init__(self) -> None: - self._input_processors: Dict[Type["nn.Module"], - MultiModalInputProcessor[D]] = {} + self._input_mappers: Dict[Type["nn.Module"], + MultiModalInputMapper[D]] = {} @abstractmethod def get_data_type(self) -> Type[D]: @@ -70,57 +63,55 @@ def get_data_type(self) -> Type[D]: raise NotImplementedError @abstractmethod - def _default_input_processor( - self, data: D, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + def _default_input_mapper(self, ctx: InputContext, + data: D) -> Dict[str, "torch.Tensor"]: """Return a dictionary to be passed as keyword arguments to - :meth:`torch.nn.Module.forward`. This is similar in concept to + :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers. """ raise NotImplementedError - def register_input_processor(self, - processor: Optional[ - MultiModalInputProcessor[D]] = None): + def register_input_mapper( + self, + mapper: Optional[MultiModalInputMapper[D]] = None, + ): """ - Register an input processor to a model class. + Register an input mapper to a model class. When the model receives input data that matches the modality served by - this plugin (see :meth:`get_data_type`), the provided input processor is - applied to preprocess the data. If `None` is provided, then the default - input processor is applied instead. + this plugin (see :meth:`get_data_type`), the provided function is + invoked to transform the data into a dictionary of model inputs. + If `None` is provided, then the default input mapper is used instead. """ def wrapper(model_cls: N) -> N: - if model_cls in self._input_processors: + if model_cls in self._input_mappers: logger.warning( - "Model class %s already has an input processor " + "Model class %s already has an input mapper " "registered to %s. It is overwritten by the new one.", model_cls, self) - self._input_processors[model_cls] = processor \ - or self._default_input_processor + self._input_mappers[model_cls] = mapper \ + or self._default_input_mapper return model_cls return wrapper - def process_input( - self, data: D, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + def map_input(self, model_config: ModelConfig, + data: D) -> Dict[str, "torch.Tensor"]: """ - Apply an input processor to a :class:`~MultiModalData` instance passed - to the model. - - The model is identified by ``model_config``. ``vlm_config`` is - for compatibility purposes and may be merged into ``model_config`` - in the near future. + Apply an input mapper to a :class:`~MultiModalData` instance passed + to the model, transforming the data into a dictionary of model inputs. """ - model_cls = self.get_model_cls(model_config) + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) - processor = self._input_processors.get(model_cls) - if processor is None: - raise KeyError(f"No input processor in {self} is registered for " + mapper = self._input_mappers.get(model_cls) + if mapper is None: + raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") - return processor(data, model_config, vlm_config) + return mapper(InputContext(model_config), data) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index b964e9ee4262..16ed99c9fc02 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,70 +1,485 @@ -from typing import Dict, Tuple, Type, Union +from functools import lru_cache +from typing import (TYPE_CHECKING, Dict, List, Optional, Tuple, Type, TypeVar, + Union) import torch from PIL import Image +from transformers import (CLIPVisionConfig, LlavaConfig, LlavaNextConfig, + PretrainedConfig, PreTrainedTokenizerBase) +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape) from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.inputs.registry import DummyDataFactory, InputContext, InputProcessor from vllm.logger import init_logger from vllm.sequence import SequenceData -from vllm.transformers_utils.image_processor import cached_get_image_processor +from vllm.transformers_utils.image_processor import get_image_processor +from vllm.transformers_utils.tokenizer import get_tokenizer from .base import MultiModalData, MultiModalPlugin +if TYPE_CHECKING: + from vllm.inputs import LLMInputs +else: + LLMInputs = dict + logger = init_logger(__name__) +_cached_get_tokenizer = lru_cache(get_tokenizer) +_cached_get_image_processor = lru_cache(get_image_processor) + + +def _get_clip_num_patches(hf_config: CLIPVisionConfig) -> int: + image_size = hf_config.image_size + patch_size = hf_config.patch_size -def _get_dummy_seq_data(seq_len: int, - vlm_config: VisionLanguageConfig) -> SequenceData: - # NOTE: We assume that token is repeated `image_feature_size` times - # and then concatenated with the text prompt - # TODO: Enable other ways of inserting the image into the prompt + assert image_size % patch_size == 0 + return image_size // patch_size - token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size - token_ids += [0] * (seq_len - vlm_config.image_feature_size) - return SequenceData(token_ids) +def _get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: + num_patches = _get_clip_num_patches(hf_config) + return num_patches * num_patches -def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor: - if vlm_config.image_processor is None: - values_dtype = torch.float16 +def _get_llava_next_num_unpadded_features( + height: int, + width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, +) -> Tuple[int, int]: + # Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111 + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio: float = width / height + current_aspect_ratio: float = current_width / current_height + if aspect_ratio > current_aspect_ratio: + new_height = (height * current_width) // width + current_height = new_height else: - values_dtype = torch.uint8 + new_width = (width * current_height) // height + current_width = new_width + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + + +def _get_llava_next_image_feature_size( + hf_config: LlavaNextConfig, + *, + input_height: int, + input_width: int, +) -> int: + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + num_patches = _get_clip_num_patches(vision_config) + base_feature_size = num_patches * num_patches + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_size=(input_height, input_width), + grid_pinpoints=hf_config.image_grid_pinpoints, + patch_size=vision_config.image_size, + ) - return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype) + ( + unpadded_feature_size, + newline_feature_size, + ) = _get_llava_next_num_unpadded_features(input_height, input_width, + num_patches, + num_patch_height, + num_patch_width) + return unpadded_feature_size + newline_feature_size + base_feature_size -def get_dummy_image_data( - seq_len: int, - model_config: ModelConfig, - vlm_config: VisionLanguageConfig, -) -> Tuple[SequenceData, MultiModalData]: - """Standard dummy data factory for image data (to be used in - :meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`).""" - seq_data = _get_dummy_seq_data(seq_len, vlm_config) - values = _get_dummy_values(vlm_config) + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) - config_input_type = vlm_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType - fake_mm_data: MultiModalData - if config_input_type == ImageInputType.PIXEL_VALUES: - fake_mm_data = ImagePixelData(values) - elif config_input_type == ImageInputType.IMAGE_FEATURES: - fake_mm_data = ImageFeatureData(values) - else: - raise NotImplementedError +class DummyImageDataFactories: + """ + Contains factories for dummy image data factories. + + See Also: + :data:`vllm.inputs.registry.DummyDataFactory` + """ + + @classmethod + def _dummy_seq_data_for_clip( + cls, + hf_config: CLIPVisionConfig, + seq_len: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, + ): + if image_feature_size_override is None: + image_feature_size = _get_clip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + token_ids = [image_token_id] * image_feature_size + token_ids += [0] * (seq_len - image_feature_size) + return SequenceData(token_ids) + + @classmethod + def _dummy_pixel_data_for_clip( + cls, + hf_config: CLIPVisionConfig, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, + ): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return ImagePixelData(image) + + @classmethod + def _dummy_feature_data_for_clip( + cls, + hf_config: CLIPVisionConfig, + *, + image_feature_size_override: Optional[int] = None, + ): + if image_feature_size_override is None: + image_feature_size = _get_clip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + values = torch.zeros((1, image_feature_size, hf_config.hidden_size), + dtype=torch.float16) + return ImageFeatureData(values) + + @classmethod + def _dummy_data_for_llava( + cls, + model_config: ModelConfig, + multimodal_config: VisionLanguageConfig, + hf_config: LlavaConfig, + seq_len: int, + ): + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + seq_data = cls._dummy_seq_data_for_clip( + vision_config, + seq_len, + image_token_id=hf_config.image_token_index, + ) + + image_input_type = multimodal_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + multi_modal_data: MultiModalData + if image_input_type == ImageInputType.PIXEL_VALUES: + multi_modal_data = cls._dummy_pixel_data_for_clip( + vision_config) + elif image_input_type == ImageInputType.IMAGE_FEATURES: + multi_modal_data = cls._dummy_feature_data_for_clip( + vision_config) + + return seq_data, multi_modal_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + @classmethod + def _dummy_data_for_llava_next( + cls, + model_config: ModelConfig, + multimodal_config: VisionLanguageConfig, + hf_config: LlavaNextConfig, + seq_len: int, + ): + vision_config = hf_config.vision_config + + # Result in the max possible feature size + dummy_height = dummy_width = 448 + image_feature_size = _get_llava_next_image_feature_size( + hf_config, input_height=dummy_height, input_width=dummy_width) + + if isinstance(vision_config, CLIPVisionConfig): + seq_data = cls._dummy_seq_data_for_clip( + vision_config, + seq_len, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + image_input_type = multimodal_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + multi_modal_data: MultiModalData + if image_input_type == ImageInputType.PIXEL_VALUES: + multi_modal_data = cls._dummy_pixel_data_for_clip( + vision_config, + image_width_override=dummy_width, + image_height_override=dummy_height, + ) + elif image_input_type == ImageInputType.IMAGE_FEATURES: + multi_modal_data = cls._dummy_feature_data_for_clip( + vision_config, + image_feature_size_override=image_feature_size, + ) + + return seq_data, multi_modal_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + @classmethod + def for_model( + cls, + hf_config_type: Type[PretrainedConfig], + ) -> DummyDataFactory: + """ + Create an dummy image data factory for a model as identified + by the config type. + """ + if hf_config_type == LlavaConfig: + return lambda ctx, seq_len: cls._dummy_data_for_llava( + ctx.model_config, + ctx.get_multimodal_config(), + ctx.get_hf_config(LlavaConfig), + seq_len=seq_len, + ) + if hf_config_type == LlavaNextConfig: + return lambda ctx, seq_len: cls._dummy_data_for_llava_next( + ctx.model_config, + ctx.get_multimodal_config(), + ctx.get_hf_config(LlavaNextConfig), + seq_len=seq_len, + ) + + msg = f"Unsupported model config: {type(hf_config_type)}" + raise NotImplementedError(msg) + + +_T = TypeVar("_T", str, int) + + +class ImageInputProcessors: + """ + Contains factories for image input processors. + + See Also: + :data:`vllm.inputs.registry.InputProcessor` + """ + + @classmethod + def _repeat_and_pad_token( + cls, + token: _T, + *, + repeat_count: int = 1, + pad_token_left: Optional[_T] = None, + pad_token_right: Optional[_T] = None, + ) -> List[_T]: + replacement = [token] * repeat_count + if pad_token_left is not None: + replacement = [pad_token_left] + replacement + if pad_token_right is not None: + replacement = replacement + [pad_token_right] + + return replacement + + @classmethod + def _repeat_and_pad_image_tokens( + cls, + tokenizer: PreTrainedTokenizerBase, + prompt: Optional[str], + prompt_token_ids: List[int], + *, + image_token_id: int, + repeat_count: int = 1, + pad_token_left: Optional[int] = None, + pad_token_right: Optional[int] = None, + ) -> Tuple[Optional[str], List[int]]: + # To avoid invoking the tokenizer, we assume that the + # image token is called "" + if prompt is None: + new_prompt = None + else: + image_token_str = tokenizer.decode(image_token_id) + pad_token_str_left = (None if pad_token_left is None else + tokenizer.decode(pad_token_left)) + pad_token_str_right = (None if pad_token_right is None else + tokenizer.decode(pad_token_right)) + replacement_str = "".join( + cls._repeat_and_pad_token( + image_token_str, + repeat_count=repeat_count, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + + # The image tokens are removed to be consistent with HuggingFace + new_prompt = prompt.replace(image_token_str, replacement_str, 1) + + new_token_ids: List[int] = [] + for i, token in enumerate(prompt_token_ids): + if token == image_token_id: + replacement_ids = cls._repeat_and_pad_token( + image_token_id, + repeat_count=repeat_count, + pad_token_left=pad_token_left, + pad_token_right=pad_token_right, + ) + new_token_ids.extend(replacement_ids) + + # No need to further scan the list since we only replace once + new_token_ids.extend(prompt_token_ids[i + 1:]) + break + else: + new_token_ids.append(token) + + return new_prompt, new_token_ids + + @classmethod + def _input_processor_for_clip( + cls, + model_config: ModelConfig, + multimodal_config: VisionLanguageConfig, + hf_config: CLIPVisionConfig, + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, + ): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or not isinstance( + multi_modal_data, (ImagePixelData, ImageFeatureData)): + return llm_inputs + + tokenizer = _cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_feature_size = _get_clip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = cls._repeat_and_pad_image_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + image_token_id=image_token_id, + repeat_count=image_feature_size, + ) - return seq_data, fake_mm_data + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + @classmethod + def _input_processor_for_llava( + cls, + model_config: ModelConfig, + multimodal_config: VisionLanguageConfig, + hf_config: LlavaConfig, + llm_inputs: LLMInputs, + ): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or not isinstance( + multi_modal_data, (ImagePixelData, ImageFeatureData)): + return llm_inputs + + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + return cls._input_processor_for_clip( + model_config, + multimodal_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + @classmethod + def _input_processor_for_llava_next( + cls, + model_config: ModelConfig, + multimodal_config: VisionLanguageConfig, + hf_config: LlavaNextConfig, + llm_inputs: LLMInputs, + ): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or not isinstance( + multi_modal_data, (ImagePixelData, ImageFeatureData)): + return llm_inputs + + if isinstance(multi_modal_data, ImagePixelData): + image = multi_modal_data.image + if isinstance(image, torch.Tensor): + _, _, _, height, width = image.shape + else: + width, height = image.size + + image_feature_size = _get_llava_next_image_feature_size( + hf_config, input_height=height, input_width=width) + else: + image_features = multi_modal_data.image_features + image_feature_size = image_features.shape[-2] + + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + return cls._input_processor_for_clip( + model_config, + multimodal_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + @classmethod + def for_model( + cls, + hf_config_type: Type[PretrainedConfig], + ) -> InputProcessor: + """ + Create an input processor for a model as identified + by the config type. + """ + if hf_config_type == LlavaConfig: + return lambda ctx, llm_inputs: cls._input_processor_for_llava( + ctx.model_config, + ctx.get_multimodal_config(), + ctx.get_hf_config(LlavaConfig), + llm_inputs=llm_inputs, + ) + if hf_config_type == LlavaNextConfig: + return lambda ctx, llm_inputs: cls._input_processor_for_llava_next( + ctx.model_config, + ctx.get_multimodal_config(), + ctx.get_hf_config(LlavaNextConfig), + llm_inputs=llm_inputs, + ) + + msg = f"Unsupported model config: {type(hf_config_type)}" + raise NotImplementedError(msg) class ImagePixelData(MultiModalData): """ The pixel data of an image. Can be one of: - - :class:``PIL.Image``: An image object. Requires that a HuggingFace + - :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace processor is available to the model. - - :class:``torch.Tensor``: The raw pixel data which is passed to the model + - :class:`torch.Tensor`: The raw pixel data which is passed to the model without additional pre-processing. """ @@ -75,31 +490,38 @@ def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None: self.image = image + def __repr__(self) -> str: + image = self.image + if isinstance(image, Image.Image): + return f"{type(self).__name__}(image={image})" + + return (f"{type(self).__name__}(image=torch.Tensor(shape=" + f"{image.shape}, dtype={image.dtype}))") + class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): def get_data_type(self) -> Type[ImagePixelData]: return ImagePixelData - def _get_hf_image_processor(self, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): + def _get_hf_image_processor(self, model_config: ModelConfig): + vlm_config = model_config.multimodal_config if vlm_config is None or vlm_config.image_processor is None: return None - return cached_get_image_processor( + return _cached_get_image_processor( vlm_config.image_processor, trust_remote_code=model_config.trust_remote_code, revision=vlm_config.image_processor_revision, ) - def _default_input_processor( - self, data: ImagePixelData, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + def _default_input_mapper(self, ctx: InputContext, + data: ImagePixelData) -> Dict[str, torch.Tensor]: + model_config = ctx.model_config image = data.image - image_processor = self._get_hf_image_processor(model_config, - vlm_config) if isinstance(image, Image.Image): + image_processor = self._get_hf_image_processor(model_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available" "to process the image object") @@ -127,15 +549,22 @@ class ImageFeatureData(MultiModalData): def __init__(self, image_features: torch.Tensor) -> None: self.image_features = image_features + def __repr__(self) -> str: + image_features = self.image_features + + return (f"{type(self).__name__}(image_features=torch.Tensor(shape=" + f"{image_features.shape}, dtype={image_features.dtype}))") + class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]): def get_data_type(self) -> Type[ImageFeatureData]: return ImageFeatureData - def _default_input_processor( - self, data: ImageFeatureData, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + def _default_input_mapper( + self, ctx: InputContext, + data: ImageFeatureData) -> Dict[str, torch.Tensor]: + model_config = ctx.model_config image_features = data.image_features.to(model_config.dtype) return {"image_features": image_features} diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 4789ce5ce4cf..758bf43ca8fd 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,29 +1,19 @@ import functools -from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, - Tuple, Type, TypeVar) +from typing import Any, Optional, Sequence, Type, TypeVar -from vllm.config import ModelConfig, VisionLanguageConfig +from torch import nn + +from vllm.config import ModelConfig from vllm.logger import init_logger -from .base import MultiModalData, MultiModalPlugin +from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData, ImagePixelPlugin) -if TYPE_CHECKING: - import torch - from torch import nn - - from vllm.sequence import SequenceData - logger = init_logger(__name__) D = TypeVar("D", bound=MultiModalData) -N = TypeVar("N", bound=Type["nn.Module"]) - -MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig], - Dict[str, "torch.Tensor"]] -MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig], - Tuple["SequenceData", MultiModalData]] +N = TypeVar("N", bound=Type[nn.Module]) class MultiModalRegistry: @@ -34,13 +24,12 @@ class MultiModalRegistry: DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin()) - def __init__(self, - *, - plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS - ) -> None: + def __init__( + self, + *, + plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS, + ) -> None: self._plugins_by_data_type = {p.get_data_type(): p for p in plugins} - self._dummy_factories_by_model_type: Dict[Type["nn.Module"], - MultiModalDummyFactory] = {} def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None: data_type = plugin.get_data_type() @@ -62,95 +51,53 @@ def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]): msg = f"Unknown multi-modal data type: {data_type}" raise NotImplementedError(msg) - def register_dummy_data(self, factory: MultiModalDummyFactory): - """ - Register a dummy data factory to a model class. - - During memory profiling, the provided function is invoked to create - dummy data to be inputted into the model. The modality and shape of - the dummy data should be an upper bound of what the model would receive - at inference time. - """ - - def wrapper(model_cls: N) -> N: - if model_cls in self._dummy_factories_by_model_type: - logger.warning( - "Model class %s already has dummy data " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._dummy_factories_by_model_type[model_cls] = factory - - return model_cls - - return wrapper - - def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): - """Create dummy data for memory profiling.""" - model_cls = MultiModalPlugin.get_model_cls(model_config) - dummy_factory = self._dummy_factories_by_model_type.get(model_cls) - if dummy_factory is None: - msg = f"No dummy data defined for model class: {model_cls}" - raise NotImplementedError(msg) - - return dummy_factory(seq_len, model_config, vlm_config) - - def register_input( - self, - data_type: Type[D], - processor: Optional[MultiModalInputProcessor[D]] = None): + def register_input_mapper( + self, + data_type: Type[D], + mapper: Optional[MultiModalInputMapper[D]] = None, + ): """ - Register an input processor for a specific modality to a model class. + Register an input mapper for a specific modality to a model class. - See :meth:`MultiModalPlugin.register_input_processor` for more details. + See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ return self._get_plugin_for_data_type(data_type) \ - .register_input_processor(processor) + .register_input_mapper(mapper) - def register_image_pixel_input( - self, - processor: Optional[ - MultiModalInputProcessor[ImagePixelData]] = None): + def register_image_pixel_input_mapper( + self, + mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None, + ): """ - Register an input processor for image pixel data to a model class. + Register an input mapper for image pixel data to a model class. - See :meth:`MultiModalPlugin.register_input_processor` for more details. + See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ - return self.register_input(ImagePixelData, processor) + return self.register_input_mapper(ImagePixelData, mapper) - def register_image_feature_input( + def register_image_feature_input_mapper( self, - processor: Optional[ - MultiModalInputProcessor[ImageFeatureData]] = None): + mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None, + ): """ - Register an input processor for image feature data to a model class. + Register an input mapper for image feature data to a model class. - See :meth:`MultiModalPlugin.register_input_processor` for more details. + See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ - return self.register_input(ImageFeatureData, processor) + return self.register_input_mapper(ImageFeatureData, mapper) - def process_input(self, data: MultiModalData, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): + def map_input(self, model_config: ModelConfig, data: MultiModalData): """ - Apply an input processor to a :class:`~MultiModalData` instance passed + Apply an input mapper to a :class:`~MultiModalData` instance passed to the model. - See :meth:`MultiModalPlugin.process_input` for more details. + See :meth:`MultiModalPlugin.map_input` for more details. """ return self._get_plugin_for_data_type(type(data)) \ - .process_input(data, model_config, vlm_config) + .map_input(model_config, data) - def create_input_processor(self, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): + def create_input_mapper(self, model_config: ModelConfig): """ - Create an input processor (see :meth:`process_input`) for a - specific model. + Create an input mapper (see :meth:`map_input`) for a specific model. """ - return functools.partial(self.process_input, - model_config=model_config, - vlm_config=vlm_config) - - -MULTIMODAL_REGISTRY = MultiModalRegistry() -"""The global :class:`~MultiModalRegistry` which is used by model runners.""" + return functools.partial(self.map_input, model_config) diff --git a/vllm/sequence.py b/vllm/sequence.py index 2f27bf33b166..4d549893f982 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -8,12 +8,12 @@ import torch from vllm.block import LogicalTokenBlock -from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams if TYPE_CHECKING: + from vllm.inputs import LLMInputs from vllm.multimodal import MultiModalData from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -221,7 +221,7 @@ class Sequence: def __init__( self, seq_id: int, - inputs: LLMInputs, + inputs: "LLMInputs", block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index 3239b1d0cfa2..2bb5215d4846 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -1,4 +1,3 @@ -from functools import lru_cache from typing import Optional from transformers import AutoImageProcessor @@ -40,6 +39,3 @@ def get_image_processor( raise e return processor - - -cached_get_image_processor = lru_cache(get_image_processor) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index eaf43247d4fc..95d8e44f5111 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -9,10 +9,10 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad @@ -66,14 +66,8 @@ def __init__( ) # Create processor for multi-modal data - if self.vision_language_config is not None: - self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ - .create_input_processor( - self.model_config, - self.vision_language_config, - ) - else: - self.multi_modal_input_processor = None + self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \ + .create_input_mapper(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -123,13 +117,7 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data is not None: - # Process multi-modal data - if self.multi_modal_input_processor is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") - - mm_kwargs = self.multi_modal_input_processor(mm_data) + mm_kwargs = self.multi_modal_input_mapper(mm_data) for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 67c03ad60008..54cba07344c7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -13,15 +13,15 @@ VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.distributed.communication_op import graph_capture +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) @@ -124,14 +124,8 @@ def __init__( ) # Create processor for multi-modal data - if self.vision_language_config is not None: - self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ - .create_input_processor( - self.model_config, - self.vision_language_config, - ) - else: - self.multi_modal_input_processor = None + self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \ + .create_input_mapper(self.model_config) # Lazy initialization self.model: nn.Module # Set after load_model @@ -432,12 +426,7 @@ def _prepare_model_input( mm_data = seq_group_metadata.multi_modal_data if mm_data is not None: # Process multi-modal data - if self.multi_modal_input_processor is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") - - mm_kwargs = self.multi_modal_input_processor(mm_data) + mm_kwargs = self.multi_modal_input_mapper(mm_data) for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) @@ -806,12 +795,8 @@ def profile_run(self) -> None: seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - if vlm_config is None: - seq_data = SequenceData([0] * seq_len) - dummy_multi_modal_data = None - else: - seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \ - .dummy_data_for_profiling(seq_len, model_config, vlm_config) + seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ + .dummy_data_for_profiling(model_config, seq_len) seq = SequenceGroupMetadata( request_id=str(group_id),