diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 598fd5762985..1a6596328728 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -581,7 +581,8 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - pytest -v -s models/multimodal/processing - - pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model + - pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model + - pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn" - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 diff --git a/tests/conftest.py b/tests/conftest.py index 67f0e7424038..3f3790cab8d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -775,7 +775,7 @@ def __init__( tokenizer_mode: str = "auto", trust_remote_code: bool = True, seed: Optional[int] = 0, - max_model_len: int = 1024, + max_model_len: Optional[int] = 1024, dtype: str = "auto", disable_log_stats: bool = True, tensor_parallel_size: int = 1, diff --git a/tests/models/multimodal/test_tensor_schema.py b/tests/models/multimodal/test_tensor_schema.py new file mode 100644 index 000000000000..bdc62b1d2682 --- /dev/null +++ b/tests/models/multimodal/test_tensor_schema.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import partial +from typing import Any +from unittest.mock import patch + +import pytest +from transformers import PretrainedConfig + +from vllm.config import ModelConfig +from vllm.engine.llm_engine import LLMEngine as V0LLMEngine +from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.utils import GiB_bytes, set_default_torch_num_threads +from vllm.v1.core.kv_cache_utils import get_kv_cache_config +from vllm.v1.engine.core import EngineCore as V1EngineCore + +from ...conftest import VllmRunner +from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS + +ARCH_TO_SKIP = { + "MolmoForCausalLM": "incompatible requirements", + "MiniMaxVL01ForConditionalGeneration": "broken model", +} + + +def create_batched_mm_kwargs( + model_config: ModelConfig, + processor: BaseMultiModalProcessor, +) -> MultiModalKwargs: + processing_info = processor.info + dummy_inputs = processor.dummy_inputs + supported_mm_limits = processing_info.get_supported_mm_limits() + mm_counts = { + modality: 3 if limit is None else limit + for modality, limit in supported_mm_limits.items() + } + processor_inputs = dummy_inputs.get_dummy_processor_inputs( + seq_len=model_config.max_model_len, + mm_counts=mm_counts, + ) + mm_kwargs = processor.apply( + prompt=processor_inputs.prompt, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + tokenization_kwargs=processor_inputs.tokenization_kwargs, + )["mm_kwargs"] + mm_kwargs = MultiModalKwargs.batch([mm_kwargs]) + return mm_kwargs + + +# Avoid OOM and reduce initialization time by only using 1 layer +def hf_overrides(hf_config: PretrainedConfig, + exist_overrides: dict[str, Any]) -> PretrainedConfig: + hf_config.update(exist_overrides) + text_config = hf_config.get_text_config() + # Ensure at least 2 expert per group + # Since `grouped_topk` assumes top-2 + n_group = getattr(text_config, 'n_group', None) + num_experts = n_group * 2 if n_group is not None else 2 + # we use three layers for Gemma-3n to check + # both normal layer and kv_shared_layer + text_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + "num_experts": num_experts, + "num_experts_per_tok": 2, + "num_local_experts": num_experts, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": num_experts, + # For Gemma-3n + "num_kv_shared_layers": 1, + }) + if hasattr(hf_config, "vision_config"): + hf_config.vision_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + # e.g.: ibm-granite/granite-speech-3.3-2b + if hasattr(hf_config, "encoder_config"): + hf_config.encoder_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + # e.g.: Qwen/Qwen2-Audio-7B-Instruct + if hasattr(hf_config, "audio_config"): + hf_config.audio_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + "encoder_layers": 1, + }) + return hf_config + + +@pytest.mark.core_model +@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys())) +def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner], + monkeypatch): + if model_arch in ARCH_TO_SKIP: + pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") + + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + model_info.check_available_online(on_fail="skip") + + model_id = model_info.default + + hf_overrides_fn = partial(hf_overrides, + exist_overrides=model_info.hf_overrides) + + model_config = ModelConfig( + model_id, + tokenizer=model_info.tokenizer or model_id, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] + + if not any( + hasattr(model_cls, f"_parse_and_validate_{m}_input") + for m in ["image", "video", "audio"]): + pytest.skip(f"{model_arch} does not support tensor schema validation.") + + ctx = InputProcessingContext( + model_config, + tokenizer=cached_tokenizer_from_config(model_config), + ) + processing_info = factories.info(ctx) + supported_mm_limits = processing_info.get_supported_mm_limits() + limit_mm_per_prompt = { + modality: 3 if limit is None else limit + for modality, limit in supported_mm_limits.items() + } + + # Avoid calling model.forward() + def _initialize_kv_caches_v0(self) -> None: + self.cache_config.num_gpu_blocks = 0 + self.cache_config.num_cpu_blocks = 0 + + def _initialize_kv_caches_v1(self, vllm_config): + kv_cache_specs = self.model_executor.get_kv_cache_specs() + scheduler_kv_cache_config = get_kv_cache_config( + vllm_config, + kv_cache_specs[0], + 10 * GiB_bytes, + ) + + # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config + return 1, 0, scheduler_kv_cache_config + + with (patch.object(V0LLMEngine, "_initialize_kv_caches", + _initialize_kv_caches_v0), + patch.object(V1EngineCore, "_initialize_kv_caches", + _initialize_kv_caches_v1), monkeypatch.context() as m): + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + if model_info.v0_only: + m.setenv("VLLM_USE_V1", "0") + + with ( + set_default_torch_num_threads(1), + vllm_runner( + model_id, + tokenizer_name=model_info.tokenizer, + tokenizer_mode=model_info.tokenizer_mode, + revision=model_info.revision, + trust_remote_code=model_info.trust_remote_code, + max_model_len=model_info.max_model_len, + load_format="dummy", + hf_overrides=hf_overrides_fn, + limit_mm_per_prompt=limit_mm_per_prompt, + enforce_eager=True, + ) as vllm_model, + ): + model_config = vllm_model.llm.llm_engine.model_config + llm_engine = vllm_model.llm.llm_engine + + if hasattr(llm_engine, "processor"): + # v1 processor + mm_registry = llm_engine.processor.mm_registry + else: + # v0 input_preprocessor + mm_registry = llm_engine.input_preprocessor.mm_registry + + processor = mm_registry.create_processor(model_config) + mm_kwargs = create_batched_mm_kwargs(model_config, processor) + + def validate_model_input(model): + for modality in ("audio", "image", "video"): + method_name = f"_parse_and_validate_{modality}_input" + if hasattr(model, method_name): + getattr(model, method_name)(**mm_kwargs) + + vllm_model.apply_model(validate_model_input) diff --git a/tests/models/registry.py b/tests/models/registry.py index 806342a57dfa..1c8677492409 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -382,6 +382,7 @@ def check_available_online( min_transformers_version="4.54", is_available_online=False), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", + trust_remote_code=True, extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible."), # noqa: E501 @@ -431,6 +432,9 @@ def check_available_online( trust_remote_code=True), "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501 trust_remote_code=True), + "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, + extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", + "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", @@ -438,9 +442,6 @@ def check_available_online( max_transformers_version="4.48", transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501 - "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, - extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", - "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True), "Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501 diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 531018625478..e0acca75d9dd 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -51,13 +51,14 @@ class DeepseekVL2ImagePixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images + - p: Number of patches - c: Number of channels (3) - h: Height of each image - w: Width of each image """ type: Literal["pixel_values"] data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", 3, "h", "w")] + TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})] images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)] diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 4d8aa8de0f0b..40c66c226850 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -104,13 +104,16 @@ def smart_resize( class KeyeImagePixelInputs(TensorSchema): """ Dimensions: + - b: Batch size - np: Number of patches - - cps: Number of channels * patch_size * patch_size + - c: Number of channels + - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values"] - pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] + pixel_values: Annotated[torch.Tensor, + TensorShape("b", "np", 3, "ps", "ps")] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -134,14 +137,16 @@ class KeyeImageEmbeddingInputs(TensorSchema): class KeyeVideoPixelInputs(TensorSchema): """ Dimensions: + - b: Batch size - np: Number of patches - - ctps: Number of channels * temporal_patch_size * patch_size * - patch_size - - nv: Number of videos + - c: Number of channels + - ps: Patch size + - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ type: Literal["pixel_values_videos"] - pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")] + pixel_values_videos: Annotated[torch.Tensor, + TensorShape("b", "np", 3, "ps", "ps")] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py index b4669d12fa21..5896bde31265 100644 --- a/vllm/transformers_utils/processors/deepseek_vl2.py +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -256,7 +256,7 @@ def process_one( def __call__( self, *, - prompt: str, + text: str, images: list[Image.Image], inference_mode: bool = True, **kwargs, @@ -264,7 +264,7 @@ def __call__( """ Args: - prompt (str): the formatted prompt; + text (str): the formatted prompt; images (list[ImageType]): the list of images; inference_mode (bool): if True, then remove the last eos token; **kwargs: @@ -278,7 +278,7 @@ def __call__( """ prepare = self.process_one( - prompt=prompt, + prompt=text, images=images, inference_mode=inference_mode, )