Skip to content

Commit

Permalink
feat(): add gemini llm extension (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshinexcode authored Aug 14, 2024
1 parent 3e348b6 commit 7a92c95
Show file tree
Hide file tree
Showing 12 changed files with 658 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ COSY_TTS_KEY=
# ElevenLabs TTS key
ELEVENLABS_TTS_KEY=

# Extension: gemini_llm
# Gemini API key
GEMINI_API_KEY=

# Extension: litellm
# Using Environment Variables, refer to https://docs.litellm.ai/docs/providers
# For example:
Expand Down
6 changes: 6 additions & 0 deletions agents/addon/extension/gemini_llm_python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import gemini_llm_addon
from .extension import EXTENSION_NAME
from .log import logger


logger.info(f"{EXTENSION_NAME} extension loaded")
1 change: 1 addition & 0 deletions agents/addon/extension/gemini_llm_python/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
EXTENSION_NAME = "gemini_llm_python"
54 changes: 54 additions & 0 deletions agents/addon/extension/gemini_llm_python/gemini_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Dict, List
import google.generativeai as genai


class GeminiLLMConfig:
def __init__(self,
api_key: str,
max_output_tokens: int,
model: str,
prompt: str,
temperature: float,
top_k: int,
top_p: float):
self.api_key = api_key
self.max_output_tokens = max_output_tokens
self.model = model
self.prompt = prompt
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p

@classmethod
def default_config(cls):
return cls(
api_key="",
max_output_tokens=512,
model="gemini-1.0-pro-latest",
prompt="You are a voice assistant who talks in a conversational way and can chat with me like my friends. I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. Don’t talk like a robot, instead I would like you to talk like a real human with emotions. I will use your answer for text-to-speech, so don’t return me any meaningless characters. I want you to be helpful, when I’m asking you for advice, give me precise, practical and useful advice instead of being vague. When giving me a list of options, express the options in a narrative way instead of bullet points.",
temperature=0.1,
top_k=40,
top_p=0.95,
)


class GeminiLLM:
def __init__(self, config: GeminiLLMConfig):
self.config = config
genai.configure(api_key=self.config.api_key)
self.model = genai.GenerativeModel(self.config.model)

def get_chat_completions_stream(self, messages: List[Dict[str, str]]):
try:
chat = self.model.start_chat(history=messages[0:-1])
response = chat.send_message((self.config.prompt, messages[-1].get("parts")),
generation_config=genai.types.GenerationConfig(
max_output_tokens=self.config.max_output_tokens,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p),
stream=True)

return response
except Exception as e:
raise Exception(f"get_chat_completions_stream failed, err: {e}")
23 changes: 23 additions & 0 deletions agents/addon/extension/gemini_llm_python/gemini_llm_addon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#
#
# Agora Real Time Engagement
# Created by XinHui Li in 2024.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from rte import (
Addon,
register_addon_as_extension,
RteEnv,
)
from .extension import EXTENSION_NAME
from .log import logger
from .gemini_llm_extension import GeminiLLMExtension


@register_addon_as_extension(EXTENSION_NAME)
class GeminiLLMExtensionAddon(Addon):
def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None:
logger.info("on_create_instance")

rte.on_create_instance_done(GeminiLLMExtension(addon_name), context)
240 changes: 240 additions & 0 deletions agents/addon/extension/gemini_llm_python/gemini_llm_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
#
#
# Agora Real Time Engagement
# Created by XinHui Li in 2024.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from threading import Thread
from rte import (
Extension,
RteEnv,
Cmd,
Data,
StatusCode,
CmdResult,
)
from .gemini_llm import GeminiLLM, GeminiLLMConfig
from .log import logger
from .utils import get_micro_ts, parse_sentence


CMD_IN_FLUSH = "flush"
CMD_OUT_FLUSH = "flush"
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT = "end_of_segment"

PROPERTY_API_KEY = "api_key" # Required
PROPERTY_GREETING = "greeting" # Optional
PROPERTY_MAX_MEMORY_LENGTH = "max_memory_length" # Optional
PROPERTY_MAX_OUTPUT_TOKENS = "max_output_tokens" # Optional
PROPERTY_MODEL = "model" # Optional
PROPERTY_PROMPT = "prompt" # Optional
PROPERTY_TEMPERATURE = "temperature" # Optional
PROPERTY_TOP_K = "top_k" # Optional
PROPERTY_TOP_P = "top_p" # Optional


class GeminiLLMExtension(Extension):
memory = []
max_memory_length = 10
outdate_ts = 0
gemini_llm = None

def on_start(self, rte: RteEnv) -> None:
logger.info("GeminiLLMExtension on_start")
# Prepare configuration
gemini_llm_config = GeminiLLMConfig.default_config()

try:
api_key = rte.get_property_string(PROPERTY_API_KEY)
gemini_llm_config.api_key = api_key
except Exception as err:
logger.info(f"GetProperty required {PROPERTY_API_KEY} failed, err: {err}")
return

for key in [PROPERTY_GREETING, PROPERTY_MODEL, PROPERTY_PROMPT]:
try:
val = rte.get_property_string(key)
if val:
gemini_llm_config.key = val
except Exception as e:
logger.warning(f"get_property_string optional {key} failed, err: {e}")

for key in [PROPERTY_TEMPERATURE, PROPERTY_TOP_P]:
try:
gemini_llm_config.key = float(rte.get_property_float(key))
except Exception as e:
logger.warning(f"get_property_float optional {key} failed, err: {e}")

