Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MODEL] Qwen Multimodal Support (Qwen-VL / Qwen-VL-Chat) #8029

Merged
merged 35 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
de65d94
Enable stub mm image inputs for qwen models
alex-jw-brooks Aug 19, 2024
debcc8c
Add calc for max number of qwen image tokens
alex-jw-brooks Aug 20, 2024
9803fab
Add unmodified visual qwen code, enable visual weight loading
alex-jw-brooks Aug 20, 2024
364c110
Implement model processor for qwen, fix max tokens
alex-jw-brooks Aug 26, 2024
9b4eb9a
Implement qwen input mapper and visual feature forward call
alex-jw-brooks Aug 27, 2024
a709dd9
Hacky integration of img pos / merging for qwen-vl
alex-jw-brooks Aug 27, 2024
59339d2
Add multimodal dummy data for qwen models
alex-jw-brooks Aug 27, 2024
d6e3ca4
Conditionally enable visual component to support qwen llm-only models
alex-jw-brooks Aug 28, 2024
8e27aa2
Fix hardcoded image start ID in image position extraction
alex-jw-brooks Aug 28, 2024
f535d61
Enable chat for qwen-vl
alex-jw-brooks Aug 28, 2024
b10b73c
Improve validation, add qwen single image embed support
alex-jw-brooks Aug 28, 2024
0611f19
Tentative support for multi-image embeddings in qwen
alex-jw-brooks Aug 28, 2024
541b7b5
Add example for qwen vl offline inference
alex-jw-brooks Aug 28, 2024
1db0e6d
run formatting and linting
alex-jw-brooks Aug 28, 2024
ed3e15d
Fix bug in image token input processing
alex-jw-brooks Aug 29, 2024
cadabb7
Qwen - add comments, error handling in warmup
alex-jw-brooks Aug 29, 2024
4238c2f
Fix device and dtype hack in Qwen resampler
alex-jw-brooks Aug 30, 2024
0258345
Update sequence data initialization
alex-jw-brooks Aug 30, 2024
d0b8962
Flatten bn dimension for qwen
alex-jw-brooks Aug 30, 2024
030b535
Switch qwen test to text only model
alex-jw-brooks Aug 30, 2024
27f819a
Run code formatting
alex-jw-brooks Aug 30, 2024
fcdd6f1
Add image tag standardization, multimodal qwen tests
alex-jw-brooks Sep 1, 2024
a5c1201
Remove support for <image> in qwen
alex-jw-brooks Sep 1, 2024
29a3c7f
Update docs for multimodal support
alex-jw-brooks Sep 1, 2024
7ac8ff9
Update vllm/model_executor/models/qwen.py
alex-jw-brooks Sep 3, 2024
3fe3a77
Add qwen back to llm docs
alex-jw-brooks Sep 3, 2024
a4f5400
Make qwen/minicpmv embedding utils common
alex-jw-brooks Sep 3, 2024
1989da8
Make qwenvl / minicpmv2.0 resampler common
alex-jw-brooks Sep 4, 2024
5e8409b
Fix qwen warning for image placeholders
alex-jw-brooks Sep 4, 2024
c889a69
Fix formatting, missing license, typehints
alex-jw-brooks Sep 4, 2024
2aa9549
Remove unreachable optional cross attn in qwenvl
alex-jw-brooks Sep 4, 2024
531d628
Use parallel linear layers in qwenvl mlp
alex-jw-brooks Sep 4, 2024
d19418d
Limit `max_num_seqs`
DarkLight1337 Sep 5, 2024
4f25926
Further limit
DarkLight1337 Sep 5, 2024
00c2e09
Fix dummy data seq padding for multimodal qwen
alex-jw-brooks Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ Multimodal Language Models
- Image
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`QWenLMHeadModel`
- Qwen
- Image
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
-
* - :code:`UltravoxModel`
- Ultravox
- Audio
Expand Down
15 changes: 15 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,20 @@ def run_blip2(question):
return llm, prompt, stop_token_ids


# Qwen
def run_qwen_vl(question):

llm = LLM(
model="Qwen/Qwen-VL",
trust_remote_code=True,
max_num_seqs=5,
)

prompt = f"{question}Picture 1: <img></img>\n"
stop_token_ids = None
return llm, prompt, stop_token_ids


model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
Expand All @@ -169,6 +183,7 @@ def run_blip2(question):
"minicpmv": run_minicpmv,
"blip-2": run_blip2,
"internvl_chat": run_internvl,
"qwen_vl": run_qwen_vl,
}


Expand Down
167 changes: 142 additions & 25 deletions tests/models/test_qwen.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,165 @@
from typing import Type
import pathlib
from typing import List, Optional, Type

import pytest

from ..conftest import HfRunner, VllmRunner
from vllm.multimodal.utils import rescale_image_size

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close

models = ["qwen/qwen-vl"]
pytestmark = pytest.mark.vlm

text_only_models = [
"Qwen/Qwen-7B-Chat" # Has no visual component
]

@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("model", models)
def test_text_only_qwen_model(
multimodal_models = ["Qwen/Qwen-VL"]

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"Picture 1: <img></img>\nWhat's the content of the image?: ",
"cherry_blossom":
"Picture 1: <img></img>\nWhat is the season?: ",
})


### Tests for multimodal Qwen models
def run_test(
tmp_path: pathlib.PosixPath,
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
example_prompts,
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
# This test checks language inputs only, since the visual component
# for qwen-vl is still unsupported in VLLM. In the near-future, the
# implementation and this test will be extended to consider
# visual inputs as well.
"""Inference result should be the same between hf and vllm.

All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]

# Export the images to a tempdir and substitute it into the hf prompt;
# the contents between <img>/</img> will be ignored by VLLM, but the
# transformers implementation for the visual transformer parses this to
# reload it in the forward call; the contents are treated as a URL or a
# local path.
for idx, asset in enumerate(image_assets):
image_tmp_path = tmp_path / f"{asset.name}.jpg"
asset.pil_image.save(image_tmp_path)
HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace(
"<img></img>", f"<img>{image_tmp_path}</img>")

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

# max_model_len should be greater than image_feature_size
# Qwen encodes images into a fixed content size of 256
with vllm_runner(model,
max_model_len=300,
max_num_seqs=1,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts,
max_tokens,
num_logprobs=num_logprobs,
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", multimodal_models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("num_logprobs", [5])
def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets,
model, size_factors, dtype, max_tokens,
num_logprobs) -> None:
run_test(
tmp_path,
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


# Ensure that a text-only Qwen model can still be loaded and
# used for inference in VLLM without throwing.
@pytest.mark.parametrize("model", text_only_models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_text_only_qwen_model_can_be_loaded_and_run(
vllm_runner: Type[VllmRunner],
example_prompts,
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
):
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
vllm_model.generate_greedy_logprobs(
example_prompts,
max_tokens,
num_logprobs=num_logprobs,
)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
2 changes: 2 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def add(self, modality: Literal["image", "audio"],
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":
return f"Picture {current_count}: <img></img>"
if model_type.startswith("llava"):
return MultiModalItemTracker._cached_token_str(
self._tokenizer,
Expand Down
Loading
Loading