Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support vision for openai #137

Merged
merged 12 commits into from
Aug 11, 2024
3 changes: 3 additions & 0 deletions agents/addon/extension/openai_chatgpt_python/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
},
"max_memory_length": {
"type": "int64"
},
"enable_vision": {
"type": "bool"
}
},
"data_in": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from rte.image_frame import ImageFrame
from .openai_chatgpt import OpenAIChatGPT, OpenAIChatGPTConfig
from datetime import datetime
from threading import Thread
Expand All @@ -20,6 +21,10 @@
MetadataInfo,
)
from .log import logger
from base64 import b64encode
import numpy as np
from io import BytesIO
from PIL import Image


CMD_IN_FLUSH = "flush"
Expand All @@ -39,6 +44,7 @@
PROPERTY_TOP_P = "top_p" # Optional
PROPERTY_MAX_TOKENS = "max_tokens" # Optional
PROPERTY_GREETING = "greeting" # Optional
PROPERTY_ENABLE_VISION = "enable_vision" # Optional
PROPERTY_PROXY_URL = "proxy_url" # Optional
PROPERTY_MAX_MEMORY_LENGTH = "max_memory_length" # Optional

Expand Down Expand Up @@ -73,11 +79,69 @@ def parse_sentence(sentence, content):
return sentence, remain, found_punc


def yuv420_to_rgb(yuv_data, width, height):
plutoless marked this conversation as resolved.
Show resolved Hide resolved
# Calculate the size of each plane
frame_size = width * height
chroma_size = frame_size // 4

y_plane = yuv_data[0:frame_size].reshape((height, width))
u_plane = yuv_data[frame_size:frame_size + chroma_size].reshape((height // 2, width // 2))
v_plane = yuv_data[frame_size + chroma_size:].reshape((height // 2, width // 2))

u_plane = u_plane.repeat(2, axis=0).repeat(2, axis=1)
v_plane = v_plane.repeat(2, axis=0).repeat(2, axis=1)

# Ensure calculations are done in a wider data type to prevent overflow
y_plane = y_plane.astype(np.int16)
u_plane = u_plane.astype(np.int16)
v_plane = v_plane.astype(np.int16)

# Convert YUV to RGB using the standard conversion formula
r_plane = y_plane + 1.402 * (v_plane - 128)
g_plane = y_plane - 0.344136 * (u_plane - 128) - 0.714136 * (v_plane - 128)
b_plane = y_plane + 1.772 * (u_plane - 128)

# Clip values to the 0-255 range and convert to uint8
r_plane = np.clip(r_plane, 0, 255).astype(np.uint8)
g_plane = np.clip(g_plane, 0, 255).astype(np.uint8)
b_plane = np.clip(b_plane, 0, 255).astype(np.uint8)

# Stack the RGB planes into an image
rgb_image = np.stack([r_plane, g_plane, b_plane], axis=-1)

return rgb_image

def yuv2base64png(yuv_data, width, height):
# Convert YUV to RGB
rgb_image = yuv420_to_rgb(np.frombuffer(yuv_data, dtype=np.uint8), width, height)

# Convert the RGB image to a PIL Image
pil_image = Image.fromarray(rgb_image)

# Save the image to a BytesIO object in PNG format
buffered = BytesIO()
pil_image.save(buffered, format="JPEG")

# Get the byte data of the PNG image
png_image_data = buffered.getvalue()

# Convert the PNG byte data to a Base64 encoded string
base64_encoded_image = b64encode(png_image_data).decode('utf-8')

# Create the data URL
mime_type = 'image/jpeg'
base64_url = f"data:{mime_type};base64,{base64_encoded_image}"
return base64_url

class OpenAIChatGPTExtension(Extension):
memory = []
max_memory_length = 10
outdate_ts = 0
openai_chatgpt = None
enable_vision = False
image_data = None
image_width = 0
image_height = 0

def on_init(
self, rte: RteEnv, manifest: MetadataInfo, property: MetadataInfo
Expand Down Expand Up @@ -168,6 +232,11 @@ def on_start(self, rte: RteEnv) -> None:
except Exception as err:
logger.info(f"GetProperty optional {PROPERTY_GREETING} failed, err: {err}")

try:
self.enable_vision = rte.get_property_bool(PROPERTY_ENABLE_VISION)
except Exception as err:
logger.info(f"GetProperty optional {PROPERTY_ENABLE_VISION} failed, err: {err}")

try:
prop_max_memory_length = rte.get_property_int(PROPERTY_MAX_MEMORY_LENGTH)
if prop_max_memory_length > 0:
Expand Down Expand Up @@ -233,6 +302,13 @@ def on_cmd(self, rte: RteEnv, cmd: Cmd) -> None:
cmd_result.set_property_string("detail", "success")
rte.return_result(cmd_result, cmd)

def on_image_frame(self, rte_env: RteEnv, image_frame: ImageFrame) -> None:
# logger.info(f"OpenAIChatGPTExtension on_image_frame {image_frame.get_width()} {image_frame.get_height()}")
self.image_data = image_frame.get_buf()
self.image_width = image_frame.get_width()
self.image_height = image_frame.get_height()
return

def on_data(self, rte: RteEnv, data: Data) -> None:
"""
on_data receives data from rte graph.
Expand Down Expand Up @@ -271,7 +347,22 @@ def on_data(self, rte: RteEnv, data: Data) -> None:
# Prepare memory
if len(self.memory) > self.max_memory_length:
self.memory.pop(0)
self.memory.append({"role": "user", "content": input_text})
if self.image_data is not None and self.enable_vision is True:
url = yuv2base64png(self.image_data, self.image_width, self.image_height)
# logger.info(f"image url: {url}")
self.memory.append({"role": "user", "content": [
{"type": "text", "text": input_text},
{
"type": "image_url",
"image_url": {
"url": url,
}
}
]})
# clear image after use
self.image_data = None
else:
self.memory.append({"role": "user", "content": input_text})

def chat_completions_stream_worker(start_time, input_text, memory):
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
openai==1.35.13
requests==2.32.3
requests==2.32.3
numpy==2.0.1
pillow==10.4.0