Skip to content

Commit 2f109f5

Browse files
BeibinLisonichi
andauthored
Add vision capability (#2025)
* Add vision capability * Configurate: description_prompt * Print warning instead of raising issues for type * Skip vision capability test if dependencies not installed * Append "vision" to agent's system message when enabled VisionCapability * GPT-4V notebook update with ConversableAgent * Clean GPT-4V notebook * Add vision capability test to workflow * Lint import * Update system message for vision capability * Add a `custom_caption_func` to VisionCapability * Add custom function example for vision capability * Skip test Vision capability custom func * GPT-4V notebook metadata to website * Remove redundant files * The custom caption function takes more inputs now * Add a more complex example of custom caption func * Remove trailing space --------- Co-authored-by: Chi Wang <[email protected]>
1 parent 212722c commit 2f109f5

File tree

6 files changed

+878
-215
lines changed

6 files changed

+878
-215
lines changed

.github/workflows/contrib-tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ jobs:
247247
- name: Coverage
248248
run: |
249249
pip install coverage>=5.3
250-
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py --skip-openai
250+
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py test/agentchat/contrib/capabilities/test_vision_capability.py --skip-openai
251251
coverage xml
252252
- name: Upload coverage to Codecov
253253
uses: codecov/codecov-action@v3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import copy
2+
from typing import Callable, Dict, List, Optional, Union
3+
4+
from autogen.agentchat.assistant_agent import ConversableAgent
5+
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
6+
from autogen.agentchat.contrib.img_utils import (
7+
convert_base64_to_data_uri,
8+
get_image_data,
9+
get_pil_image,
10+
gpt4v_formatter,
11+
message_formatter_pil_to_b64,
12+
)
13+
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
14+
from autogen.agentchat.conversable_agent import colored
15+
from autogen.code_utils import content_str
16+
from autogen.oai.client import OpenAIWrapper
17+
18+
DEFAULT_DESCRIPTION_PROMPT = (
19+
"Write a detailed caption for this image. "
20+
"Pay special attention to any details that might be useful or relevant "
21+
"to the ongoing conversation."
22+
)
23+
24+
25+
class VisionCapability(AgentCapability):
26+
"""We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability,
27+
such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe
28+
the image (captioning) before sending the information to the agent's actual client.
29+
30+
The vision capability will hook to the ConversableAgent's `process_last_received_message`.
31+
32+
Some technical details:
33+
When the agent (who has the vision capability) received an message, it will:
34+
1. _process_received_message:
35+
a. _append_oai_message
36+
2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag.
37+
a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.)
38+
b. hook process_all_messages_before_reply
39+
3. send:
40+
a. hook process_message_before_send
41+
b. _append_oai_message
42+
"""
43+
44+
def __init__(
45+
self,
46+
lmm_config: Dict,
47+
description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT,
48+
custom_caption_func: Callable = None,
49+
) -> None:
50+
"""
51+
Initializes a new instance, setting up the configuration for interacting with
52+
a Language Multimodal (LMM) client and specifying optional parameters for image
53+
description and captioning.
54+
55+
Args:
56+
lmm_config (Dict): Configuration for the LMM client, which is used to call
57+
the LMM service for describing the image. This must be a dictionary containing
58+
the necessary configuration parameters. If `lmm_config` is False or an empty dictionary,
59+
it is considered invalid, and initialization will assert.
60+
description_prompt (Optional[str], optional): The prompt to use for generating
61+
descriptions of the image. This parameter allows customization of the
62+
prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided.
63+
custom_caption_func (Callable, optional): A callable that, if provided, will be used
64+
to generate captions for images. This allows for custom captioning logic outside
65+
of the standard LMM service interaction.
66+
The callable should take three parameters as input:
67+
1. an image URL (or local location)
68+
2. image_data (a PIL image)
69+
3. lmm_client (to call remote LMM)
70+
and then return a description (as string).
71+
If not provided, captioning will rely on the LMM client configured via `lmm_config`.
72+
If provided, we will not run the default self._get_image_caption method.
73+
74+
Raises:
75+
AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided,
76+
an AssertionError is raised to indicate that the Vision Capability requires
77+
one of these to be valid for operation.
78+
"""
79+
self._lmm_config = lmm_config
80+
self._description_prompt = description_prompt
81+
self._parent_agent = None
82+
83+
if lmm_config:
84+
self._lmm_client = OpenAIWrapper(**lmm_config)
85+
else:
86+
self._lmm_client = None
87+
88+
self._custom_caption_func = custom_caption_func
89+
assert (
90+
self._lmm_config or custom_caption_func
91+
), "Vision Capability requires a valid lmm_config or custom_caption_func."
92+
93+
def add_to_agent(self, agent: ConversableAgent) -> None:
94+
self._parent_agent = agent
95+
96+
# Append extra info to the system message.
97+
agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.")
98+
99+
# Register a hook for processing the last message.
100+
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
101+
102+
def process_last_received_message(self, content: Union[str, List[dict]]) -> str:
103+
"""
104+
Processes the last received message content by normalizing and augmenting it
105+
with descriptions of any included images. The function supports input content
106+
as either a string or a list of dictionaries, where each dictionary represents
107+
a content item (e.g., text, image). If the content contains image URLs, it
108+
fetches the image data, generates a caption for each image, and inserts the
109+
caption into the augmented content.
110+
111+
The function aims to transform the content into a format compatible with GPT-4V
112+
multimodal inputs, specifically by formatting strings into PIL-compatible
113+
images if needed and appending text descriptions for images. This allows for
114+
a more accessible presentation of the content, especially in contexts where
115+
images cannot be displayed directly.
116+
117+
Args:
118+
content (Union[str, List[dict]]): The last received message content, which
119+
can be a plain text string or a list of dictionaries representing
120+
different types of content items (e.g., text, image_url).
121+
122+
Returns:
123+
str: The augmented message content
124+
125+
Raises:
126+
AssertionError: If an item in the content list is not a dictionary.
127+
128+
Examples:
129+
Assuming `self._get_image_caption(img_data)` returns
130+
"A beautiful sunset over the mountains" for the image.
131+
132+
- Input as String:
133+
content = "Check out this cool photo!"
134+
Output: "Check out this cool photo!"
135+
(Content is a string without an image, remains unchanged.)
136+
137+
- Input as String, with image location:
138+
content = "What's weather in this cool photo: <img http://example.com/photo.jpg>"
139+
Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is:
140+
A beautiful sunset over the mountains\n"
141+
(Caption added after the image)
142+
143+
- Input as List with Text Only:
144+
content = [{"type": "text", "text": "Here's an interesting fact."}]
145+
Output: "Here's an interesting fact."
146+
(No images in the content, it remains unchanged.)
147+
148+
- Input as List with Image URL:
149+
content = [
150+
{"type": "text", "text": "What's weather in this cool photo:"},
151+
{"type": "image_url", "image_url": {"url": "http://example.com/photo.jpg"}}
152+
]
153+
Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is:
154+
A beautiful sunset over the mountains\n"
155+
(Caption added after the image)
156+
"""
157+
copy.deepcopy(content)
158+
# normalize the content into the gpt-4v format for multimodal
159+
# we want to keep the URL format to keep it concise.
160+
if isinstance(content, str):
161+
content = gpt4v_formatter(content, img_format="url")
162+
163+
aug_content: str = ""
164+
for item in content:
165+
assert isinstance(item, dict)
166+
if item["type"] == "text":
167+
aug_content += item["text"]
168+
elif item["type"] == "image_url":
169+
img_url = item["image_url"]["url"]
170+
img_caption = ""
171+
172+
if self._custom_caption_func:
173+
img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client)
174+
elif self._lmm_client:
175+
img_data = get_image_data(img_url)
176+
img_caption = self._get_image_caption(img_data)
177+
else:
178+
img_caption = ""
179+
180+
aug_content += f"<img {img_url}> in case you can not see, the caption of this image is: {img_caption}\n"
181+
else:
182+
print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.")
183+
184+
return aug_content
185+
186+
def _get_image_caption(self, img_data: str) -> str:
187+
"""
188+
Args:
189+
img_data (str): base64 encoded image data.
190+
Returns:
191+
str: caption for the given image.
192+
"""
193+
response = self._lmm_client.create(
194+
context=None,
195+
messages=[
196+
{
197+
"role": "user",
198+
"content": [
199+
{"type": "text", "text": self._description_prompt},
200+
{
201+
"type": "image_url",
202+
"image_url": {
203+
"url": convert_base64_to_data_uri(img_data),
204+
},
205+
},
206+
],
207+
}
208+
],
209+
)
210+
description = response.choices[0].message.content
211+
return content_str(description)

