|
| 1 | +import copy |
| 2 | +from typing import Callable, Dict, List, Optional, Union |
| 3 | + |
| 4 | +from autogen.agentchat.assistant_agent import ConversableAgent |
| 5 | +from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability |
| 6 | +from autogen.agentchat.contrib.img_utils import ( |
| 7 | + convert_base64_to_data_uri, |
| 8 | + get_image_data, |
| 9 | + get_pil_image, |
| 10 | + gpt4v_formatter, |
| 11 | + message_formatter_pil_to_b64, |
| 12 | +) |
| 13 | +from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent |
| 14 | +from autogen.agentchat.conversable_agent import colored |
| 15 | +from autogen.code_utils import content_str |
| 16 | +from autogen.oai.client import OpenAIWrapper |
| 17 | + |
| 18 | +DEFAULT_DESCRIPTION_PROMPT = ( |
| 19 | + "Write a detailed caption for this image. " |
| 20 | + "Pay special attention to any details that might be useful or relevant " |
| 21 | + "to the ongoing conversation." |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +class VisionCapability(AgentCapability): |
| 26 | + """We can add vision capability to regular ConversableAgent, even if the agent does not have the multimodal capability, |
| 27 | + such as GPT-3.5-turbo agent, Llama, Orca, or Mistral agents. This vision capability will invoke a LMM client to describe |
| 28 | + the image (captioning) before sending the information to the agent's actual client. |
| 29 | +
|
| 30 | + The vision capability will hook to the ConversableAgent's `process_last_received_message`. |
| 31 | +
|
| 32 | + Some technical details: |
| 33 | + When the agent (who has the vision capability) received an message, it will: |
| 34 | + 1. _process_received_message: |
| 35 | + a. _append_oai_message |
| 36 | + 2. generate_reply: if the agent is a MultimodalAgent, it will also use the image tag. |
| 37 | + a. hook process_last_received_message (NOTE: this is where the vision capability will be hooked to.) |
| 38 | + b. hook process_all_messages_before_reply |
| 39 | + 3. send: |
| 40 | + a. hook process_message_before_send |
| 41 | + b. _append_oai_message |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + lmm_config: Dict, |
| 47 | + description_prompt: Optional[str] = DEFAULT_DESCRIPTION_PROMPT, |
| 48 | + custom_caption_func: Callable = None, |
| 49 | + ) -> None: |
| 50 | + """ |
| 51 | + Initializes a new instance, setting up the configuration for interacting with |
| 52 | + a Language Multimodal (LMM) client and specifying optional parameters for image |
| 53 | + description and captioning. |
| 54 | +
|
| 55 | + Args: |
| 56 | + lmm_config (Dict): Configuration for the LMM client, which is used to call |
| 57 | + the LMM service for describing the image. This must be a dictionary containing |
| 58 | + the necessary configuration parameters. If `lmm_config` is False or an empty dictionary, |
| 59 | + it is considered invalid, and initialization will assert. |
| 60 | + description_prompt (Optional[str], optional): The prompt to use for generating |
| 61 | + descriptions of the image. This parameter allows customization of the |
| 62 | + prompt passed to the LMM service. Defaults to `DEFAULT_DESCRIPTION_PROMPT` if not provided. |
| 63 | + custom_caption_func (Callable, optional): A callable that, if provided, will be used |
| 64 | + to generate captions for images. This allows for custom captioning logic outside |
| 65 | + of the standard LMM service interaction. |
| 66 | + The callable should take three parameters as input: |
| 67 | + 1. an image URL (or local location) |
| 68 | + 2. image_data (a PIL image) |
| 69 | + 3. lmm_client (to call remote LMM) |
| 70 | + and then return a description (as string). |
| 71 | + If not provided, captioning will rely on the LMM client configured via `lmm_config`. |
| 72 | + If provided, we will not run the default self._get_image_caption method. |
| 73 | +
|
| 74 | + Raises: |
| 75 | + AssertionError: If neither a valid `lmm_config` nor a `custom_caption_func` is provided, |
| 76 | + an AssertionError is raised to indicate that the Vision Capability requires |
| 77 | + one of these to be valid for operation. |
| 78 | + """ |
| 79 | + self._lmm_config = lmm_config |
| 80 | + self._description_prompt = description_prompt |
| 81 | + self._parent_agent = None |
| 82 | + |
| 83 | + if lmm_config: |
| 84 | + self._lmm_client = OpenAIWrapper(**lmm_config) |
| 85 | + else: |
| 86 | + self._lmm_client = None |
| 87 | + |
| 88 | + self._custom_caption_func = custom_caption_func |
| 89 | + assert ( |
| 90 | + self._lmm_config or custom_caption_func |
| 91 | + ), "Vision Capability requires a valid lmm_config or custom_caption_func." |
| 92 | + |
| 93 | + def add_to_agent(self, agent: ConversableAgent) -> None: |
| 94 | + self._parent_agent = agent |
| 95 | + |
| 96 | + # Append extra info to the system message. |
| 97 | + agent.update_system_message(agent.system_message + "\nYou've been given the ability to interpret images.") |
| 98 | + |
| 99 | + # Register a hook for processing the last message. |
| 100 | + agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message) |
| 101 | + |
| 102 | + def process_last_received_message(self, content: Union[str, List[dict]]) -> str: |
| 103 | + """ |
| 104 | + Processes the last received message content by normalizing and augmenting it |
| 105 | + with descriptions of any included images. The function supports input content |
| 106 | + as either a string or a list of dictionaries, where each dictionary represents |
| 107 | + a content item (e.g., text, image). If the content contains image URLs, it |
| 108 | + fetches the image data, generates a caption for each image, and inserts the |
| 109 | + caption into the augmented content. |
| 110 | +
|
| 111 | + The function aims to transform the content into a format compatible with GPT-4V |
| 112 | + multimodal inputs, specifically by formatting strings into PIL-compatible |
| 113 | + images if needed and appending text descriptions for images. This allows for |
| 114 | + a more accessible presentation of the content, especially in contexts where |
| 115 | + images cannot be displayed directly. |
| 116 | +
|
| 117 | + Args: |
| 118 | + content (Union[str, List[dict]]): The last received message content, which |
| 119 | + can be a plain text string or a list of dictionaries representing |
| 120 | + different types of content items (e.g., text, image_url). |
| 121 | +
|
| 122 | + Returns: |
| 123 | + str: The augmented message content |
| 124 | +
|
| 125 | + Raises: |
| 126 | + AssertionError: If an item in the content list is not a dictionary. |
| 127 | +
|
| 128 | + Examples: |
| 129 | + Assuming `self._get_image_caption(img_data)` returns |
| 130 | + "A beautiful sunset over the mountains" for the image. |
| 131 | +
|
| 132 | + - Input as String: |
| 133 | + content = "Check out this cool photo!" |
| 134 | + Output: "Check out this cool photo!" |
| 135 | + (Content is a string without an image, remains unchanged.) |
| 136 | +
|
| 137 | + - Input as String, with image location: |
| 138 | + content = "What's weather in this cool photo: <img http://example.com/photo.jpg>" |
| 139 | + Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is: |
| 140 | + A beautiful sunset over the mountains\n" |
| 141 | + (Caption added after the image) |
| 142 | +
|
| 143 | + - Input as List with Text Only: |
| 144 | + content = [{"type": "text", "text": "Here's an interesting fact."}] |
| 145 | + Output: "Here's an interesting fact." |
| 146 | + (No images in the content, it remains unchanged.) |
| 147 | +
|
| 148 | + - Input as List with Image URL: |
| 149 | + content = [ |
| 150 | + {"type": "text", "text": "What's weather in this cool photo:"}, |
| 151 | + {"type": "image_url", "image_url": {"url": "http://example.com/photo.jpg"}} |
| 152 | + ] |
| 153 | + Output: "What's weather in this cool photo: <img http://example.com/photo.jpg> in case you can not see, the caption of this image is: |
| 154 | + A beautiful sunset over the mountains\n" |
| 155 | + (Caption added after the image) |
| 156 | + """ |
| 157 | + copy.deepcopy(content) |
| 158 | + # normalize the content into the gpt-4v format for multimodal |
| 159 | + # we want to keep the URL format to keep it concise. |
| 160 | + if isinstance(content, str): |
| 161 | + content = gpt4v_formatter(content, img_format="url") |
| 162 | + |
| 163 | + aug_content: str = "" |
| 164 | + for item in content: |
| 165 | + assert isinstance(item, dict) |
| 166 | + if item["type"] == "text": |
| 167 | + aug_content += item["text"] |
| 168 | + elif item["type"] == "image_url": |
| 169 | + img_url = item["image_url"]["url"] |
| 170 | + img_caption = "" |
| 171 | + |
| 172 | + if self._custom_caption_func: |
| 173 | + img_caption = self._custom_caption_func(img_url, get_pil_image(img_url), self._lmm_client) |
| 174 | + elif self._lmm_client: |
| 175 | + img_data = get_image_data(img_url) |
| 176 | + img_caption = self._get_image_caption(img_data) |
| 177 | + else: |
| 178 | + img_caption = "" |
| 179 | + |
| 180 | + aug_content += f"<img {img_url}> in case you can not see, the caption of this image is: {img_caption}\n" |
| 181 | + else: |
| 182 | + print(f"Warning: the input type should either be `test` or `image_url`. Skip {item['type']} here.") |
| 183 | + |
| 184 | + return aug_content |
| 185 | + |
| 186 | + def _get_image_caption(self, img_data: str) -> str: |
| 187 | + """ |
| 188 | + Args: |
| 189 | + img_data (str): base64 encoded image data. |
| 190 | + Returns: |
| 191 | + str: caption for the given image. |
| 192 | + """ |
| 193 | + response = self._lmm_client.create( |
| 194 | + context=None, |
| 195 | + messages=[ |
| 196 | + { |
| 197 | + "role": "user", |
| 198 | + "content": [ |
| 199 | + {"type": "text", "text": self._description_prompt}, |
| 200 | + { |
| 201 | + "type": "image_url", |
| 202 | + "image_url": { |
| 203 | + "url": convert_base64_to_data_uri(img_data), |
| 204 | + }, |
| 205 | + }, |
| 206 | + ], |
| 207 | + } |
| 208 | + ], |
| 209 | + ) |
| 210 | + description = response.choices[0].message.content |
| 211 | + return content_str(description) |
0 commit comments