Skip to content

Commit c60e8fd

Browse files
authored
FEAT: support qwen2-vl-instruct (#2205)
1 parent bcfedf8 commit c60e8fd

File tree

11 files changed

+313
-7
lines changed

11 files changed

+313
-7
lines changed

.github/workflows/python.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ jobs:
135135
pip install tensorizer
136136
pip install eva-decord
137137
pip install jj-pytorchvideo
138+
pip install qwen-vl-utils
138139
working-directory: .
139140

140141
- name: Test with pytest

setup.cfg

+2
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ all =
127127
loguru # For Fish Speech
128128
natsort # For Fish Speech
129129
loralib # For Fish Speech
130+
qwen-vl-utils # For qwen2-vl
130131
intel =
131132
torch==2.1.0a0
132133
intel_extension_for_pytorch==2.1.10+xpu
@@ -151,6 +152,7 @@ transformers =
151152
peft
152153
eva-decord # For video in VL
153154
jj-pytorchvideo # For CogVLM2-video
155+
qwen-vl-utils # For qwen2-vl
154156
vllm =
155157
vllm>=0.2.6
156158
sglang =

xinference/core/worker.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,15 @@ def __init__(
7373
self._supervisor_ref: Optional[xo.ActorRefType] = None
7474
self._main_pool = main_pool
7575
self._main_pool.recover_sub_pool = self.recover_sub_pool
76-
self._status_guard_ref: xo.ActorRefType["StatusGuardActor"] = ( # type: ignore
77-
None
78-
)
76+
self._status_guard_ref: xo.ActorRefType[
77+
"StatusGuardActor"
78+
] = None # type: ignore
7979
self._event_collector_ref: xo.ActorRefType[ # type: ignore
8080
EventCollectorActor
8181
] = None
82-
self._cache_tracker_ref: xo.ActorRefType[CacheTrackerActor] = ( # type: ignore
83-
None
84-
)
82+
self._cache_tracker_ref: xo.ActorRefType[
83+
CacheTrackerActor
84+
] = None # type: ignore
8585

8686
# internal states.
8787
# temporary placeholder during model launch process:

xinference/deploy/docker/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ jj-pytorchvideo # For CogVLM2-video
7070
loguru # For Fish Speech
7171
natsort # For Fish Speech
7272
loralib # For Fish Speech
73+
qwen-vl-utils # For qwen2-vl
7374

7475
# sglang
7576
outlines>=0.0.44

xinference/deploy/docker/requirements_cpu.txt

+1
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,4 @@ jj-pytorchvideo # For CogVLM2-video
6565
loguru # For Fish Speech
6666
natsort # For Fish Speech
6767
loralib # For Fish Speech
68+
qwen-vl-utils # For qwen2-vl

xinference/model/llm/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _install():
142142
from .transformers.internlm2 import Internlm2PytorchChatModel
143143
from .transformers.minicpmv25 import MiniCPMV25Model
144144
from .transformers.minicpmv26 import MiniCPMV26Model
145+
from .transformers.qwen2_vl import Qwen2VLChatModel
145146
from .transformers.qwen_vl import QwenVLChatModel
146147
from .transformers.yi_vl import YiVLChatModel
147148
from .vllm.core import VLLMChatModel, VLLMModel, VLLMVisionModel
@@ -171,6 +172,7 @@ def _install():
171172
PytorchChatModel,
172173
Internlm2PytorchChatModel,
173174
QwenVLChatModel,
175+
Qwen2VLChatModel,
174176
YiVLChatModel,
175177
DeepSeekVLChatModel,
176178
InternVLChatModel,

xinference/model/llm/llm_family.json

+46
Original file line numberDiff line numberDiff line change
@@ -6805,5 +6805,51 @@
68056805
"stop": [
68066806
"</s>"
68076807
]
6808+
},
6809+
{
6810+
"version":1,
6811+
"context_length":32768,
6812+
"model_name":"qwen2-vl-instruct",
6813+
"model_lang":[
6814+
"en",
6815+
"zh"
6816+
],
6817+
"model_ability":[
6818+
"chat",
6819+
"vision"
6820+
],
6821+
"model_description":"Qwen2-VL: To See the World More Clearly.Qwen2-VL is the latest version of the vision language models in the Qwen model familities.",
6822+
"model_specs":[
6823+
{
6824+
"model_format":"pytorch",
6825+
"model_size_in_billions":2,
6826+
"quantizations":[
6827+
"none"
6828+
],
6829+
"model_id":"Qwen/Qwen2-VL-2B-Instruct",
6830+
"model_revision":"096da3b96240e3d66d35be0e5ccbe282eea8d6b1"
6831+
},
6832+
{
6833+
"model_format":"pytorch",
6834+
"model_size_in_billions":7,
6835+
"quantizations":[
6836+
"none"
6837+
],
6838+
"model_id":"Qwen/Qwen2-VL-7B-Instruct",
6839+
"model_revision":"6010982c1010c3b222fa98afc81575f124aa9bd6"
6840+
}
6841+
],
6842+
"prompt_style":{
6843+
"style_name":"QWEN",
6844+
"system_prompt":"You are a helpful assistant",
6845+
"roles":[
6846+
"user",
6847+
"assistant"
6848+
],
6849+
"stop": [
6850+
"<|im_end|>",
6851+
"<|endoftext|>"
6852+
]
6853+
}
68086854
}
68096855
]

