99 UserMessage )
1010from mistral_common .protocol .instruct .request import ChatCompletionRequest
1111from PIL import Image
12- from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
1312
1413from vllm .config import ModelConfig
1514from vllm .inputs import InputProcessingContext
1615from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalDataDict
1716from vllm .multimodal .inputs import MultiModalInputs
1817from vllm .multimodal .processing import BaseMultiModalProcessor , ProcessingCache
19- from vllm .transformers_utils .tokenizer import (MistralTokenizer ,
20- cached_tokenizer_from_config )
18+ from vllm .transformers_utils .tokenizer import (AnyTokenizer , MistralTokenizer ,
19+ cached_tokenizer_from_config ,
20+ encode_tokens )
2121
2222from ....multimodal .utils import random_audio , random_image , random_video
2323from ...registry import HF_EXAMPLE_MODELS
@@ -28,7 +28,6 @@ def _test_processing_correctness(
2828 hit_rate : float ,
2929 num_batches : int ,
3030 simplify_rate : float ,
31- ignore_mm_keys : Optional [set [str ]] = None ,
3231):
3332 model_info = HF_EXAMPLE_MODELS .find_hf_info (model_id )
3433 model_info .check_available_online (on_fail = "skip" )
@@ -99,10 +98,23 @@ def _test_processing_correctness(
9998 }
10099
101100 mm_counts = {k : len (vs ) for k , vs in mm_data .items ()}
102- prompt = dummy_inputs .get_dummy_processor_inputs (
103- model_config .max_model_len ,
104- mm_counts ,
105- ).prompt_text
101+
102+ # Mistral chat outputs tokens directly, rather than text prompts
103+ if isinstance (tokenizer , MistralTokenizer ):
104+ images = mm_data .get ("image" , [])
105+ request = ChatCompletionRequest (messages = [
106+ UserMessage (content = [
107+ TextChunk (text = "" ),
108+ * (ImageChunk (image = image ) for image in images ),
109+ ]),
110+ ])
111+ res = tokenizer .mistral .encode_chat_completion (request )
112+ prompt = res .tokens
113+ else :
114+ prompt = dummy_inputs .get_dummy_processor_inputs (
115+ model_config .max_model_len ,
116+ mm_counts ,
117+ ).prompt
106118
107119 # Drop unnecessary keys and test single -> multi conversion
108120 if rng .rand () < simplify_rate :
@@ -112,124 +124,66 @@ def _test_processing_correctness(
112124 elif len (mm_data [k ]) == 1 :
113125 mm_data [k ] = mm_data [k ][0 ]
114126
115- if isinstance (tokenizer , MistralTokenizer ):
116- _test_processing_correctness_mistral (
117- model_config ,
118- tokenizer ,
119- prompt ,
120- mm_data ,
121- baseline_processor ,
122- cached_processor ,
123- batch_idx ,
124- ignore_mm_keys = ignore_mm_keys ,
125- )
126- else :
127- _test_processing_correctness_hf (
128- model_config ,
129- tokenizer ,
130- prompt ,
131- mm_data ,
132- baseline_processor ,
133- cached_processor ,
134- batch_idx ,
135- ignore_mm_keys = ignore_mm_keys ,
136- )
137-
138-
139- def _test_processing_correctness_hf (
127+ _test_processing_correctness_one (
128+ model_config ,
129+ tokenizer ,
130+ prompt ,
131+ mm_data ,
132+ baseline_processor ,
133+ cached_processor ,
134+ batch_idx ,
135+ )
136+
137+
138+ # For some multimodal models, tokenizer will always add bos_token
139+ # at the beginning of prompt by default, causing hf_processor outputs
140+ # incorrect token ids. So we need use `add_special_tokens=False` here
141+ # to leave bos_token to be added by the processor.
142+ _ADD_SPECIAL_TOKENS_OVERRIDES = {
143+ "mllama" : False ,
144+ "ovis" : False ,
145+ "ultravox" : False ,
146+ "whisper" : False ,
147+ }
148+
149+ _IGNORE_MM_KEYS = {
150+ # In Ultravox, the audio_features can be different depending on padding
151+ # The slight difference should not be a problem though, since
152+ # attention_mask lets us ignore the difference.
153+ "ultravox" : {"audio_features" },
154+ }
155+
156+
157+ def _test_processing_correctness_one (
140158 model_config : ModelConfig ,
141- tokenizer : Union [ PreTrainedTokenizer , PreTrainedTokenizerFast ] ,
142- prompt : str ,
159+ tokenizer : AnyTokenizer ,
160+ prompt : Union [ str , list [ int ]] ,
143161 mm_data : MultiModalDataDict ,
144162 baseline_processor : BaseMultiModalProcessor ,
145163 cached_processor : BaseMultiModalProcessor ,
146164 batch_idx : int ,
147- ignore_mm_keys : Optional [set [str ]] = None ,
148165):
149- if model_config .hf_config .model_type in ("mllama" , "ovis" , "ultravox" ,
150- "whisper" ):
151- # For some multimodal models, tokenizer will always add bos_token
152- # at the beginning of prompt by default, causing hf_processor outputs
153- # incorrect token ids. So we need use `add_special_tokens=False` here
154- # to leave bos_token to be added by the processor.
155- token_prompt = tokenizer .encode (prompt , add_special_tokens = False )
166+ model_type = model_config .hf_config .model_type
167+ ignore_mm_keys = _IGNORE_MM_KEYS .get (model_type , set [str ]())
168+
169+ if isinstance (prompt , str ):
170+ text_prompt = prompt
171+ token_prompt = encode_tokens (
172+ tokenizer ,
173+ prompt ,
174+ add_special_tokens = _ADD_SPECIAL_TOKENS_OVERRIDES .get (model_type ),
175+ )
156176 else :
157- token_prompt = tokenizer .encode (prompt )
158-
159- baseline_result = baseline_processor .apply (
160- prompt ,
161- mm_data = mm_data ,
162- hf_processor_mm_kwargs = {},
163- )
164- cached_result = cached_processor .apply (
165- prompt ,
166- mm_data = mm_data ,
167- hf_processor_mm_kwargs = {},
168- )
169-
170- _assert_inputs_equal (
171- baseline_result ,
172- cached_result ,
173- ignore_mm_keys = ignore_mm_keys ,
174- msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
175- )
177+ # Mistral does not support decode_tokens with skip_special_tokens=False
178+ text_prompt = None
179+ token_prompt = prompt
176180
177181 baseline_tokenized_result = baseline_processor .apply (
178182 token_prompt ,
179183 mm_data = mm_data ,
180184 hf_processor_mm_kwargs = {},
181185 )
182186
183- _assert_inputs_equal (
184- baseline_result ,
185- baseline_tokenized_result ,
186- ignore_mm_keys = ignore_mm_keys ,
187- msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
188- )
189-
190- cached_tokenized_result = cached_processor .apply (
191- token_prompt ,
192- mm_data = mm_data ,
193- hf_processor_mm_kwargs = {},
194- )
195-
196- _assert_inputs_equal (
197- cached_result ,
198- cached_tokenized_result ,
199- ignore_mm_keys = ignore_mm_keys ,
200- msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
201- )
202-
203-
204- def _test_processing_correctness_mistral (
205- model_config : ModelConfig ,
206- tokenizer : MistralTokenizer ,
207- prompt : str ,
208- mm_data : MultiModalDataDict ,
209- baseline_processor : BaseMultiModalProcessor ,
210- cached_processor : BaseMultiModalProcessor ,
211- batch_idx : int ,
212- ignore_mm_keys : Optional [set [str ]] = None ,
213- ):
214- images = mm_data .get ("image" , [])
215- if not isinstance (images , list ):
216- images = [images ]
217-
218- request = ChatCompletionRequest (messages = [
219- UserMessage (content = [
220- TextChunk (text = prompt ),
221- * (ImageChunk (image = image ) for image in images ),
222- ]),
223- ])
224- res = tokenizer .mistral .encode_chat_completion (request )
225- token_prompt = res .tokens
226-
227- # Mistral chat outputs tokens directly, rather than text prompts
228- baseline_tokenized_result = baseline_processor .apply (
229- token_prompt ,
230- mm_data = mm_data ,
231- hf_processor_mm_kwargs = {},
232- )
233187 cached_tokenized_result = cached_processor .apply (
234188 token_prompt ,
235189 mm_data = mm_data ,
@@ -240,9 +194,44 @@ def _test_processing_correctness_mistral(
240194 baseline_tokenized_result ,
241195 cached_tokenized_result ,
242196 ignore_mm_keys = ignore_mm_keys ,
243- msg = f"Failed ({ batch_idx = } , { prompt = } , { mm_data = } )" ,
197+ msg = f"Failed ({ batch_idx = } , { token_prompt = } , { mm_data = } )" ,
244198 )
245199
200+ if text_prompt is not None :
201+ baseline_text_result = baseline_processor .apply (
202+ text_prompt ,
203+ mm_data = mm_data ,
204+ hf_processor_mm_kwargs = {},
205+ )
206+ cached_text_result = cached_processor .apply (
207+ text_prompt ,
208+ mm_data = mm_data ,
209+ hf_processor_mm_kwargs = {},
210+ )
211+
212+ _assert_inputs_equal (
213+ baseline_text_result ,
214+ cached_text_result ,
215+ ignore_mm_keys = ignore_mm_keys ,
216+ msg = f"Failed ({ batch_idx = } , { text_prompt = } , { mm_data = } )" ,
217+ )
218+
219+ _assert_inputs_equal (
220+ baseline_text_result ,
221+ baseline_tokenized_result ,
222+ ignore_mm_keys = ignore_mm_keys ,
223+ msg = f"Failed ({ batch_idx = } , { text_prompt = } , "
224+ f"{ token_prompt = } , { mm_data = } )" ,
225+ )
226+
227+ _assert_inputs_equal (
228+ cached_text_result ,
229+ cached_tokenized_result ,
230+ ignore_mm_keys = ignore_mm_keys ,
231+ msg = f"Failed ({ batch_idx = } , { text_prompt = } , "
232+ f"{ token_prompt = } , { mm_data = } )" ,
233+ )
234+
246235
247236# yapf: disable
248237@pytest .mark .parametrize ("model_id" , [
@@ -281,6 +270,7 @@ def _test_processing_correctness_mistral(
281270 "AIDC-AI/Ovis2-1B" ,
282271 "google/paligemma-3b-mix-224" ,
283272 "google/paligemma2-3b-ft-docci-448" ,
273+ "microsoft/Phi-3.5-vision-instruct" ,
284274 "microsoft/Phi-4-multimodal-instruct" ,
285275 "mistralai/Pixtral-12B-2409" ,
286276 "mistral-community/pixtral-12b" ,
@@ -303,41 +293,6 @@ def test_processing_correctness(
303293 num_batches : int ,
304294 simplify_rate : float ,
305295):
306- ignore_mm_keys = None
307- if 'ultravox' in model_id :
308- # In Ultravox, the audio_features can be different depending on padding
309- # The slight difference should not be a problem though, since
310- # attention_mask lets us ignore the difference.
311- ignore_mm_keys = {"audio_features" }
312-
313- _test_processing_correctness (
314- model_id ,
315- hit_rate = hit_rate ,
316- num_batches = num_batches ,
317- simplify_rate = simplify_rate ,
318- ignore_mm_keys = ignore_mm_keys ,
319- )
320-
321-
322- # yapf: disable
323- @pytest .mark .parametrize ("model_id" , ["microsoft/Phi-3.5-vision-instruct" ])
324- @pytest .mark .parametrize ("hit_rate" , [0.3 , 0.5 , 1.0 ])
325- @pytest .mark .parametrize ("num_batches" , [32 ])
326- @pytest .mark .parametrize ("simplify_rate" , [1.0 ])
327- # yapf: enable
328- def test_processing_correctness_phi3v (
329- model_id : str ,
330- hit_rate : float ,
331- num_batches : int ,
332- simplify_rate : float ,
333- ):
334- # HACK - this is an attempted workaround for the following bug
335- # https://github.com/huggingface/transformers/issues/34307
336- from transformers import AutoImageProcessor # noqa: F401
337- from transformers import AutoProcessor # noqa: F401
338-
339- AutoImageProcessor .from_pretrained (model_id , trust_remote_code = True )
340-
341296 _test_processing_correctness (
342297 model_id ,
343298 hit_rate = hit_rate ,
@@ -356,16 +311,10 @@ def _assert_inputs_equal(
356311 if ignore_mm_keys is None :
357312 ignore_mm_keys = set ()
358313
359- if msg is None :
360- assert "mm_kwargs" in a and "mm_kwargs" in b
361- else :
362- assert "mm_kwargs" in a and "mm_kwargs" in b , msg
314+ assert "mm_kwargs" in a and "mm_kwargs" in b , msg
363315
364316 for key in ignore_mm_keys :
365317 a ["mm_kwargs" ].pop (key , None )
366318 b ["mm_kwargs" ].pop (key , None )
367319
368- if msg is None :
369- assert a == b
370- else :
371- assert a == b , msg
320+ assert a == b , msg
0 commit comments