Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tests/models/multimodal/processing/test_mllama4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mllama's multimodal preprocessing and profiling."""
import pytest
from torch import prod
from transformers import Llama4Config

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.profiling import MultiModalProfiler

from ...utils import build_model_context


@pytest.mark.parametrize("model_id", ["meta-llama/Llama-Guard-4-12B"])
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
def test_profiling(model_id: str, max_model_len: int):
model_config_kwargs = {
"max_model_len": max_model_len,
}
ctx = build_model_context(
model_id,
model_config_kwargs=model_config_kwargs,
limit_mm_per_prompt={"image": 1},
)

mm_config = ctx.get_mm_config()
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
profiler = MultiModalProfiler(processor)

decoder_dummy_data = profiler.get_decoder_dummy_data(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)

hf_config = ctx.get_hf_config(Llama4Config)

mm_kwargs = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]

image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
downsample_ratio = int(
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
chunks_per_image = prod(mm_kwargs["patches_per_image"])
total_num_patches = chunks_per_image * tokens_per_patch
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][
0][1] # x-y seperator tokens
total_tokens = total_num_patches.item() + num_tiles.item(
) + 3 # image start, image, image end

profiled_tokens = profiler.get_mm_max_contiguous_tokens(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)

assert total_tokens == profiled_tokens["image"]
assert total_tokens == sum(
placeholder.length for placeholder in
decoder_dummy_data.multi_modal_placeholders["image"])
4 changes: 3 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,9 @@ def check_available_online(
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
max_model_len=10240),
max_model_len=10240,
extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501
),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
Expand Down
Loading