diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 6b743997f5ee..ea590b8314e8 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -71,11 +71,11 @@ def retrieve_images_in_messages( Retrieve and combine images from the chat and the images passed as input. """ if images is None: - images = [] + retrieved_images = None elif not isinstance(images, Iterable): images = [images] + retrieved_images = [] idx_images = 0 - retrieved_images = [] for message in messages: for content in message["content"]: if isinstance(content, dict): @@ -106,7 +106,7 @@ def retrieve_images_in_messages( ) # The number of images passed should be consistent with the number of images in the chat without an image key - if idx_images != len(images): + if images is not None and idx_images != len(images): raise ValueError( "The number of images in the chat messages should be the same as the number of images passed to the pipeline." ) @@ -356,7 +356,8 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **p inputs_text = inputs["text"] images = inputs["images"] - images = load_images(images, timeout=timeout) + if images is not None: + images = load_images(images, timeout=timeout) # if batched text inputs, we set padding to True unless specified otherwise if isinstance(text, (list, tuple)) and len(text) > 1: diff --git a/tests/pipelines/test_pipelines_image_text_to_text.py b/tests/pipelines/test_pipelines_image_text_to_text.py index 903e90919c2c..3895c61774c7 100644 --- a/tests/pipelines/test_pipelines_image_text_to_text.py +++ b/tests/pipelines/test_pipelines_image_text_to_text.py @@ -66,6 +66,78 @@ def run_pipeline_test(self, pipe, examples): ], ) + @require_torch + def test_small_model_pt_token_text_only(self): + pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf") + text = "What is the capital of France? Assistant:" + + outputs = pipe(text=text) + self.assertEqual( + outputs, + [ + { + "input_text": "What is the capital of France? Assistant:", + "generated_text": "What is the capital of France? Assistant: The capital of France is Paris.", + } + ], + ) + + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Write a poem on Hugging Face, the company"}, + ], + }, + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is the capital of France?"}, + ], + }, + ], + ] + outputs = pipe(text=messages) + self.assertEqual( + outputs, + [ + [ + { + "input_text": [ + { + "role": "user", + "content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}], + } + ], + "generated_text": [ + { + "role": "user", + "content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}], + }, + { + "role": "assistant", + "content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom", + }, + ], + } + ], + [ + { + "input_text": [ + {"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]} + ], + "generated_text": [ + {"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}, + {"role": "assistant", "content": "Paris"}, + ], + } + ], + ], + ) + @require_torch def test_small_model_pt_token(self): pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")