Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ Specified using `--task generate`.
* ✅︎
- * `InternVLChatModel`
* InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0
* T + I<sup>E+</sup>
* T + I<sup>E+</sup> + (V<sup>E+</sup>)
* `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc.
*
* ✅︎
Expand Down Expand Up @@ -1241,6 +1241,10 @@ V1 currently uses a simplified attention pattern:
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
:::

:::{note}
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
:::

:::{note}
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
:::
Expand Down
15 changes: 11 additions & 4 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,22 +329,26 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:

# InternVL
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"

model_name = "OpenGVLab/InternVL2-2B"
model_name = "OpenGVLab/InternVL3-2B"

engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_model_len=8192,
limit_mm_per_prompt={modality: 1},
)

if modality == "image":
placeholder = "<image>"
elif modality == "video":
placeholder = "<video>"

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
'content': f"{placeholder}\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
Expand All @@ -356,6 +360,9 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
stop_token_ids = [
token_id for token_id in stop_token_ids if token_id is not None
]

return ModelRequestData(
engine_args=engine_args,
Expand Down
11 changes: 11 additions & 0 deletions tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,17 @@
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
),
"intern_vl-video": VLMTestInfo(
models=[
"OpenGVLab/InternVL3-1B",
],
test_type=VLMTestType.VIDEO,
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
video_idx_to_prompt=lambda idx: "<video>",
max_model_len=8192,
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
),
"kimi_vl": VLMTestInfo(
models=["moonshotai/Kimi-VL-A3B-Instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
Expand Down
86 changes: 66 additions & 20 deletions tests/models/multimodal/generation/vlm_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pathlib import PosixPath
from typing import Optional, Union

import numpy as np
import numpy.typing as npt
import torch
from PIL.Image import Image
from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
Expand Down Expand Up @@ -495,30 +497,74 @@ def __init__(self, hf_runner: HfRunner):
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size

def __call__(self, text: str, images: Union[Image, list[Image]],
**kwargs):
def __call__(
self,
text: str,
images: Union[Image, list[Image]] = None,
videos: Union[npt.NDArray, list[npt.NDArray]] = None,
**kwargs,
):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START,
image_to_pixel_values_internvl)
image_to_pixel_values_internvl, video_to_pixel_values_internvl)
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values_internvl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
videos = [videos] if isinstance(videos, np.ndarray) else videos
if images is not None:
pixel_values_images = [
image_to_pixel_values_internvl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_images = [
pixel_value.shape[0] for pixel_value in pixel_values_images
]
else:
pixel_values_images, num_patches_images = [], []

if videos is not None:
pixel_values_videos = [
video_to_pixel_values_internvl(
video,
input_size=self.image_size,
min_num=1,
max_num=1,
use_thumbnail=False,
) for video in videos
]
num_patches_videos = [
pixel_value.shape[0] for pixel_value in pixel_values_videos
]
else:
pixel_values_videos, num_patches_videos = [], []

pixel_values = []
while ("<image>" in text) or ("<video>" in text):
image_index = text.find("<image>")
video_index = text.find("<video>")
if image_index == -1 or (video_index > -1
and video_index < image_index):
num_patches = num_patches_videos.pop(0)
pixel_values.append(pixel_values_videos.pop(0))
context_tokens = IMG_START + \
IMG_CONTEXT * self.num_image_token + IMG_END
video_tokens = ''.join([
f'Frame{i+1}: {context_tokens}'
for i in range(num_patches)
])
text = text.replace('<video>', video_tokens, 1)
else:
num_patches = num_patches_images.pop(0)
pixel_values.append(pixel_values_images.pop(0))
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)

prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt
Expand Down
1 change: 1 addition & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def _test_processing_correctness_mistral(
"ibm-granite/granite-speech-3.3-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"OpenGVLab/InternVL3-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"moonshotai/Kimi-VL-A3B-Instruct",
Expand Down
3 changes: 2 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def check_available_online(
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible."), # noqa: E501
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
extras={"2B": "OpenGVLab/InternVL2-2B",
"3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,8 @@ def _placeholder_str(self, modality: ModalityStr,
return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "internvl_chat":
return "<video>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
Expand Down
Loading