Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,8 @@ steps:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pip freeze | grep -E 'torch'
- pytest -v -s models/multimodal/processing
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model
- pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn"
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work

- label: Multi-Modal Models Test (Extended) 1
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def __init__(
tokenizer_mode: str = "auto",
trust_remote_code: bool = True,
seed: Optional[int] = 0,
max_model_len: int = 1024,
max_model_len: Optional[int] = 1024,
dtype: str = "auto",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
Expand Down
195 changes: 195 additions & 0 deletions tests/models/multimodal/test_tensor_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch

import pytest
from transformers import PretrainedConfig

from vllm.config import ModelConfig
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import GiB_bytes, set_default_torch_num_threads
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore

from ...conftest import VllmRunner
from ...utils import fork_new_process_for_each_test
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS

ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
"MiniMaxVL01ForConditionalGeneration": "broken model",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even after fixing the shape check for MiniMax, this model can't still work due to mismatch mm tokens and placeholders:

[rank0]:     inputs_embeds = self.get_input_embeddings(input_ids,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/kaggle/working/vllm/vllm/model_executor/models/minimax_vl_01.py", line 218, in get_input_embeddings
[rank0]:     inputs_embeds = merge_multimodal_embeddings(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/kaggle/working/vllm/vllm/model_executor/models/utils.py", line 511, in merge_multimodal_embeddings
[rank0]:     return _merge_multimodal_embeddings(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/kaggle/working/vllm/vllm/model_executor/models/utils.py", line 427, in _merge_multimodal_embeddings
[rank0]:     raise ValueError(
[rank0]: ValueError: Attempted to assign 19 x 576 = 10944 multimodal tokens to 11088 placeholders

To avoid bloking other migration PRs, I will fix this model in a separate PR.

}


def create_batched_mm_kwargs(
model_config: ModelConfig,
processor: BaseMultiModalProcessor,
) -> MultiModalKwargs:
processing_info = processor.info
dummy_inputs = processor.dummy_inputs
supported_mm_limits = processing_info.get_supported_mm_limits()
mm_counts = {
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}
processor_inputs = dummy_inputs.get_dummy_processor_inputs(
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
)
mm_kwargs = processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)["mm_kwargs"]
mm_kwargs = MultiModalKwargs.batch([mm_kwargs])
return mm_kwargs


@pytest.mark.core_model
@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys()))
@fork_new_process_for_each_test
def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
monkeypatch):
if model_arch in ARCH_TO_SKIP:
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")

model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")

model_id = model_info.default

# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)
text_config = hf_config.get_text_config()
# Ensure at least 2 expert per group
# Since `grouped_topk` assumes top-2
n_group = getattr(text_config, 'n_group', None)
num_experts = n_group * 2 if n_group is not None else 2
# we use three layers for Gemma-3n to check
# both normal layer and kv_shared_layer
text_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
})
if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
# e.g.: ibm-granite/granite-speech-3.3-2b
if hasattr(hf_config, "encoder_config"):
hf_config.encoder_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})

# e.g.: Qwen/Qwen2-Audio-7B-Instruct
if hasattr(hf_config, "audio_config"):
hf_config.audio_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"encoder_layers": 1,
})

return hf_config

model_config = ModelConfig(
model_id,
tokenizer=model_info.tokenizer or model_id,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]

if not any(
hasattr(model_cls, f"_parse_and_validate_{m}_input")
for m in ["image", "video", "audio"]):
pytest.skip(f"{model_arch} does not support tensor schema validation.")

ctx = InputProcessingContext(
model_config,
tokenizer=cached_tokenizer_from_config(model_config),
)
processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits()
limit_mm_per_prompt = {
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}

# Avoid calling model.forward()
def _initialize_kv_caches_v0(self) -> None:
self.cache_config.num_gpu_blocks = 0
self.cache_config.num_cpu_blocks = 0

def _initialize_kv_caches_v1(self, vllm_config):
kv_cache_specs = self.model_executor.get_kv_cache_specs()
scheduler_kv_cache_config = get_kv_cache_config(
vllm_config,
kv_cache_specs[0],
10 * GiB_bytes,
)

# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config

with (patch.object(V0LLMEngine, "_initialize_kv_caches",
_initialize_kv_caches_v0),
patch.object(V1EngineCore, "_initialize_kv_caches",
_initialize_kv_caches_v1), monkeypatch.context() as m):
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")

with (
set_default_torch_num_threads(1),
vllm_runner(
model_id,
tokenizer_name=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
max_model_len=model_info.max_model_len,
load_format="dummy",
hf_overrides=hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
) as vllm_model,
):
model_config = vllm_model.llm.llm_engine.model_config
llm_engine = vllm_model.llm.llm_engine

if hasattr(llm_engine, "processor"):
# v1 processor
mm_registry = llm_engine.processor.mm_registry
else:
# v0 input_preprocessor
mm_registry = llm_engine.input_preprocessor.mm_registry

processor = mm_registry.create_processor(model_config)
mm_kwargs = create_batched_mm_kwargs(model_config, processor)

def validate_model_input(model):
for modality in ("audio", "image", "video"):
method_name = f"_parse_and_validate_{modality}_input"
if hasattr(model, method_name):
getattr(model, method_name)(**mm_kwargs)

vllm_model.apply_model(validate_model_input)
7 changes: 4 additions & 3 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def check_available_online(
min_transformers_version="4.54",
is_available_online=False), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible."), # noqa: E501
Expand Down Expand Up @@ -431,16 +432,16 @@ def check_available_online(
trust_remote_code=True),
"Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501
trust_remote_code=True),
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
trust_remote_code=True,
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ class DeepseekVL2ImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- p: Number of patches
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", 3, "h", "w")]
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]


Expand Down
17 changes: 11 additions & 6 deletions vllm/model_executor/models/keye.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Isotr0py for improving model coverage and fixes! Please feel free to share any feedback to help make these changes safer in the future.

Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,16 @@ def smart_resize(
class KeyeImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- cps: Number of channels * patch_size * patch_size
- c: Number of channels
- ps: Patch size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
pixel_values: Annotated[torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]


Expand All @@ -134,14 +137,16 @@ class KeyeImageEmbeddingInputs(TensorSchema):
class KeyeVideoPixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- ctps: Number of channels * temporal_patch_size * patch_size *
patch_size
- nv: Number of videos
- c: Number of channels
- ps: Patch size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")]
pixel_values_videos: Annotated[torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]


Expand Down
6 changes: 3 additions & 3 deletions vllm/transformers_utils/processors/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,15 @@ def process_one(
def __call__(
self,
*,
prompt: str,
text: str,
images: list[Image.Image],
inference_mode: bool = True,
**kwargs,
):
"""

Args:
prompt (str): the formatted prompt;
text (str): the formatted prompt;
images (list[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
**kwargs:
Expand All @@ -278,7 +278,7 @@ def __call__(
"""

prepare = self.process_one(
prompt=prompt,
prompt=text,
images=images,
inference_mode=inference_mode,
)
Expand Down