Skip to content

Commit 0a629b5

Browse files
shenoyvvarunvadiklyutiy
authored andcommitted
[Tests] Fixing bug inside MultiModalProfiler. (vllm-project#21842)
Signed-off-by: Varun Shenoy <[email protected]>
1 parent 79e53ee commit 0a629b5

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Tests for mllama's multimodal preprocessing and profiling."""
4+
import pytest
5+
from torch import prod
6+
from transformers import Llama4Config
7+
8+
from vllm.multimodal import MULTIMODAL_REGISTRY
9+
from vllm.multimodal.profiling import MultiModalProfiler
10+
11+
from ...utils import build_model_context
12+
13+
14+
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-Guard-4-12B"])
15+
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
16+
def test_profiling(model_id: str, max_model_len: int):
17+
model_config_kwargs = {
18+
"max_model_len": max_model_len,
19+
}
20+
ctx = build_model_context(
21+
model_id,
22+
model_config_kwargs=model_config_kwargs,
23+
limit_mm_per_prompt={"image": 1},
24+
)
25+
26+
mm_config = ctx.get_mm_config()
27+
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
28+
profiler = MultiModalProfiler(processor)
29+
30+
decoder_dummy_data = profiler.get_decoder_dummy_data(
31+
max_model_len,
32+
mm_counts=mm_config.limit_per_prompt,
33+
)
34+
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
35+
max_model_len,
36+
mm_counts=mm_config.limit_per_prompt,
37+
)
38+
39+
hf_config = ctx.get_hf_config(Llama4Config)
40+
41+
mm_kwargs = processor.apply(
42+
prompt=dummy_mm_data.prompt,
43+
mm_data=dummy_mm_data.mm_data,
44+
hf_processor_mm_kwargs=dict(),
45+
)["mm_kwargs"]
46+
47+
image_size = hf_config.vision_config.image_size
48+
patch_size = hf_config.vision_config.patch_size
49+
downsample_ratio = int(
50+
round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)))
51+
tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio
52+
chunks_per_image = prod(mm_kwargs["patches_per_image"])
53+
total_num_patches = chunks_per_image * tokens_per_patch
54+
num_tiles = mm_kwargs["aspect_ratios"][0][0] * mm_kwargs["aspect_ratios"][
55+
0][1] # x-y seperator tokens
56+
total_tokens = total_num_patches.item() + num_tiles.item(
57+
) + 3 # image start, image, image end
58+
59+
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
60+
max_model_len,
61+
mm_counts=mm_config.limit_per_prompt,
62+
)
63+
64+
assert total_tokens == profiled_tokens["image"]
65+
assert total_tokens == sum(
66+
placeholder.length for placeholder in
67+
decoder_dummy_data.multi_modal_placeholders["image"])

tests/models/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,9 @@ def check_available_online(
391391
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
392392
trust_remote_code=True),
393393
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
394-
max_model_len=10240),
394+
max_model_len=10240,
395+
extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501
396+
),
395397
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
396398
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
397399
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501

0 commit comments

Comments
 (0)