Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def retrieve_images_in_messages(
Retrieve and combine images from the chat and the images passed as input.
"""
if images is None:
images = []
retrieved_images = None
elif not isinstance(images, Iterable):
images = [images]
retrieved_images = []
idx_images = 0
retrieved_images = []
for message in messages:
for content in message["content"]:
if isinstance(content, dict):
Expand Down Expand Up @@ -106,7 +106,7 @@ def retrieve_images_in_messages(
)

# The number of images passed should be consistent with the number of images in the chat without an image key
if idx_images != len(images):
if images is not None and idx_images != len(images):
raise ValueError(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
)
Expand Down Expand Up @@ -356,7 +356,8 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **p
inputs_text = inputs["text"]
images = inputs["images"]

images = load_images(images, timeout=timeout)
if images is not None:
images = load_images(images, timeout=timeout)

# if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1:
Expand Down
72 changes: 72 additions & 0 deletions tests/pipelines/test_pipelines_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,78 @@ def run_pipeline_test(self, pipe, examples):
],
)

@require_torch
def test_small_model_pt_token_text_only(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
text = "What is the capital of France? Assistant:"

outputs = pipe(text=text)
self.assertEqual(
outputs,
[
{
"input_text": "What is the capital of France? Assistant:",
"generated_text": "What is the capital of France? Assistant: The capital of France is Paris.",
}
],
)

messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Write a poem on Hugging Face, the company"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the capital of France?"},
],
},
],
]
outputs = pipe(text=messages)
self.assertEqual(
outputs,
[
[
{
"input_text": [
{
"role": "user",
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
}
],
"generated_text": [
{
"role": "user",
"content": [{"type": "text", "text": "Write a poem on Hugging Face, the company"}],
},
{
"role": "assistant",
"content": "Hugging Face, a company of minds\nWith tools and services that make our lives easier\nFrom",
},
],
}
],
[
{
"input_text": [
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}
],
"generated_text": [
{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]},
{"role": "assistant", "content": "Paris"},
],
}
],
],
)

@require_torch
def test_small_model_pt_token(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
Expand Down