Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hotfix][Pixtral] Fix multiple images bugs #8415

Merged
merged 16 commits into from
Sep 12, 2024
177 changes: 135 additions & 42 deletions tests/models/test_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,156 @@

Run `pytest tests/models/test_mistral.py`.
"""
import uuid
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add many more aggressive tests

from typing import Any, Dict, List

import pytest
from mistral_common.protocol.instruct.messages import ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk

from vllm.sampling_params import SamplingParams
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins

pytestmark = pytest.mark.vlm

MODELS = ["mistralai/Pixtral-12B-2409"]
IMG_URLS = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/231/200/300",
"https://picsum.photos/id/27/500/500",
"https://picsum.photos/id/17/150/600",
]
PROMPT = "Describe each image in one short sentence."


def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
return [{
"role":
"user",
"content": [{
"type": "text",
"text": PROMPT,
}] + [{
"type": "image_url",
"image_url": {
"url": url
}
} for url in urls]
}]


def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
msg = _create_msg_format(urls)

tokenizer = MistralTokenizer.from_model("pixtral")

request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
tokenized = tokenizer.encode_chat_completion(request)

engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)

images = []
for chunk in request.messages[0].content:
if isinstance(chunk, ImageURLChunk):
images.append(image_from_chunk(chunk))

mm_data = MultiModalDataBuiltins(image=images)
engine_inputs["multi_modal_data"] = mm_data

return engine_inputs


MSGS = [
_create_msg_format(IMG_URLS[:1]),
_create_msg_format(IMG_URLS[:2]),
_create_msg_format(IMG_URLS)
]
ENGINE_INPUTS = [
_create_engine_inputs(IMG_URLS[:1]),
_create_engine_inputs(IMG_URLS[:2]),
_create_engine_inputs(IMG_URLS)
]

EXPECTED = [
"The image shows a black dog sitting on a wooden surface.", # noqa
"1. A black dog with floppy ears sits attentively on a wooden surface.\n2. A vast mountain range with rugged peaks stretches under a cloudy sky.", # noqa
"1. A black dog sits attentively on a wooden floor.\n2. A vast mountain range stretches across the horizon under a cloudy sky.\n3. Surfers wait for waves in the ocean at sunset.\n4. A winding gravel path leads through a lush green park." # noqa
]

SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0)
LIMIT_MM_PER_PROMPT = dict(image=4)


@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_model_len", [8192, 65536])
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
def test_chat(
vllm_runner,
example_prompts,
max_model_len: int,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
image_urls = [
"https://picsum.photos/id/237/200/300",
"https://picsum.photos/seed/picsum/200/300"
]
expected = [
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
]
prompt = "Describe the image in one short sentence."

sampling_params = SamplingParams(max_tokens=512, temperature=0.0)

with vllm_runner(model, dtype=dtype,
tokenizer_mode="mistral") as vllm_model:

for i, image_url in enumerate(image_urls):
messages = [
{
"role":
"user",
"content": [{
"type": "text",
"text": prompt
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}]
},
]

outputs = vllm_model.model.chat(messages,
sampling_params=sampling_params)
assert outputs[0].outputs[0].text == expected[i]

with vllm_runner(model,
dtype=dtype,
tokenizer_mode="mistral",
enable_chunked_prefill=False,
max_model_len=max_model_len,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT) as vllm_model:
results = []
for msg in MSGS:
outputs = vllm_model.model.chat(msg,
sampling_params=SAMPLING_PARAMS)

results.append(outputs[0].outputs[0].text)

assert results == EXPECTED


@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(model: str, dtype: str) -> None:
args = EngineArgs(
model=model,
tokenizer_mode="mistral",
enable_chunked_prefill=False,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
dtype=dtype,
)
engine = LLMEngine.from_engine_args(args)

engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)

results = []
count = 0
while True:
out = engine.step()
count += 1
for request_output in out:
if request_output.finished:
results.append(request_output.outputs[0].text)

if count == 2:
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
SAMPLING_PARAMS)
if not engine.has_unfinished_requests():
break

assert results[0] == EXPECTED[0]
# the result is a tiny bit different but this is not unexpected given that
# different kernels are executed and given that flash attention is not
# deterministic
assert results[
1] == "1. A black dog with floppy ears sits attentively on a wooden surface.\n2. A vast mountain range stretches across the horizon under a cloudy sky." # noqa
assert results[2] == EXPECTED[2]
83 changes: 49 additions & 34 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from array import array
from dataclasses import dataclass, fields
from itertools import tee
Expand All @@ -15,11 +14,12 @@

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
Expand Down Expand Up @@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder

mm_config = ctx.model_config.multimodal_config
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
patch_size = mm_encoder.mm_config.image_patch_size
image_token_id = mm_encoder.special_ids.img

# approximate image size
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
mm_config = ctx.model_config.multimodal_config
num_images = mm_config.limit_per_prompt.get("image", 1)

# dummy size
size = 256
image = Image.new("RGB", (size, size), color=0)
img_chunk = ImageChunk(image=image)

tokens = mm_encoder(img_chunk).tokens
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
tokens)
image_feature_size = (size**2) // (patch_size**2)

num_image_tokens = image_feature_size * num_images

token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * num_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - num_image_tokens)

seq_data = SequenceData(token_ids)
mm_data = {"image": max_num_images_per_request * [image]}
mm_data = {"image": num_images * [image]}
return seq_data, mm_data


Expand Down Expand Up @@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
return MultiModalInputs({"images": images})


def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: Optional[List[torch.Tensor]],
image_id: int) -> torch.Tensor:
text_locations = input_ids != image_id
image_locations = input_ids == image_id

seq_len = input_ids.shape[0]
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is not None and "image" in multi_modal_data:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)

N_txt = text_locations.sum().item()
_, D_txt = inputs_embeds.shape
N_img, D_img = image_features.shape
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img

assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
"to image features dim {D_img}")
assert (seq_len == N_txt +
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
f"{(N_txt, N_img, image_locations.sum().item())}")
if image_token_id not in llm_inputs['prompt_token_ids']:
raise ValueError(
(f"You've passed {llm_inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."))

inputs_embeds[image_locations, :] = image_features
return inputs_embeds
return llm_inputs


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):

def __init__(self,
Expand Down Expand Up @@ -201,11 +206,21 @@ def _parse_and_validate_image_input(
return None

if isinstance(images, torch.Tensor):
# always take last images
images = [images[-1][i] for i in range(images.size(1))]
# if passed as batch take all images
N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# always take last images
images = [images[-1][i] for i in range(len(images[0]))]
# if passed as list flatten lists of tensors
flatten_images = []
for imgs_per_req in images:
imgs_per_req = [
imgs_per_req[i] for i in range(imgs_per_req.size(0))
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req

flatten_images.extend(imgs_per_req)

images = flatten_images

return images

Expand Down
Loading