xinference/model/llm/llm_family_modelscope.json

+44
Original file line numberDiff line numberDiff line change
@@ -4508,5 +4508,49 @@
45084508
160133,
45094509
160132
45104510
]
4511+
},
4512+
{
4513+
"version": 1,
4514+
"context_length": 32768,
4515+
"model_name": "qwen2-vl-instruct",
4516+
"model_lang": [
4517+
"en",
4518+
"zh"
4519+
],
4520+
"model_ability": [
4521+
"chat",
4522+
"vision"
4523+
],
4524+
"model_description": "Qwen2-VL: To See the World More Clearly.Qwen2-VL is the latest version of the vision language models in the Qwen model familities.",
4525+
"model_specs": [
4526+
{
4527+
"model_format": "pytorch",
4528+
"model_size_in_billions": 2,
4529+
"quantizations": [
4530+
"none"
4531+
],
4532+
"model_hub": "modelscope",
4533+
"model_id": "qwen/Qwen2-VL-2B-Instruct",
4534+
"model_revision": "master"
4535+
},
4536+
{
4537+
"model_format": "pytorch",
4538+
"model_size_in_billions": 7,
4539+
"quantizations": [
4540+
"none"
4541+
],
4542+
"model_hub": "modelscope",
4543+
"model_id": "qwen/Qwen2-VL-7B-Instruct",
4544+
"model_revision": "master"
4545+
}
4546+
],
4547+
"prompt_style": {
4548+
"style_name": "QWEN",
4549+
"system_prompt": "You are a helpful assistant",
4550+
"roles": [
4551+
"user",
4552+
"assistant"
4553+
]
4554+
}
45114555
}
45124556
]

xinference/model/llm/transformers/core.py

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"MiniCPM-Llama3-V-2_5",
6565
"MiniCPM-V-2.6",
6666
"glm-4v",
67+
"qwen2-vl-instruct",
6768
]
6869