for key in [PROPERTY_MAX_OUTPUT_TOKENS, PROPERTY_TOP_K]:
try:
gemini_llm_config.key = int(rte.get_property_int(key))
except Exception as e:
logger.warning(f"get_property_int optional {key} failed, err: {e}")

try:
prop_max_memory_length = rte.get_property_int(PROPERTY_MAX_MEMORY_LENGTH)
if prop_max_memory_length > 0:
self.max_memory_length = int(prop_max_memory_length)
except Exception as err:
logger.warning(f"GetProperty optional {PROPERTY_MAX_MEMORY_LENGTH} failed, err: {err}")

# Create GeminiLLM instance
self.gemini_llm = GeminiLLM(gemini_llm_config)
logger.info(f"newGeminiLLM succeed with max_output_tokens: {gemini_llm_config.max_output_tokens}, model: {gemini_llm_config.model}")

# Send greeting if available
greeting = rte.get_property_string(PROPERTY_GREETING)
if greeting:
try:
output_data = Data.create("text_data")
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting)
output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True)
rte.send_data(output_data)
logger.info(f"greeting [{greeting}] sent")
except Exception as e:
logger.error(f"greeting [{greeting}] send failed, err: {e}")

rte.on_start_done()

def on_stop(self, rte: RteEnv) -> None:
logger.info("GeminiLLMExtension on_stop")
rte.on_stop_done()

def on_cmd(self, rte: RteEnv, cmd: Cmd) -> None:
logger.info("GeminiLLMExtension on_cmd")
cmd_json = cmd.to_json()
logger.info(f"GeminiLLMExtension on_cmd json: {cmd_json}")

cmd_name = cmd.get_name()

if cmd_name == CMD_IN_FLUSH:
self.outdate_ts = get_micro_ts()
cmd_out = Cmd.create(CMD_OUT_FLUSH)
rte.send_cmd(cmd_out, None)
logger.info(f"GeminiLLMExtension on_cmd sent flush")
else:
logger.info(f"GeminiLLMExtension on_cmd unknown cmd: {cmd_name}")
cmd_result = CmdResult.create(StatusCode.ERROR)
cmd_result.set_property_string("detail", "unknown cmd")
rte.return_result(cmd_result, cmd)
return

cmd_result = CmdResult.create(StatusCode.OK)
cmd_result.set_property_string("detail", "success")
rte.return_result(cmd_result, cmd)

def on_data(self, rte: RteEnv, data: Data) -> None:
"""
on_data receives data from rte graph.
current supported data:
- name: text_data
example:
{name: text_data, properties: {text: "hello"}
"""
logger.info(f"GeminiLLMExtension on_data")

# Assume 'data' is an object from which we can get properties
try:
is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL)
if not is_final:
logger.info("ignore non-final input")
return
except Exception as e:
logger.error(f"on_data get_property_bool {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {e}")
return

# Get input text
try:
input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT)
if not input_text:
logger.info("ignore empty text")
return
logger.info(f"on_data input text: [{input_text}]")
except Exception as e:
logger.error(f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {e}")
return

# Prepare memory
if len(self.memory) > self.max_memory_length:
self.memory.pop(0)
self.memory.append({"role": "user", "parts": input_text})

def chat_completions_stream_worker(start_time, input_text, memory):
try:
logger.info(f"chat_completions_stream_worker for input text: [{input_text}] memory: {memory}")

# Get result from AI
resp = self.gemini_llm.get_chat_completions_stream(memory)
if resp is None:
logger.info(f"chat_completions_stream_worker for input text: [{input_text}] failed")
return

sentence = ""
full_content = ""
first_sentence_sent = False

for chat_completions in resp:
if start_time < self.outdate_ts:
logger.info(f"chat_completions_stream_worker recv interrupt and flushing for input text: [{input_text}], startTs: {start_time}, outdateTs: {self.outdate_ts}")
break

if (chat_completions.text is not None):
content = chat_completions.text
else:
content = ""

full_content += content

while True:
sentence, content, sentence_is_final = parse_sentence(sentence, content)

if len(sentence) == 0 or not sentence_is_final:
logger.info(f"sentence {sentence} is empty or not final")
break

logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] got sentence: [{sentence}]")

# send sentence
try:
output_data = Data.create("text_data")
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, False)
rte.send_data(output_data)
logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] sent sentence [{sentence}]")
except Exception as e:
logger.error(f"chat_completions_stream_worker recv for input text: [{input_text}] send sentence [{sentence}] failed, err: {e}")
break

sentence = ""
if not first_sentence_sent:
first_sentence_sent = True
logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] first sentence sent, first_sentence_latency {get_micro_ts() - start_time}ms")

# remember response as assistant content in memory
memory.append({"role": "model", "parts": full_content})

# send end of segment
try:
output_data = Data.create("text_data")
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True)
rte.send_data(output_data)
logger.info(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] sent")
except Exception as e:
logger.error(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] send failed, err: {e}")

except Exception as e:
logger.error(f"chat_completions_stream_worker for input text: [{input_text}] failed, err: {e}")

# Start thread to request and read responses from GeminiLLM
start_time = get_micro_ts()
thread = Thread(
target=chat_completions_stream_worker,
args=(start_time, input_text, self.memory),
)
thread.start()
logger.info(f"GeminiLLMExtension on_data end")
12 changes: 12 additions & 0 deletions agents/addon/extension/gemini_llm_python/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import logging
from .extension import EXTENSION_NAME

logger = logging.getLogger(EXTENSION_NAME)
logger.setLevel(logging.INFO)

formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s")

console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)

logger.addHandler(console_handler)
Loading

0 comments on commit 7a92c95

Please sign in to comment.