Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions tests/models/multimodal/processing/test_mllama4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mllama's multimodal preprocessing and profiling."""
import pytest
from torch import prod
from transformers import Llama4Config

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.profiling import MultiModalProfiler

from ...utils import build_model_context


@pytest.mark.parametrize("model_id", ["meta-llama/Llama-Guard-4-12B"])
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
@pytest.mark.parametrize("max_num_seqs", [1, 2, 8])
def test_profiling(
model_id: str,
max_model_len: int,
max_num_seqs: int,
):

model_config_kwargs = {
"max_model_len": max_model_len,
}
ctx = build_model_context(
model_id,
model_config_kwargs=model_config_kwargs,
limit_mm_per_prompt={"image": 1},
)

mm_config = ctx.get_mm_config()
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
profiler = MultiModalProfiler(processor)

decoder_dummy_data = profiler.get_decoder_dummy_data(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)

hf_config = ctx.get_hf_config(Llama4Config)

mm_kwargs = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]

image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
chunks_per_image = prod(mm_kwargs["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][
0][1] # x-y seperator tokens
total_tokens = total_num_patches.item() + num_tiles.item(
) + 3 # image start, image, image end

profiled_tokens = profiler.get_mm_max_tokens(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)

assert total_tokens == profiled_tokens["image"]
assert total_tokens == sum(
placeholder.length for placeholder in
decoder_dummy_data.multi_modal_placeholders["image"])
4 changes: 3 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def check_available_online(
trust_remote_code=True,
v0_only=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
max_model_len=10240),
max_model_len=10240,
extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501
),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
Expand Down
2 changes: 1 addition & 1 deletion vllm/multimodal/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _get_mm_num_tokens(
placeholders_by_modality = mm_inputs["mm_placeholders"]

return {
modality: sum(item.get_num_embeds() for item in placeholders)
modality: sum(item.length for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}

Expand Down
Loading