Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vision capability #2025

Merged
merged 30 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1be5865
Add vision capability
BeibinLi Mar 9, 2024
7f7e746
Configurate: description_prompt
BeibinLi Mar 9, 2024
8ff43fa
Print warning instead of raising issues for type
BeibinLi Mar 9, 2024
a3cad84
Merge branch 'main' of https://github.com/microsoft/autogen into visi…
BeibinLi Mar 13, 2024
24e0fec
Skip vision capability test if dependencies not installed
BeibinLi Mar 13, 2024
72a9373
Append "vision" to agent's system message when enabled VisionCapability
BeibinLi Mar 13, 2024
937a591
GPT-4V notebook update with ConversableAgent
BeibinLi Mar 13, 2024
8e6bac9
Clean GPT-4V notebook
BeibinLi Mar 13, 2024
9e5e720
Add vision capability test to workflow
BeibinLi Mar 13, 2024
225dbc6
Lint import
BeibinLi Mar 13, 2024
1639849
Merge branch 'main' into vision_capability
BeibinLi Mar 13, 2024
75353d0
Merge branch 'main' into vision_capability
sonichi Mar 14, 2024
264cf13
Merge branch 'main' into vision_capability
BeibinLi Mar 14, 2024
a7c27ca
Update system message for vision capability
BeibinLi Mar 14, 2024
8fc695d
Merge branch 'main' into vision_capability
BeibinLi Mar 14, 2024
9017b35
Add a `custom_caption_func` to VisionCapability
BeibinLi Mar 15, 2024
19958bc
Add custom function example for vision capability
BeibinLi Mar 15, 2024
a04db95
Skip test Vision capability custom func
BeibinLi Mar 15, 2024
cfdc4fc
GPT-4V notebook metadata to website
BeibinLi Mar 15, 2024
ac037bd
Merge branch 'main' into vision_capability
BeibinLi Mar 15, 2024
8ca5895
Merge branch 'main' into vision_capability
BeibinLi Mar 16, 2024
a65649e
Merge branch 'main' into vision_capability
BeibinLi Mar 17, 2024
74bc66d
Merge branch 'main' of https://github.com/microsoft/autogen into visi…
BeibinLi Mar 18, 2024
6da32f7
Merge branch 'vision_capability' of https://github.com/microsoft/auto…
BeibinLi Mar 18, 2024
27fa933
Remove redundant files
BeibinLi Mar 18, 2024
cf2d6c3
Merge branch 'main' of https://github.com/microsoft/autogen into visi…
BeibinLi Mar 21, 2024
6e776bf
The custom caption function takes more inputs now
BeibinLi Mar 22, 2024
400d2cf
Add a more complex example of custom caption func
BeibinLi Mar 22, 2024
2f1518c
Remove trailing space
BeibinLi Mar 22, 2024
6bce05d
Merge branch 'main' into vision_capability
sonichi Mar 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ jobs:
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py --skip-openai
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py test/agentchat/contrib/capabilities/test_vision_capability.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
Expand Down
211 changes: 211 additions & 0 deletions autogen/agentchat/contrib/capabilities/vision_capability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import copy
from typing import Callable, Dict, List, Optional, Union

from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.img_utils import (
convert_base64_to_data_uri,
get_image_data,
get_pil_image,
gpt4v_formatter,
message_formatter_pil_to_b64,
)
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
from autogen.agentchat.conversable_agent import colored
from autogen.code_utils import content_str
from autogen.oai.client import OpenAIWrapper

DEFAULT_DESCRIPTION_PROMPT = (
"Write a detailed caption for this image. "
"Pay special attention to any details that might be useful or relevant "
"to the ongoing conversation."
)


class VisionCapability(AgentCapability):
"""We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability,
such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe
the image (captioning) before sending the information to the agent's actual client.

The vision capability will hook to the ConversableAgent's `process_last_received_message`.

Some technical details:
When the agent (who has the vision capability) received an message, it will:
1. _process_received_message:
a. _append_oai_message
2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag.
a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.)
b. hook process_all_messages_before_reply
3. send:
a. hook process_message_before_send
b. _append_oai_message
"""

def __init__(
self,
lmm_config: Dict,
description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT,
custom_caption_func: Callable = None,
) -> None:
"""
Initializes a new instance, setting up the configuration for interacting with
a Language Multimodal (LMM) client and specifying optional parameters for image
description and captioning.

Args:
lmm_config (Dict): Configuration for the LMM client, which is used to call
the LMM service for describing the image. This must be a dictionary containing
the necessary configuration parameters. If `lmm_config` is False or an empty dictionary,
it is considered invalid, and initialization will assert.
description_prompt (Optional[str], optional): The prompt to use for generating
descriptions of the image. This parameter allows customization of the
prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided.
custom_caption_func (Callable, optional): A callable that, if provided, will be used
to generate captions for images. This allows for custom captioning logic outside
of the standard LMM service interaction.
The callable should take three parameters as input:
1. an image URL (or local location)
2. image_data (a PIL image)
3. lmm_client (to call remote LMM)
and then return a description (as string).
If not provided, captioning will rely on the LMM client configured via `lmm_config`.
If provided, we will not run the default self._get_image_caption method.

Raises:
AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided,
an AssertionError is raised to indicate that the Vision Capability requires
one of these to be valid for operation.
"""
self._lmm_config = lmm_config
self._description_prompt = description_prompt
self._parent_agent = None

if lmm_config:
self._lmm_client = OpenAIWrapper(**lmm_config)
else:
self._lmm_client = None

self._custom_caption_func = custom_caption_func
assert (
self._lmm_config or custom_caption_func
), "Vision Capability requires a valid lmm_config or custom_caption_func."