autogen/agentchat/contrib/img_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
2424
# Already a PIL Image object
2525
return image_file
2626

27+
# Remove quotes if existed
28+
if image_file.startswith('"') and image_file.endswith('"'):
29+
image_file = image_file[1:-1]
30+
if image_file.startswith("'") and image_file.endswith("'"):
31+
image_file = image_file[1:-1]
32+
2733
if image_file.startswith("http://") or image_file.startswith("https://"):
2834
# A URL file
2935
response = requests.get(image_file)

autogen/agentchat/conversable_agent.py

+22-21
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,35 @@
55
import json
66
import logging
77
import re
8+
import warnings
89
from collections import defaultdict
910
from functools import partial
1011
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
11-
import warnings
12+
1213
from openai import BadRequestError
1314

1415
from autogen.exception_utils import InvalidCarryOverType, SenderRequired
1516

16-
from ..coding.base import CodeExecutor
17-
from ..coding.factory import CodeExecutorFactory
18-
from ..formatting_utils import colored
19-
20-
from ..oai.client import OpenAIWrapper, ModelClient
21-
from ..runtime_logging import logging_enabled, log_new_agent
17+
from .._pydantic import model_dump
2218
from ..cache.cache import Cache
2319
from ..code_utils import (
2420
UNKNOWN,
25-
content_str,
2621
check_can_use_docker_or_throw,
22+
content_str,
2723
decide_use_docker,
2824
execute_code,
2925
extract_code,
3026
infer_lang,
3127
)
32-
from .utils import gather_usage_summary, consolidate_chat_info
33-
from .chat import ChatResult, initiate_chats, a_initiate_chats
34-
35-
28+
from ..coding.base import CodeExecutor
29+
from ..coding.factory import CodeExecutorFactory
30+
from ..formatting_utils import colored
3631
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
32+
from ..oai.client import ModelClient, OpenAIWrapper
33+
from ..runtime_logging import log_new_agent, logging_enabled
3734
from .agent import Agent, LLMAgent
38-
from .._pydantic import model_dump
35+
from .chat import ChatResult, a_initiate_chats, initiate_chats
36+
from .utils import consolidate_chat_info, gather_usage_summary
3937

