Skip to content

Commit

Permalink
[VLM][Core] Support profiling with multiple multi-modal inputs per pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Aug 14, 2024
1 parent 70b746e commit 3f674a4
Show file tree
Hide file tree
Showing 38 changed files with 573 additions and 217 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ Input Processing Pipeline

6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.

- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision model.
3 changes: 3 additions & 0 deletions docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ by following :ref:`this guide <adding_multimodal_plugin>`.

Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.

..
TODO: Add usage of --limit-mm-per-prompt when multi-image input is officially supported
Guides
++++++

Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/enabling_multimodal_inputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
3. Register maximum number of multi-modal tokens
------------------------------------------------

For each modality type that the model accepts as input, calculate the maximum possible number of tokens
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.

.. code-block:: diff
Expand Down
24 changes: 24 additions & 0 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser


@pytest.mark.parametrize(("arg", "expected"), [
(None, None),
("image=16", {
"image": 16
}),
("image=16,video=2", {
"image": 16,
"video": 2
}),
])
def test_limit_mm_per_prompt_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--limit-mm-per-prompt", arg])

assert args.limit_mm_per_prompt == expected
2 changes: 1 addition & 1 deletion tests/models/test_blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand Down Expand Up @@ -176,7 +176,7 @@ def run_multi_image_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand All @@ -197,6 +197,7 @@ def run_multi_image_test(
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
limit_mm_per_prompt={"image": len(images)},
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand Down
84 changes: 78 additions & 6 deletions tests/multimodal/test_mapper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from contextlib import nullcontext

import numpy as np
import pytest
from transformers import CLIPImageProcessor, LlavaNextImageProcessor

from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.config import ModelConfig, MultiModalConfig
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size


@pytest.fixture
def mm_registry():
return MultiModalRegistry()


@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_clip_image_processor(image_assets, dtype, size_factor):
def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"

hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
Expand All @@ -24,6 +31,9 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
dtype=dtype,
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})

mm_registry.init_mm_limits_per_prompt(model_config, mm_config)

for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
Expand All @@ -32,7 +42,7 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
image,
return_tensors="pt",
)
vllm_result = MULTIMODAL_REGISTRY.map_input(
vllm_result = mm_registry.map_input(
model_config,
{"image": image},
)
Expand All @@ -48,7 +58,8 @@ def test_clip_image_processor(image_assets, dtype, size_factor):

@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_llava_next_image_processor(image_assets, dtype, size_factor):
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
size_factor):
MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"

hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
Expand All @@ -63,6 +74,9 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
dtype=dtype,
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})

mm_registry.init_mm_limits_per_prompt(model_config, mm_config)

for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
Expand All @@ -71,7 +85,7 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
image,
return_tensors="pt",
)
vllm_result = MULTIMODAL_REGISTRY.map_input(
vllm_result = mm_registry.map_input(
model_config,
{"image": image},
)
Expand All @@ -83,3 +97,61 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):

assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"


@pytest.mark.parametrize(
("num_images", "limit", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"

model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": limit})

mm_registry.init_mm_limits_per_prompt(model_config, mm_config)

image = image_assets[0].pil_image
if num_images == 0:
mm_inputs = {}
elif num_images == 1:
mm_inputs = {"image": image}
else:
mm_inputs = {"image": [image] * num_images}

with nullcontext() if is_valid else pytest.raises(ValueError):
mm_registry.map_input(model_config, mm_inputs)


# NOTE: We don't test zero images since the HF processor doesn't support it
@pytest.mark.parametrize("num_images", [1, 2])
def test_image_mapper_multi(image_assets, mm_registry, num_images):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"

model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": num_images})

mm_registry.init_mm_limits_per_prompt(model_config, mm_config)

image = image_assets[0].pil_image
mm_inputs = {"image": [image] * num_images}

mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
assert len(mapped_inputs["pixel_values"]) == num_images
14 changes: 10 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple,
Type, Union)

import torch
from transformers import PretrainedConfig
Expand Down Expand Up @@ -1429,10 +1430,15 @@ def verify_with_model_config(self, model_config: ModelConfig):

@dataclass
class MultiModalConfig:
"""Configs the input data format and how models should run for
multimodal models."""
"""Controls the behavior of multimodal models."""

limit_per_prompt: Mapping[str, int]
"""
The maximum number of multi-modal input instances allowed per prompt
for each :class:`~vllm.multimodal.MultiModalPlugin`.
"""

# TODO: Add configs to init vision tower or not.
pass


_STR_DTYPE_TO_TORCH_DTYPE = {
Expand Down
48 changes: 43 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import dataclasses
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
Union)

from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
Expand All @@ -15,8 +16,7 @@
from vllm.utils import FlexibleArgumentParser

if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup

logger = init_logger(__name__)

Expand All @@ -29,11 +29,32 @@ def nullable_str(val: str):
return val


def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
if len(val) == 0:
return None

out_dict: Dict[str, int] = {}
for item in val.split(","):
try:
key, value = item.split("=")
except TypeError as exc:
msg = "Each item should be in the form KEY=VALUE"
raise ValueError(msg) from exc

try:
out_dict[key] = int(value)
except ValueError as exc:
msg = f"Failed to parse value of item {key}={value}"
raise ValueError(msg) from exc

return out_dict


@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[List[str]]] = None
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
Expand Down Expand Up @@ -81,6 +102,7 @@ class EngineArgs:
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
Expand Down Expand Up @@ -435,6 +457,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')

# Multimodal related configs
parser.add_argument(
'--limit-mm-per-prompt',
type=nullable_kvs,
default=EngineArgs.limit_mm_per_prompt,
# The default value is given in
# MultiModalRegistry.init_mm_limits_per_prompt
help=('For each multimodal plugin, limit how many '
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'))

# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
Expand Down Expand Up @@ -709,7 +746,8 @@ def create_engine_config(self, ) -> EngineConfig:
"CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")

multimodal_config = MultiModalConfig()
multimodal_config = MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt or {})

device_config = DeviceConfig(device=self.device)
model_config = ModelConfig(
Expand Down
Loading

0 comments on commit 3f674a4

Please sign in to comment.