diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py new file mode 100644 index 00000000000..38d18b059c2 --- /dev/null +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for mllama's multimodal preprocessing and profiling.""" +import pytest +from torch import prod +from transformers import Llama4Config + +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.profiling import MultiModalProfiler + +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", ["meta-llama/Llama-Guard-4-12B"]) +@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072]) +@pytest.mark.parametrize("max_num_seqs", [1, 2, 8]) +def test_profiling( + model_id: str, + max_model_len: int, + max_num_seqs: int, +): + + model_config_kwargs = { + "max_model_len": max_model_len, + } + ctx = build_model_context( + model_id, + model_config_kwargs=model_config_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_config = ctx.get_mm_config() + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + profiler = MultiModalProfiler(processor) + + decoder_dummy_data = profiler.get_decoder_dummy_data( + max_model_len, + mm_counts=mm_config.limit_per_prompt, + ) + dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs( + max_model_len, + mm_counts=mm_config.limit_per_prompt, + ) + + hf_config = ctx.get_hf_config(Llama4Config) + + mm_kwargs = processor.apply( + prompt=dummy_mm_data.prompt, + mm_data=dummy_mm_data.mm_data, + hf_processor_mm_kwargs=dict(), + )["mm_kwargs"] + + image_size = hf_config.vision_config.image_size + patch_size = hf_config.vision_config.patch_size + downsample_ratio = int( + round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))) + tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio + chunks_per_image = prod(mm_kwargs["patches_per_image"]) + total_num_patches = chunks_per_image * tokens_per_patch + num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][ + 0][1] # x-y seperator tokens + total_tokens = total_num_patches.item() + num_tiles.item( + ) + 3 # image start, image, image end + + profiled_tokens = profiler.get_mm_max_tokens( + max_model_len, + mm_counts=mm_config.limit_per_prompt, + ) + + assert total_tokens == profiled_tokens["image"] + assert total_tokens == sum( + placeholder.length for placeholder in + decoder_dummy_data.multi_modal_placeholders["image"]) diff --git a/tests/models/registry.py b/tests/models/registry.py index e6543c19734..acfe91f46cb 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -337,7 +337,9 @@ def check_available_online( trust_remote_code=True, v0_only=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 - max_model_len=10240), + max_model_len=10240, + extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501 + ), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501 diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 1faecb7bd24..fec84d01f8a 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -179,7 +179,7 @@ def _get_mm_num_tokens( placeholders_by_modality = mm_inputs["mm_placeholders"] return { - modality: sum(item.get_num_embeds() for item in placeholders) + modality: sum(item.length for item in placeholders) for modality, placeholders in placeholders_by_modality.items() }