4038
__all__ = ("ConversableAgent",)
4139

@@ -2603,22 +2601,25 @@ def process_last_received_message(self, messages):
26032601
return messages # Last message contains a context key.
26042602
if "content" not in last_message:
26052603
return messages # Last message has no content.
2606-
user_text = last_message["content"]
2607-
if not isinstance(user_text, str):
2608-
return messages # Last message content is not a string. TODO: Multimodal agents will use a dict here.
2609-
if user_text == "exit":
2604+
2605+
user_content = last_message["content"]
2606+
if not isinstance(user_content, str) and not isinstance(user_content, list):
2607+
# if the user_content is a string, it is for regular LLM
2608+
# if the user_content is a list, it should follow the multimodal LMM format.
2609+
return messages
2610+
if user_content == "exit":
26102611
return messages # Last message is an exit command.
26112612

26122613
# Call each hook (in order of registration) to process the user's message.
2613-
processed_user_text = user_text
2614+
processed_user_content = user_content
26142615
for hook in hook_list:
2615-
processed_user_text = hook(processed_user_text)
2616-
if processed_user_text == user_text:
2616+
processed_user_content = hook(processed_user_content)
2617+
if processed_user_content == user_content:
26172618
return messages # No hooks actually modified the user's message.
26182619

26192620
# Replace the last user message with the expanded one.
26202621
messages = messages.copy()
2621-
messages[-1]["content"] = processed_user_text
2622+
messages[-1]["content"] = processed_user_content
26222623
return messages
26232624

26242625
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:

0 commit comments

Comments
 (0)