Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Outdate] Add vision capability #1926

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ jobs:
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py --skip-openai
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_vision_capability.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
Expand Down
185 changes: 185 additions & 0 deletions autogen/agentchat/contrib/capabilities/vision_capability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import copy
from typing import Dict, List, Optional, Union

from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.img_utils import (
convert_base64_to_data_uri,
get_image_data,
get_pil_image,
gpt4v_formatter,
message_formatter_pil_to_b64,
)
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
from autogen.agentchat.conversable_agent import colored
from autogen.code_utils import content_str
from autogen.oai.client import OpenAIWrapper

DEFAULT_DESCRIPTION_PROMPT = (
"Write a detailed caption for this image. "
"Pay special attention to any details that might be useful or relevant "
"to the ongoing conversation."
)


class VisionCapability(AgentCapability):
"""We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability,
such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe
the image (captioning) before sending the information to the agent's actual client.

The vision capability will hook to the ConversableAgent's `process_last_received_message`.

Some technical details:
When the agent (who has the vision capability) received an message, it will:
1. _process_received_message:
a. _append_oai_message
2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag.
a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.)
b. hook process_all_messages_before_reply
3. send:
a. hook process_message_before_send
b. _append_oai_message
"""

def __init__(
self,
lmm_config: Dict,
description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT,
) -> None:
"""
Args:
lmm_config (dict or False): LMM (multimodal) client configuration,
which will be used to call LMM to describe the image.
description_prompt (str, optional): The prompt to use for describing the image.
"""
assert lmm_config, "Vision Capability requires a valid lmm_config."
self._lmm_config = lmm_config
self._description_prompt = description_prompt
self._parent_agent = None
self._lmm_client = OpenAIWrapper(**lmm_config)

def add_to_agent(self, agent: ConversableAgent) -> None:
if isinstance(agent, MultimodalConversableAgent):
print(
colored(
"Warning: This agent is already a multimodal agent. The vision capability will not be added.",
"yellow",
)
)
return # do nothing

self._parent_agent = agent

# Append extra info to the system message.
agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.")

# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)

# Was an lmm_config passed to the constructor?
if self._lmm_config is None:
# No. Use the agent's lmm_config.
self._lmm_config = agent.lmm_config
assert self._lmm_config, "Vision Capability requires a valid lmm_config."

def process_last_received_message(self, content: Union[str, List[dict]]) -> str:
"""
Processes the last received message content by normalizing and augmenting it
with descriptions of any included images. The function supports input content
as either a string or a list of dictionaries, where each dictionary represents
a content item (e.g., text, image). If the content contains image URLs, it
fetches the image data, generates a caption for each image, and inserts the
caption into the augmented content.

The function aims to transform the content into a format compatible with GPT-4V
multimodal inputs, specifically by formatting strings into PIL-compatible
images if needed and appending text descriptions for images. This allows for
a more accessible presentation of the content, especially in contexts where
images cannot be displayed directly.

Args:
content (Union[str, List[dict]]): The last received message content, which
can be a plain text string or a list of dictionaries representing
different types of content items (e.g., text, image_url).

Returns:
str: The augmented message content

Raises:
AssertionError: If an item in the content list is not a dictionary.

Examples:
Assuming `self._get_image_caption(img_data)` returns
"A beautiful sunset over the mountains" for the image.

- Input as String:
content = "Check out this cool photo!"
Output: "Check out this cool photo!"
(Content is a string without an image, remains unchanged.)

- Input as String, with image location:
content = "What's weather in this cool photo: <img http://example.com/photo.jpg>"
Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)

- Input as List with Text Only:
content = [{"type": "text", "text": "Here's an interesting fact."}]
Output: "Here's an interesting fact."
(No images in the content, it remains unchanged.)

- Input as List with Image URL:
content = [
{"type": "text", "text": "What's weather in this cool photo:"},
{"type": "image_url", "image_url": {"url": "http://example.com/photo.jpg"}}
]
Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is:
A beautiful sunset over the mountains\n"
(Caption added after the image)
"""
copy.deepcopy(content)
# normalize the content into the gpt-4v format for multimodal
# we want to keep the URL format to keep it concise.
if isinstance(content, str):
content = gpt4v_formatter(content, img_format="url")

aug_content: str = ""
for item in content:
assert isinstance(item, dict)
if item["type"] == "text":
aug_content += item["text"]
elif item["type"] == "image_url":
img_data = get_image_data(item["image_url"]["url"])
img_caption = self._get_image_caption(img_data)
aug_content += f'<img {item["image_url"]["url"]}> in case you can not see, the caption of this image is: {img_caption}\n'
WaelKarkoub marked this conversation as resolved.
Show resolved Hide resolved
else:
print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.")

return aug_content

def _get_image_caption(self, img_data: str) -> str:
"""
Args:
img_data (str): base64 encoded image data.
Returns:
str: caption for the given image.
"""
response = self._lmm_client.create(
context=None,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": self._description_prompt},
{
"type": "image_url",
"image_url": {
"url": convert_base64_to_data_uri(img_data),
},
},
],
}
],
)
description = response.choices[0].message.content
return content_str(description)
6 changes: 6 additions & 0 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
# Already a PIL Image object
return image_file

# Remove quotes if existed
if image_file.startswith('"') and image_file.endswith('"'):
image_file = image_file[1:-1]
if image_file.startswith("'") and image_file.endswith("'"):
image_file = image_file[1:-1]

if image_file.startswith("http://") or image_file.startswith("https://"):
# A URL file
response = requests.get(image_file)
Expand Down
41 changes: 21 additions & 20 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,34 @@
import json
import logging
import re
import warnings
from collections import defaultdict
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
import warnings

from openai import BadRequestError

from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory

from ..oai.client import OpenAIWrapper, ModelClient
from ..runtime_logging import logging_enabled, log_new_agent
from .._pydantic import model_dump
from ..cache.cache import Cache
from ..code_utils import (
UNKNOWN,
content_str,
check_can_use_docker_or_throw,
content_str,
decide_use_docker,
execute_code,
extract_code,
infer_lang,
)
from .utils import gather_usage_summary, consolidate_chat_info
from .chat import ChatResult, initiate_chats, a_initiate_chats


from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from ..oai.client import ModelClient, OpenAIWrapper
from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent, LLMAgent
from .._pydantic import model_dump
from .chat import ChatResult, a_initiate_chats, initiate_chats
from .utils import consolidate_chat_info, gather_usage_summary

try:
from termcolor import colored
Expand Down Expand Up @@ -2595,22 +2593,25 @@ def process_last_received_message(self, messages):
return messages # Last message contains a context key.
if "content" not in last_message:
return messages # Last message has no content.
user_text = last_message["content"]
if not isinstance(user_text, str):
return messages # Last message content is not a string. TODO: Multimodal agents will use a dict here.
if user_text == "exit":

user_content = last_message["content"]
if not isinstance(user_content, str) and not isinstance(user_content, list):
# if the user_content is a string, it is for regular LLM
# if the user_content is a list, it should follow the multimodal LMM format.
return messages
if user_content == "exit":
return messages # Last message is an exit command.

# Call each hook (in order of registration) to process the user's message.
processed_user_text = user_text
processed_user_content = user_content
for hook in hook_list:
processed_user_text = hook(processed_user_text)
if processed_user_text == user_text:
processed_user_content = hook(processed_user_content)
if processed_user_content == user_content:
return messages # No hooks actually modified the user's message.

# Replace the last user message with the expanded one.
messages = messages.copy()
messages[-1]["content"] = processed_user_text
messages[-1]["content"] = processed_user_content
return messages

def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
Expand Down
Loading
Loading