diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py new file mode 100644 index 00000000000..f3871b60c3f --- /dev/null +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for mllama's multimodal preprocessing and profiling.""" +import pytest +from torch import prod +from transformers import Llama4Config + +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.profiling import MultiModalProfiler + +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", ["meta-llama/Llama-Guard-4-12B"]) +@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072]) +def test_profiling(model_id: str, max_model_len: int): + model_config_kwargs = { + "max_model_len": max_model_len, + } + ctx = build_model_context( + model_id, + model_config_kwargs=model_config_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_config = ctx.get_mm_config() + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + profiler = MultiModalProfiler(processor) + + decoder_dummy_data = profiler.get_decoder_dummy_data( + max_model_len, + mm_counts=mm_config.limit_per_prompt, + ) + dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs( + max_model_len, + mm_counts=mm_config.limit_per_prompt, + ) + + hf_config = ctx.get_hf_config(Llama4Config) + + mm_kwargs = processor.apply( + prompt=dummy_mm_data.prompt, + mm_data=dummy_mm_data.mm_data, + hf_processor_mm_kwargs=dict(), + )["mm_kwargs"] + + image_size = hf_config.vision_config.image_size + patch_size = hf_config.vision_config.patch_size + downsample_ratio = int( + round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))) + tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio + chunks_per_image = prod(mm_kwargs["patches_per_image"]) + total_num_patches = chunks_per_image * tokens_per_patch + num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][ + 0][1] # x-y seperator tokens + total_tokens = total_num_patches.item() + num_tiles.item( + ) + 3 # image start, image, image end + + profiled_tokens = profiler.get_mm_max_contiguous_tokens( + max_model_len, + mm_counts=mm_config.limit_per_prompt, + ) + + assert total_tokens == profiled_tokens["image"] + assert total_tokens == sum( + placeholder.length for placeholder in + decoder_dummy_data.multi_modal_placeholders["image"]) diff --git a/tests/models/registry.py b/tests/models/registry.py index e6543c19734..acfe91f46cb 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -337,7 +337,9 @@ def check_available_online( trust_remote_code=True, v0_only=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 - max_model_len=10240), + max_model_len=10240, + extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501 + ), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501 diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 1faecb7bd24..fa551a68979 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -175,11 +175,14 @@ def _get_dummy_mm_inputs( def _get_mm_num_tokens( self, mm_inputs: MultiModalInputs, + mm_embeddings_only: bool = True, ) -> Mapping[str, int]: placeholders_by_modality = mm_inputs["mm_placeholders"] return { - modality: sum(item.get_num_embeds() for item in placeholders) + modality: + sum(item.get_num_embeds() if mm_embeddings_only else item.length + for item in placeholders) for modality, placeholders in placeholders_by_modality.items() } @@ -248,11 +251,33 @@ def get_decoder_dummy_data( multi_modal_placeholders=mm_inputs["mm_placeholders"], ) - def get_mm_max_tokens( + def _get_mm_max_tokens( self, seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + mm_embeddings_only: bool = True, ) -> Mapping[str, int]: mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - return self._get_mm_num_tokens(mm_inputs) + return self._get_mm_num_tokens(mm_inputs, + mm_embeddings_only=mm_embeddings_only) + + def get_mm_max_contiguous_tokens( + self, + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ): + """ + Returns the maximum length of the multimodal (image placeholders+text) + tokens, including any break/text tokens in-between image embeddings. + + [IMG] [IMG] [IMG] [IMG] [IMG] [IMG] + Returns 9, even when the number of image embeddings is 6. + + This is important to take into account when profiling and + initializing the encoder cache size. + """ + + return self._get_mm_max_tokens(seq_len, + mm_counts, + mm_embeddings_only=False) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 27aaa661c35..c9a2a60afea 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -129,7 +129,7 @@ def get_max_tokens_per_item_by_modality( seq_len = model_config.max_model_len mm_limits = self.get_mm_limits_per_prompt(model_config) - return profiler.get_mm_max_tokens( + return profiler.get_mm_max_contiguous_tokens( seq_len, { modality: 1