Skip to content
Merged
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ |
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand Down
33 changes: 33 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,38 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
)


# Ovis2_5
def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData:
model_name = "AIDC-AI/Ovis2.5-2B"

engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
dtype="half",
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<image>"
elif modality == "video":
placeholder = "<video>"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"{placeholder}\n{question}"}]
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# PaliGemma
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
Expand Down Expand Up @@ -1469,6 +1501,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"nemotron_vl": run_nemotron_vl,
"NVLM_D": run_nvlm_d,
"ovis": run_ovis,
"ovis2_5": run_ovis2_5,
"paligemma": run_paligemma,
"paligemma2": run_paligemma2,
"phi3_v": run_phi3v,
Expand Down
31 changes: 31 additions & 0 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,36 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
)


# ovis2_5
def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "AIDC-AI/Ovis2.5-2B"

engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
trust_remote_code=True,
dtype="half",
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)


def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistral-community/pixtral-12b"

Expand Down Expand Up @@ -1085,6 +1115,7 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData:
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"ovis": load_ovis,
"ovis2_5": load_ovis2_5,
"phi3_v": load_phi3v,
"phi4_mm": load_phi4mm,
"phi4_multimodal": load_phi4_multimodal,
Expand Down
21 changes: 21 additions & 0 deletions tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
from transformers import (AutoModel, AutoModelForImageTextToText,
AutoModelForTextToWaveform, AutoModelForVision2Seq)
from transformers.utils import is_flash_attn_2_available

from vllm.platforms import current_platform
from vllm.utils import identity
Expand Down Expand Up @@ -621,6 +622,26 @@
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
patch_hf_runner=model_utils.ovis_patch_hf_runner,
),
"ovis2_5": VLMTestInfo(
models=["AIDC-AI/Ovis2.5-2B"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO
),
prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
video_idx_to_prompt=lambda idx: "<video>\n",
max_model_len=4096,
max_num_seqs=2,
dtype="half",
num_logprobs=10,
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
marks=[pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="HF model needs `flash_attn` installed"
)],
Comment on lines +640 to +643
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@myselvess The HF implementation's ViT needs flash_attn installed (https://huggingface.co/AIDC-AI/Ovis2.5-2B/blob/main/modeling_ovis2_5.py#L221-L250), which won't be installed on our CI. Do we have plans to support SDPA as a fallback for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have supported SDPA in siglip2navit https://github.com/myselvess/vllm/blob/ovis2_5_new/vllm/model_executor/models/siglip2navit.py#L257-L281. Do you mean that sdpa is also required in hf implementation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that sdpa is also required in hf implementation?

Yes, because we need to run HF implementation on CI without FA installed for correctness comparison.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that sdpa is also required in hf implementation?

Yes, because we need to run HF implementation on CI without FA installed for correctness comparison.

OK, I'll tell my colleagues to add this part.

),
"phi3v": VLMTestInfo(
models=["microsoft/Phi-3.5-vision-instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
Expand Down
58 changes: 58 additions & 0 deletions tests/models/multimodal/generation/vlm_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import numpy.typing as npt
import PIL.Image
import pytest
import regex as re
import torch
Expand Down Expand Up @@ -810,6 +811,63 @@ def processor(*args, text="", images=None, **kwargs):
return hf_model


def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.llm.get_output_embeddings()

def processor(*args, text="", images=None, videos=None, **kwargs):
if images is None:
images = []
else:
images = [images] if isinstance(images, Image) else images
if videos is None:
videos = []
else:
videos = [videos] if isinstance(videos, np.ndarray) else videos
videos = [[PIL.Image.fromarray(frame) for frame in vid]
for vid in videos]

prompt_start_and_end = {
"qwen2": ("<|im_start|>user\n", "<|im_end|>\n"),
"llama":
("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"),
"gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"),
}
for start, end in prompt_start_and_end.values():
if start in text and end in text:
text = text.split(start)[1].split(end)[0]
break

images_message = [{"type": "image", "image": img} for img in images]
videos_message = [{"type": "video", "video": vid} for vid in videos]

messages = [{
"role":
"user",
"content": [
*images_message,
*videos_message,
{
"type": "text",
"text": text
},
],
}]

input_ids, pixel_values, grid_thws = hf_model.model.preprocess_inputs(
messages=messages, enable_thinking=True)
inputs = {
"inputs": input_ids,
"pixel_values": pixel_values,
"grid_thws": grid_thws,
}
return BatchFeature(data=inputs, tensor_type="pt")

hf_model.processor = processor
return hf_model


def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner for Qwen2.5-Omni."""
thinker = hf_model.model.thinker
Expand Down
2 changes: 2 additions & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _test_processing_correctness(
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"mllama": False,
"ovis": False,
"ovis2_5": False,
"paligemma": False,
"ultravox": False,
"whisper": False,
Expand Down Expand Up @@ -301,6 +302,7 @@ def _test_processing_correctness_one(
"AIDC-AI/Ovis1.6-Gemma2-9B",
"AIDC-AI/Ovis1.6-Llama3.2-3B",
"AIDC-AI/Ovis2-1B",
"AIDC-AI/Ovis2.5-2B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-3.5-vision-instruct",
Expand Down
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,9 @@ def check_available_online(
transformers_version_reason="HF model is not compatible", # noqa: E501
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", trust_remote_code=True,
max_transformers_version="4.53",
transformers_version_reason="HF model is not compatible"), # noqa: E501
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
Expand Down
Loading