Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tests/models/multimodal/processing/test_mllama4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mllama's multimodal preprocessing and profiling."""
import pytest
from torch import prod
from transformers import Llama4Config

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.profiling import MultiModalProfiler

from ...utils import build_model_context


@pytest.mark.parametrize("model_id", ["meta-llama/Llama-Guard-4-12B"])
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
def test_profiling(model_id: str, max_model_len: int):
model_config_kwargs = {
"max_model_len": max_model_len,
}
ctx = build_model_context(
model_id,
model_config_kwargs=model_config_kwargs,
limit_mm_per_prompt={"image": 1},
)

mm_config = ctx.get_mm_config()
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
profiler = MultiModalProfiler(processor)

decoder_dummy_data = profiler.get_decoder_dummy_data(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)

hf_config = ctx.get_hf_config(Llama4Config)

mm_kwargs = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]

image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
chunks_per_image = prod(mm_kwargs["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][
0][1] # x-y seperator tokens
total_tokens = total_num_patches.item() + num_tiles.item(
) + 3 # image start, image, image end

profiled_tokens = profiler.get_mm_max_contiguous_tokens(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)

assert total_tokens == profiled_tokens["image"]
assert total_tokens == sum(
placeholder.length for placeholder in
decoder_dummy_data.multi_modal_placeholders["image"])
4 changes: 3 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def check_available_online(
trust_remote_code=True,
v0_only=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
max_model_len=10240),
max_model_len=10240,
extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501
),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
Expand Down
31 changes: 28 additions & 3 deletions vllm/multimodal/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,14 @@ def _get_dummy_mm_inputs(
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]

return {
modality: sum(item.get_num_embeds() for item in placeholders)
modality:
sum(item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}

Expand Down Expand Up @@ -248,11 +251,33 @@ def get_decoder_dummy_data(
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)

def get_mm_max_tokens(
def _get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)

return self._get_mm_num_tokens(mm_inputs)
return self._get_mm_num_tokens(mm_inputs,
mm_embeddings_only=mm_embeddings_only)

def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
):
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.

<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>
Returns 9, even when the number of image embeddings is 6.

This is important to take into account when profiling and
initializing the encoder cache size.
"""

return self._get_mm_max_tokens(seq_len,
mm_counts,
mm_embeddings_only=False)
2 changes: 1 addition & 1 deletion vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_max_tokens_per_item_by_modality(
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config)

return profiler.get_mm_max_tokens(
return profiler.get_mm_max_contiguous_tokens(
seq_len,
{
modality: 1
Expand Down
Loading