-
Notifications
You must be signed in to change notification settings - Fork 33.7k
Process inputs directly in apply_chat_template in image-text-to-text pipeline #35616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
yonigozlan
merged 15 commits into
huggingface:main
from
yonigozlan:vectorize-input-chat-image-text-to-text-pipeline
Apr 23, 2025
Merged
Changes from 2 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
1834fab
tokenize inputs directly in apply_chat_template
yonigozlan 37bb6fc
refactor processing
yonigozlan 4618387
revert changes processing llava
yonigozlan 4352620
Merge remote-tracking branch 'upstream/main' into vectorize-input-cha…
yonigozlan 86fce1f
Update docs
yonigozlan df09033
fix issue with str being iterable
yonigozlan 12ea800
Merge branch 'main' into vectorize-input-chat-image-text-to-text-pipe…
yonigozlan 9e47e34
Merge branch 'main' into vectorize-input-chat-image-text-to-text-pipe…
yonigozlan ca0c94a
Merge branch 'main' into vectorize-input-chat-image-text-to-text-pipe…
yonigozlan 033522b
Merge remote-tracking branch 'upstream/main' into vectorize-input-cha…
yonigozlan 950f361
add test chat text only
yonigozlan 62598e9
Merge branch 'vectorize-input-chat-image-text-to-text-pipeline' of ht…
yonigozlan 094c44d
Merge branch 'main' into vectorize-input-chat-image-text-to-text-pipe…
yonigozlan 053f396
change function name
yonigozlan b288256
Merge branch 'vectorize-input-chat-image-text-to-text-pipeline' of ht…
yonigozlan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,10 +57,9 @@ def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image", | |
| for message in messages: | ||
| if not ("role" in message and "content" in message): | ||
| raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") | ||
| images = retrieve_images_in_messages(messages, images) | ||
| messages = retrieve_images_in_messages(messages, images) | ||
|
|
||
| self.messages = messages | ||
| self.images = images | ||
|
|
||
|
|
||
| def retrieve_images_in_messages( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like the function name should be changed here as it's not really what it does anymore |
||
|
|
@@ -72,43 +71,40 @@ def retrieve_images_in_messages( | |
| if images is None: | ||
| images = [] | ||
| idx_images = 0 | ||
| retrieved_images = [] | ||
| for message in messages: | ||
| for content in message["content"]: | ||
| if isinstance(content, dict): | ||
| if content.get("type") == "image": | ||
| for key in ["image", "url", "path", "base64"]: | ||
| if key in content: | ||
| retrieved_images.append(content[key]) | ||
| break | ||
| else: | ||
| if idx_images < len(images): | ||
| retrieved_images.append(images[idx_images]) | ||
| idx_images += 1 | ||
| else: | ||
| raise ValueError( | ||
| "The number of images in the chat messages should be the same as the number of images passed to the pipeline." | ||
| ) | ||
| # Add support for OpenAI/TGI chat format | ||
| elif content.get("type") == "image_url": | ||
| if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]: | ||
| retrieved_images.append(content["image_url"]["url"]) | ||
| # Rewrite content to be in the Transformers chat format | ||
| content["type"] = "image" | ||
| content["image"] = content["image_url"]["url"] | ||
| del content["image_url"] | ||
| if not isinstance(content, dict): | ||
| continue | ||
| content_type = content.get("type") | ||
| if content_type == "image": | ||
| if not any(key in content for key in ["image", "url", "path", "base64"]): | ||
| if idx_images < len(images): | ||
| # Insert the image passed as argument in the chat message | ||
| content["image"] = images[idx_images] | ||
| idx_images += 1 | ||
| else: | ||
| raise ValueError( | ||
| "Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key." | ||
| "The number of images in the chat messages should be the same as the number of images passed to the pipeline." | ||
| ) | ||
| # Add support for OpenAI/TGI chat format | ||
| elif content_type == "image_url": | ||
| if isinstance(content.get("image_url"), dict) and "url" in content["image_url"]: | ||
| # Rewrite content to be in the Transformers chat format | ||
| content["type"] = "image" | ||
| content["image"] = content["image_url"]["url"] | ||
| del content["image_url"] | ||
| else: | ||
| raise ValueError( | ||
| "Wrong format for 'image_url' content type. The content should have an 'image_url' dict with a 'url' key." | ||
| ) | ||
|
|
||
| # The number of images passed should be consistent with the number of images in the chat without an image key | ||
| if idx_images != len(images): | ||
| raise ValueError( | ||
| "The number of images in the chat messages should be the same as the number of images passed to the pipeline." | ||
| ) | ||
|
|
||
| return retrieved_images | ||
| return messages | ||
|
|
||
|
|
||
| @add_end_docstrings(build_pipeline_init_args(has_processor=True)) | ||
|
|
@@ -316,31 +312,30 @@ def __call__( | |
| return super().__call__({"images": images, "text": text}, **kwargs) | ||
|
|
||
| def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None): | ||
| if isinstance(inputs, Chat): | ||
| # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default | ||
| # because very few models support multiple separate, consecutive assistant messages | ||
| if continue_final_message is None: | ||
| continue_final_message = inputs.messages[-1]["role"] == "assistant" | ||
| model_inputs = self.processor.apply_chat_template( | ||
| inputs.messages, | ||
| add_generation_prompt=not continue_final_message, | ||
| continue_final_message=continue_final_message, | ||
| return_tensors=self.framework, | ||
| tokenize=True, | ||
| return_dict=True, | ||
| ) | ||
| model_inputs["text"] = inputs | ||
| return model_inputs | ||
| # In case we only have text inputs | ||
| if isinstance(inputs, (list, tuple, str)): | ||
| images = None | ||
| text = inputs | ||
| inputs_text = inputs | ||
| else: | ||
| if isinstance(inputs, Chat): | ||
| # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default | ||
| # because very few models support multiple separate, consecutive assistant messages | ||
| if continue_final_message is None: | ||
| continue_final_message = inputs.messages[-1]["role"] == "assistant" | ||
| text = self.processor.apply_chat_template( | ||
| inputs.messages, | ||
| add_generation_prompt=not continue_final_message, | ||
| continue_final_message=continue_final_message, | ||
| return_tensors=self.framework, | ||
| ) | ||
| inputs_text = inputs | ||
| images = inputs.images | ||
| else: | ||
| text = inputs["text"] | ||
| inputs_text = inputs["text"] | ||
| images = inputs["images"] | ||
|
|
||
| images = load_images(images) | ||
| images = load_images(inputs["images"]) | ||
| text = inputs["text"] | ||
| inputs_text = inputs["text"] | ||
|
|
||
| # if batched text inputs, we set padding to True unless specified otherwise | ||
| if isinstance(text, (list, tuple)) and len(text) > 1: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be
1to work correctly with different ViT backbones. Was it causing any test failures?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this change, I'm getting errors on pipeline tests that use to work with llava-interleave. For example:
returns:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in case of llava-interleave-qwen-0.5b-hf I see a mismatch in
vision_feature_select_strategyfor the model config and for processor. Will fix that on the hub :)