6970

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2022-2023 XProbe Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
import uuid
16+
from typing import Iterator, List, Optional, Union
17+
18+
from ....model.utils import select_device
19+
from ....types import (
20+
ChatCompletion,
21+
ChatCompletionChunk,
22+
ChatCompletionMessage,
23+
CompletionChunk,
24+
)
25+
from ..llm_family import LLMFamilyV1, LLMSpecV1
26+
from ..utils import generate_chat_completion, generate_completion_chunk
27+
from .core import PytorchChatModel, PytorchGenerateConfig
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
class Qwen2VLChatModel(PytorchChatModel):
33+
def __init__(self, *args, **kwargs):
34+
super().__init__(*args, **kwargs)
35+
self._tokenizer = None
36+
self._model = None
37+
self._device = None
38+
self._processor = None
39+
40+
@classmethod
41+
def match(
42+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
43+
) -> bool:
44+
llm_family = model_family.model_family or model_family.model_name
45+
if "qwen2-vl-instruct".lower() in llm_family.lower():
46+
return True
47+
return False
48+
49+
def load(self):
50+
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
51+
52+
device = self._pytorch_model_config.get("device", "auto")
53+
device = select_device(device)
54+
self._device = device
55+
# for multiple GPU, set back to auto to make multiple devices work
56+
device = "auto" if device == "cuda" else device
57+
58+
self._processor = AutoProcessor.from_pretrained(
59+
self.model_path, trust_remote_code=True
60+
)
61+
self._tokenizer = self._processor.tokenizer
62+
self._model = Qwen2VLForConditionalGeneration.from_pretrained(
63+
self.model_path, device_map=device, trust_remote_code=True
64+
).eval()
65+
66+
def _transform_messages(
67+
self,
68+
messages: List[ChatCompletionMessage],
69+
):
70+
transformed_messages = []
71+
for msg in messages:
72+
new_content = []
73+
role = msg["role"]
74+
content = msg["content"]
75+
if isinstance(content, str):
76+
new_content.append({"type": "text", "text": content})
77+
elif isinstance(content, List):
78+
for item in content: # type: ignore
79+
if "text" in item:
80+
new_content.append({"type": "text", "text": item["text"]})
81+
elif "image_url" in item:
82+
new_content.append(
83+
{"type": "image", "image": item["image_url"]["url"]}
84+
)
85+
elif "video_url" in item:
86+
new_content.append(
87+
{"type": "video", "video": item["video_url"]["url"]}
88+
)
89+
new_message = {"role": role, "content": new_content}
90+
transformed_messages.append(new_message)
91+
92+
return transformed_messages
93+
94+
def chat(
95+
self,
96+
messages: List[ChatCompletionMessage], # type: ignore
97+
generate_config: Optional[PytorchGenerateConfig] = None,
98+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
99+
messages = self._transform_messages(messages)
100+
101+
generate_config = generate_config if generate_config else {}
102+
103+
stream = generate_config.get("stream", False) if generate_config else False
104+
105+
if stream:
106+
it = self._generate_stream(messages, generate_config)
107+
return self._to_chat_completion_chunks(it)
108+
else:
109+
c = self._generate(messages, generate_config)
110+
return c
111+
112+
def _generate(
113+
self, messages: List, config: PytorchGenerateConfig = {}
114+
) -> ChatCompletion:
115+
from qwen_vl_utils import process_vision_info
116+
117+
# Preparation for inference
118+
text = self._processor.apply_chat_template(
119+
messages, tokenize=False, add_generation_prompt=True
120+
)
121+
image_inputs, video_inputs = process_vision_info(messages)
122+
inputs = self._processor(
123+
text=[text],
124+
images=image_inputs,
125+
videos=video_inputs,
126+
padding=True,
127+
return_tensors="pt",
128+
)
129+
inputs = inputs.to("cuda")
130+
131+
# Inference: Generation of the output
132+
generated_ids = self._model.generate(
133+
**inputs,
134+
max_new_tokens=config.get("max_tokens", 512),
135+
temperature=config.get("temperature", 1),
136+
)
137+
generated_ids_trimmed = [
138+
out_ids[len(in_ids) :]
139+
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
140+
]
141+
output_text = self._processor.batch_decode(
142+
generated_ids_trimmed,
143+
skip_special_tokens=True,
144+
clean_up_tokenization_spaces=False,
145+
)[0]
146+
return generate_chat_completion(self.model_uid, output_text)
147+
148+
def _generate_stream(
149+
self, messages: List, config: PytorchGenerateConfig = {}
150+
) -> Iterator[CompletionChunk]:
151+
from threading import Thread
152+
153+
from qwen_vl_utils import process_vision_info
154+
from transformers import TextIteratorStreamer
155+
156+
text = self._processor.apply_chat_template(
157+
messages, tokenize=False, add_generation_prompt=True
158+
)
159+
image_inputs, video_inputs = process_vision_info(messages)
160+
inputs = self._processor(
161+
text=[text],
162+
images=image_inputs,
163+
videos=video_inputs,
164+
padding=True,
165+
return_tensors="pt",
166+
)
167+
inputs = inputs.to(self._model.device)
168+
169+
tokenizer = self._tokenizer
170+
streamer = TextIteratorStreamer(
171+
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
172+
)
173+
174+
gen_kwargs = {
175+
"max_new_tokens": config.get("max_tokens", 512),
176+
"temperature": config.get("temperature", 1),
177+
"streamer": streamer,
178+
**inputs,
179+
}
180+
181+
thread = Thread(target=self._model.generate, kwargs=gen_kwargs)
182+
thread.start()
183+
184+
completion_id = str(uuid.uuid1())
185+
for new_text in streamer:
186+
yield generate_completion_chunk(
187+
chunk_text=new_text,
188+
finish_reason=None,
189+
chunk_id=completion_id,
190+
model_uid=self.model_uid,
191+
prompt_tokens=-1,
192+
completion_tokens=-1,
193+
total_tokens=-1,
194+
has_choice=True,
195+
has_content=True,
196+
)
197+
198+
yield generate_completion_chunk(
199+
chunk_text=None,
200+
finish_reason="stop",
201+
chunk_id=completion_id,
202+
model_uid=self.model_uid,
203+
prompt_tokens=-1,
204+
completion_tokens=-1,
205+
total_tokens=-1,
206+
has_choice=True,
207+
has_content=False,
208+
)

0 commit comments

Comments
 (0)