|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import pdb |
| 5 | +import re |
| 6 | +from typing import Any, Dict, List, Optional, Tuple, Union |
| 7 | + |
| 8 | +import replicate |
| 9 | +import requests |
| 10 | +from regex import R |
| 11 | + |
| 12 | +from autogen.agentchat.agent import Agent |
| 13 | +from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent |
| 14 | +from autogen.code_utils import content_str |
| 15 | +from autogen.img_utils import get_image_data, llava_formater |
| 16 | + |
| 17 | +try: |
| 18 | + from termcolor import colored |
| 19 | +except ImportError: |
| 20 | + |
| 21 | + def colored(x, *args, **kwargs): |
| 22 | + return x |
| 23 | + |
| 24 | + |
| 25 | +logger = logging.getLogger(__name__) |
| 26 | + |
| 27 | +# we will override the following variables later. |
| 28 | +SEP = "###" |
| 29 | + |
| 30 | +DEFAULT_LLAVA_SYS_MSG = "You are an AI agent and you can view images." |
| 31 | + |
| 32 | + |
| 33 | +class LLaVAAgent(MultimodalConversableAgent): |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + name: str, |
| 37 | + system_message: Optional[Tuple[str, List]] = DEFAULT_LLAVA_SYS_MSG, |
| 38 | + *args, |
| 39 | + **kwargs, |
| 40 | + ): |
| 41 | + """ |
| 42 | + Args: |
| 43 | + name (str): agent name. |
| 44 | + system_message (str): system message for the ChatCompletion inference. |
| 45 | + Please override this attribute if you want to reprogram the agent. |
| 46 | + **kwargs (dict): Please refer to other kwargs in |
| 47 | + [ConversableAgent](../conversable_agent#__init__). |
| 48 | + """ |
| 49 | + super().__init__( |
| 50 | + name, |
| 51 | + system_message=system_message, |
| 52 | + *args, |
| 53 | + **kwargs, |
| 54 | + ) |
| 55 | + |
| 56 | + assert self.llm_config is not None, "llm_config must be provided." |
| 57 | + self.register_reply([Agent, None], reply_func=LLaVAAgent._image_reply, position=1) |
| 58 | + |
| 59 | + def _image_reply(self, messages=None, sender=None, config=None): |
| 60 | + # Note: we did not use "llm_config" yet. |
| 61 | + |
| 62 | + if all((messages is None, sender is None)): |
| 63 | + error_msg = f"Either {messages=} or {sender=} must be provided." |
| 64 | + logger.error(error_msg) |
| 65 | + raise AssertionError(error_msg) |
| 66 | + |
| 67 | + if messages is None: |
| 68 | + messages = self._oai_messages[sender] |
| 69 | + |
| 70 | + # The formats for LLaVA and GPT are different. So, we manually handle them here. |
| 71 | + images = [] |
| 72 | + prompt = content_str(self.system_message) + "\n" |
| 73 | + for msg in messages: |
| 74 | + role = "Human" if msg["role"] == "user" else "Assistant" |
| 75 | + # pdb.set_trace() |
| 76 | + images += [d["image_url"]["url"] for d in msg["content"] if d["type"] == "image_url"] |
| 77 | + content_prompt = content_str(msg["content"]) |
| 78 | + prompt += f"{SEP}{role}: {content_prompt}\n" |
| 79 | + prompt += "\n" + SEP + "Assistant: " |
| 80 | + images = [re.sub("data:image/.+;base64,", "", im, count=1) for im in images] |
| 81 | + print(colored(prompt, "blue")) |
| 82 | + |
| 83 | + out = "" |
| 84 | + retry = 10 |
| 85 | + while len(out) == 0 and retry > 0: |
| 86 | + # image names will be inferred automatically from llava_call |
| 87 | + out = llava_call_binary( |
| 88 | + prompt=prompt, |
| 89 | + images=images, |
| 90 | + config_list=self.llm_config["config_list"], |
| 91 | + temperature=self.llm_config.get("temperature", 0.5), |
| 92 | + max_new_tokens=self.llm_config.get("max_new_tokens", 2000), |
| 93 | + ) |
| 94 | + retry -= 1 |
| 95 | + |
| 96 | + assert out != "", "Empty response from LLaVA." |
| 97 | + |
| 98 | + return True, out |
| 99 | + |
| 100 | + |
| 101 | +def _llava_call_binary_with_config( |
| 102 | + prompt: str, images: list, config: dict, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1 |
| 103 | +): |
| 104 | + if config["base_url"].find("0.0.0.0") >= 0 or config["base_url"].find("localhost") >= 0: |
| 105 | + llava_mode = "local" |
| 106 | + else: |
| 107 | + llava_mode = "remote" |
| 108 | + |
| 109 | + if llava_mode == "local": |
| 110 | + headers = {"User-Agent": "LLaVA Client"} |
| 111 | + pload = { |
| 112 | + "model": config["model"], |
| 113 | + "prompt": prompt, |
| 114 | + "max_new_tokens": max_new_tokens, |
| 115 | + "temperature": temperature, |
| 116 | + "stop": SEP, |
| 117 | + "images": images, |
| 118 | + } |
| 119 | + |
| 120 | + response = requests.post( |
| 121 | + config["base_url"].rstrip("/") + "/worker_generate_stream", headers=headers, json=pload, stream=False |
| 122 | + ) |
| 123 | + |
| 124 | + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): |
| 125 | + if chunk: |
| 126 | + data = json.loads(chunk.decode("utf-8")) |
| 127 | + output = data["text"].split(SEP)[-1] |
| 128 | + elif llava_mode == "remote": |
| 129 | + # The Replicate version of the model only support 1 image for now. |
| 130 | + img = "data:image/jpeg;base64," + images[0] |
| 131 | + response = replicate.run( |
| 132 | + config["base_url"], input={"image": img, "prompt": prompt.replace("<image>", " "), "seed": seed} |
| 133 | + ) |
| 134 | + # The yorickvp/llava-13b model can stream output as it's running. |
| 135 | + # The predict method returns an iterator, and you can iterate over that output. |
| 136 | + output = "" |
| 137 | + for item in response: |
| 138 | + # https://replicate.com/yorickvp/llava-13b/versions/2facb4a474a0462c15041b78b1ad70952ea46b5ec6ad29583c0b29dbd4249591/api#output-schema |
| 139 | + output += item |
| 140 | + |
| 141 | + # Remove the prompt and the space. |
| 142 | + output = output.replace(prompt, "").strip().rstrip() |
| 143 | + return output |
| 144 | + |
| 145 | + |
| 146 | +def llava_call_binary( |
| 147 | + prompt: str, images: list, config_list: list, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1 |
| 148 | +): |
| 149 | + # TODO 1: add caching around the LLaVA call to save compute and cost |
| 150 | + # TODO 2: add `seed` to ensure reproducibility. The seed is not working now. |
| 151 | + |
| 152 | + for config in config_list: |
| 153 | + try: |
| 154 | + return _llava_call_binary_with_config(prompt, images, config, max_new_tokens, temperature, seed) |
| 155 | + except Exception as e: |
| 156 | + print(f"Error: {e}") |
| 157 | + continue |
| 158 | + |
| 159 | + |
| 160 | +def llava_call(prompt: str, llm_config: dict) -> str: |
| 161 | + """ |
| 162 | + Makes a call to the LLaVA service to generate text based on a given prompt |
| 163 | + """ |
| 164 | + |
| 165 | + prompt, images = llava_formater(prompt, order_image_tokens=False) |
| 166 | + |
| 167 | + for im in images: |
| 168 | + if len(im) == 0: |
| 169 | + raise RuntimeError("An image is empty!") |
| 170 | + |
| 171 | + return llava_call_binary( |
| 172 | + prompt, |
| 173 | + images, |
| 174 | + config_list=llm_config["config_list"], |
| 175 | + max_new_tokens=llm_config.get("max_new_tokens", 2000), |
| 176 | + temperature=llm_config.get("temperature", 0.5), |
| 177 | + seed=llm_config.get("seed", None), |
| 178 | + ) |
0 commit comments