def add_to_agent(self, agent: ConversableAgent) -> None:
self._parent_agent = agent

# Append extra info to the system message.
agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.")

# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)

def process_last_received_message(self, content: Union[str, List[dict]]) -> str:
"""
Processes the last received message content by normalizing and augmenting it
with descriptions of any included images. The function supports input content
as either a string or a list of dictionaries, where each dictionary represents
a content item (e.g., text, image). If the content contains image URLs, it
fetches the image data, generates a caption for each image, and inserts the
caption into the augmented content.

The function aims to transform the content into a format compatible with GPT-4V
multimodal inputs, specifically by formatting strings into PIL-compatible
images if needed and appending text descriptions for images. This allows for
a more accessible presentation of the content, especially in contexts where
images cannot be displayed directly.

Args:
content (Union[str, List[dict]]): The last received message content, which
can be a plain text string or a list of dictionaries representing
different types of content items (e.g., text, image_url).

Returns:
str: The augmented message content

Raises:
AssertionError: If an item in the content list is not a dictionary.

Examples:
Assuming `self._get_image_caption(img_data)` returns
"A beautiful sunset over the mountains" for the image.

- Input as String:
content = "Check out this cool photo!"
Output: "Check out this cool photo!"
(Content is a string without an image, remains unchanged.)

- Input as String, with image location:
content = "What's weather in this cool photo: <img http://example.com/photo.jpg>"
Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)

- Input as List with Text Only:
content = [{"type": "text", "text": "Here's an interesting fact."}]
Output: "Here's an interesting fact."
(No images in the content, it remains unchanged.)

- Input as List with Image URL:
content = [
{"type": "text", "text": "What's weather in this cool photo:"},
{"type": "image_url", "image_url": {"url": "http://example.com/photo.jpg"}}
]
Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)
"""
copy.deepcopy(content)
# normalize the content into the gpt-4v format for multimodal
# we want to keep the URL format to keep it concise.
if isinstance(content, str):
content = gpt4v_formatter(content, img_format="url")

aug_content: str = ""
for item in content:
assert isinstance(item, dict)
if item["type"] == "text":
aug_content += item["text"]
elif item["type"] == "image_url":
img_url = item["image_url"]["url"]
img_caption = ""

if self._custom_caption_func:
img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client)
elif self._lmm_client:
img_data = get_image_data(img_url)
img_caption = self._get_image_caption(img_data)
else:
img_caption = ""

aug_content += f"<img {img_url}> in case you can not see, the caption of this image is: {img_caption}\n"
BeibinLi marked this conversation as resolved.
Show resolved Hide resolved
else:
print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.")

return aug_content

def _get_image_caption(self, img_data: str) -> str:
"""
Args:
img_data (str): base64 encoded image data.
Returns:
str: caption for the given image.
"""
response = self._lmm_client.create(
context=None,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": self._description_prompt},
{
"type": "image_url",
"image_url": {
"url": convert_base64_to_data_uri(img_data),
},
},
],
}
],
)
description = response.choices[0].message.content
return content_str(description)
6 changes: 6 additions & 0 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
# Already a PIL Image object
return image_file

# Remove quotes if existed
if image_file.startswith('"') and image_file.endswith('"'):
image_file = image_file[1:-1]
if image_file.startswith("'") and image_file.endswith("'"):
image_file = image_file[1:-1]

if image_file.startswith("http://") or image_file.startswith("https://"):
# A URL file
response = requests.get(image_file)
Expand Down
43 changes: 22 additions & 21 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,35 @@
import json
import logging
import re
import warnings
from collections import defaultdict
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
import warnings

from openai import BadRequestError

from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory
from ..formatting_utils import colored

from ..oai.client import OpenAIWrapper, ModelClient
from ..runtime_logging import logging_enabled, log_new_agent
from .._pydantic import model_dump
from ..cache.cache import Cache
from ..code_utils import (
UNKNOWN,
content_str,
check_can_use_docker_or_throw,
content_str,
decide_use_docker,
execute_code,
extract_code,
infer_lang,
)
from .utils import gather_usage_summary, consolidate_chat_info
from .chat import ChatResult, initiate_chats, a_initiate_chats


from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory
from ..formatting_utils import colored
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from ..oai.client import ModelClient, OpenAIWrapper
from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent, LLMAgent
from .._pydantic import model_dump
from .chat import ChatResult, a_initiate_chats, initiate_chats
from .utils import consolidate_chat_info, gather_usage_summary

__all__ = ("ConversableAgent",)

Expand Down Expand Up @@ -2603,22 +2601,25 @@ def process_last_received_message(self, messages):
return messages # Last message contains a context key.
if "content" not in last_message:
return messages # Last message has no content.
user_text = last_message["content"]
if not isinstance(user_text, str):
return messages # Last message content is not a string. TODO: Multimodal agents will use a dict here.
if user_text == "exit":

user_content = last_message["content"]
if not isinstance(user_content, str) and not isinstance(user_content, list):
# if the user_content is a string, it is for regular LLM
# if the user_content is a list, it should follow the multimodal LMM format.
return messages
if user_content == "exit":
return messages # Last message is an exit command.

# Call each hook (in order of registration) to process the user's message.
processed_user_text = user_text
processed_user_content = user_content
for hook in hook_list:
processed_user_text = hook(processed_user_text)
if processed_user_text == user_text:
processed_user_content = hook(processed_user_content)
if processed_user_content == user_content:
return messages # No hooks actually modified the user's message.

# Replace the last user message with the expanded one.
messages = messages.copy()
messages[-1]["content"] = processed_user_text
messages[-1]["content"] = processed_user_content
return messages

def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
Expand Down
Loading
Loading