diff --git a/.env.example b/.env.example index 29f50432..17141464 100644 --- a/.env.example +++ b/.env.example @@ -51,6 +51,18 @@ COSY_TTS_KEY= # ElevenLabs TTS key ELEVENLABS_TTS_KEY= +# Extension: litellm +# Using Environment Variables, refer to https://docs.litellm.ai/docs/providers +# For example: +# OpenAI +# OPENAI_API_KEY= +# OPENAI_API_BASE= +# AWS Bedrock +# AWS_ACCESS_KEY_ID= +# AWS_SECRET_ACCESS_KEY= +# AWS_REGION_NAME= +LITELLM_MODEL=gpt-4o-mini + # Extension: openai_chatgpt # OpenAI API key OPENAI_API_KEY= diff --git a/agents/addon/extension/elevenlabs_tts_python/__init__.py b/agents/addon/extension/elevenlabs_tts_python/__init__.py index 8cf7e25f..d7fe7e64 100644 --- a/agents/addon/extension/elevenlabs_tts_python/__init__.py +++ b/agents/addon/extension/elevenlabs_tts_python/__init__.py @@ -1,5 +1,6 @@ from . import elevenlabs_tts_addon +from .extension import EXTENSION_NAME from .log import logger -logger.info("elevenlabs_tts_python extension loaded") +logger.info(f"{EXTENSION_NAME} extension loaded") diff --git a/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts.py b/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts.py index 6fe1b72c..282024fc 100644 --- a/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts.py +++ b/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts.py @@ -42,9 +42,7 @@ def default_elevenlabs_tts_config() -> ElevenlabsTTSConfig: class ElevenlabsTTS: def __init__(self, config: ElevenlabsTTSConfig) -> None: self.config = config - self.client = ElevenLabs( - api_key=config.api_key, timeout=config.request_timeout_seconds - ) + self.client = ElevenLabs(api_key=config.api_key, timeout=config.request_timeout_seconds) def text_to_speech_stream(self, text: str) -> Iterator[bytes]: audio_stream = self.client.generate( diff --git a/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts_addon.py b/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts_addon.py index 896e04d9..502fe867 100644 --- a/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts_addon.py +++ b/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts_addon.py @@ -11,10 +11,11 @@ register_addon_as_extension, RteEnv, ) +from .extension import EXTENSION_NAME from .log import logger -@register_addon_as_extension("elevenlabs_tts_python") +@register_addon_as_extension(EXTENSION_NAME) class ElevenlabsTTSExtensionAddon(Addon): def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None: logger.info("on_create_instance") diff --git a/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts_extension.py b/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts_extension.py index 1feadb44..52e8ca8a 100644 --- a/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts_extension.py +++ b/agents/addon/extension/elevenlabs_tts_python/elevenlabs_tts_extension.py @@ -11,15 +11,12 @@ import time from rte import ( - Addon, Extension, - register_addon_as_extension, RteEnv, Cmd, CmdResult, StatusCode, Data, - MetadataInfo, ) from .elevenlabs_tts import default_elevenlabs_tts_config, ElevenlabsTTS from .pcm import PcmConfig, Pcm @@ -62,9 +59,7 @@ def on_start(self, rte: RteEnv) -> None: try: elevenlabs_tts_config.api_key = rte.get_property_string(PROPERTY_API_KEY) except Exception as e: - logger.warning( - f"on_start get_property_string {PROPERTY_API_KEY} error: {e}" - ) + logger.warning(f"on_start get_property_string {PROPERTY_API_KEY} error: {e}") return try: @@ -72,58 +67,36 @@ def on_start(self, rte: RteEnv) -> None: if len(model_id) > 0: elevenlabs_tts_config.model_id = model_id except Exception as e: - logger.warning( - f"on_start get_property_string {PROPERTY_MODEL_ID} error: {e}" - ) + logger.warning(f"on_start get_property_string {PROPERTY_MODEL_ID} error: {e}") try: - optimize_streaming_latency = rte.get_property_int( - PROPERTY_OPTIMIZE_STREAMING_LATENCY - ) + optimize_streaming_latency = rte.get_property_int(PROPERTY_OPTIMIZE_STREAMING_LATENCY) if optimize_streaming_latency > 0: - elevenlabs_tts_config.optimize_streaming_latency = ( - optimize_streaming_latency - ) + elevenlabs_tts_config.optimize_streaming_latency = optimize_streaming_latency except Exception as e: - logger.warning( - f"on_start get_property_int {PROPERTY_OPTIMIZE_STREAMING_LATENCY} error: {e}" - ) + logger.warning(f"on_start get_property_int {PROPERTY_OPTIMIZE_STREAMING_LATENCY} error: {e}") try: - request_timeout_seconds = rte.get_property_int( - PROPERTY_REQUEST_TIMEOUT_SECONDS - ) + request_timeout_seconds = rte.get_property_int(PROPERTY_REQUEST_TIMEOUT_SECONDS) if request_timeout_seconds > 0: elevenlabs_tts_config.request_timeout_seconds = request_timeout_seconds except Exception as e: - logger.warning( - f"on_start get_property_int {PROPERTY_REQUEST_TIMEOUT_SECONDS} error: {e}" - ) + logger.warning(f"on_start get_property_int {PROPERTY_REQUEST_TIMEOUT_SECONDS} error: {e}") try: - elevenlabs_tts_config.similarity_boost = rte.get_property_float( - PROPERTY_SIMILARITY_BOOST - ) + elevenlabs_tts_config.similarity_boost = rte.get_property_float(PROPERTY_SIMILARITY_BOOST) except Exception as e: - logger.warning( - f"on_start get_property_float {PROPERTY_SIMILARITY_BOOST} error: {e}" - ) + logger.warning(f"on_start get_property_float {PROPERTY_SIMILARITY_BOOST} error: {e}") try: - elevenlabs_tts_config.speaker_boost = rte.get_property_bool( - PROPERTY_SPEAKER_BOOST - ) + elevenlabs_tts_config.speaker_boost = rte.get_property_bool(PROPERTY_SPEAKER_BOOST) except Exception as e: - logger.warning( - f"on_start get_property_bool {PROPERTY_SPEAKER_BOOST} error: {e}" - ) + logger.warning(f"on_start get_property_bool {PROPERTY_SPEAKER_BOOST} error: {e}") try: elevenlabs_tts_config.stability = rte.get_property_float(PROPERTY_STABILITY) except Exception as e: - logger.warning( - f"on_start get_property_float {PROPERTY_STABILITY} error: {e}" - ) + logger.warning(f"on_start get_property_float {PROPERTY_STABILITY} error: {e}") try: elevenlabs_tts_config.style = rte.get_property_float(PROPERTY_STYLE) @@ -133,9 +106,7 @@ def on_start(self, rte: RteEnv) -> None: # create elevenlabsTTS instance self.elevenlabs_tts = ElevenlabsTTS(elevenlabs_tts_config) - logger.info( - f"ElevenlabsTTS succeed with model_id: {self.elevenlabs_tts.config.model_id}, VoiceId: {self.elevenlabs_tts.config.voice_id}" - ) + logger.info(f"ElevenlabsTTS succeed with model_id: {self.elevenlabs_tts.config.model_id}, VoiceId: {self.elevenlabs_tts.config.voice_id}") # create pcm instance self.pcm = Pcm(PcmConfig()) @@ -186,9 +157,7 @@ def on_data(self, rte: RteEnv, data: Data) -> None: try: text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT) except Exception as e: - logger.warning( - f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} error: {e}" - ) + logger.warning(f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} error: {e}") return if len(text) == 0: @@ -207,9 +176,7 @@ def process_text_queue(self, rte: RteEnv): logger.debug(f"process_text_queue, text: [{msg.text}]") if msg.received_ts < self.outdate_ts: - logger.info( - f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}" - ) + logger.info(f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}") continue start_time = time.time() @@ -224,9 +191,7 @@ def process_text_queue(self, rte: RteEnv): for chunk in self.pcm.read_pcm_stream(audio_stream, self.pcm_frame_size): if msg.received_ts < self.outdate_ts: - logger.info( - f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}" - ) + logger.info(f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}") break if not chunk: @@ -238,9 +203,7 @@ def process_text_queue(self, rte: RteEnv): pcm_frame_read += n if pcm_frame_read != self.pcm.get_pcm_frame_size(): - logger.debug( - f"the number of bytes read is [{pcm_frame_read}] inconsistent with pcm frame size", - ) + logger.debug(f"the number of bytes read is [{pcm_frame_read}] inconsistent with pcm frame size") continue self.pcm.send(rte, buf) @@ -250,28 +213,16 @@ def process_text_queue(self, rte: RteEnv): if first_frame_latency == 0: first_frame_latency = int((time.time() - start_time) * 1000) - logger.info( - f"first frame available for text: [{msg.text}], received_ts: {msg.received_ts}, first_frame_latency: {first_frame_latency}ms", - ) + logger.info(f"first frame available for text: [{msg.text}], received_ts: {msg.received_ts}, first_frame_latency: {first_frame_latency}ms") logger.debug(f"sending pcm data, text: [{msg.text}]") if pcm_frame_read > 0: self.pcm.send(rte, buf) sent_frames += 1 - logger.info( - f"sending pcm remain data, text: [{msg.text}], pcm_frame_read: {pcm_frame_read}" - ) + logger.info(f"sending pcm remain data, text: [{msg.text}], pcm_frame_read: {pcm_frame_read}") finish_latency = int((time.time() - start_time) * 1000) - logger.info( - f"send pcm data finished, text: [{msg.text}], received_ts: {msg.received_ts}, read_bytes: {read_bytes}, sent_frames: {sent_frames}, \ + logger.info(f"send pcm data finished, text: [{msg.text}], received_ts: {msg.received_ts}, read_bytes: {read_bytes}, sent_frames: {sent_frames}, first_frame_latency: {first_frame_latency}ms, finish_latency: {finish_latency}ms" ) - - -@register_addon_as_extension("elevenlabs_tts_python") -class ElevenlabsTTSExtensionAddon(Addon): - def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None: - logger.info("on_create_instance") - rte.on_create_instance_done(ElevenlabsTTSExtension(addon_name), context) diff --git a/agents/addon/extension/elevenlabs_tts_python/extension.py b/agents/addon/extension/elevenlabs_tts_python/extension.py new file mode 100644 index 00000000..b634e157 --- /dev/null +++ b/agents/addon/extension/elevenlabs_tts_python/extension.py @@ -0,0 +1 @@ +EXTENSION_NAME = "elevenlabs_tts_python" diff --git a/agents/addon/extension/elevenlabs_tts_python/log.py b/agents/addon/extension/elevenlabs_tts_python/log.py index 54f870f3..fad21710 100644 --- a/agents/addon/extension/elevenlabs_tts_python/log.py +++ b/agents/addon/extension/elevenlabs_tts_python/log.py @@ -1,11 +1,10 @@ import logging +from .extension import EXTENSION_NAME -logger = logging.getLogger("elevenlabs_tts_python") +logger = logging.getLogger(EXTENSION_NAME) logger.setLevel(logging.INFO) -formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s" -) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s") console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) diff --git a/agents/addon/extension/elevenlabs_tts_python/pcm.py b/agents/addon/extension/elevenlabs_tts_python/pcm.py index 67a60f21..713b74b1 100644 --- a/agents/addon/extension/elevenlabs_tts_python/pcm.py +++ b/agents/addon/extension/elevenlabs_tts_python/pcm.py @@ -22,9 +22,7 @@ def get_pcm_frame(self, buf: memoryview) -> PcmFrame: frame.set_number_of_channels(self.config.num_channels) frame.set_timestamp(self.config.timestamp) frame.set_data_fmt(PcmFrameDataFmt.INTERLEAVE) - frame.set_samples_per_channel( - self.config.samples_per_channel // self.config.channel - ) + frame.set_samples_per_channel(self.config.samples_per_channel // self.config.channel) frame.alloc_buf(self.get_pcm_frame_size()) frame_buf = frame.lock_buf() @@ -35,24 +33,19 @@ def get_pcm_frame(self, buf: memoryview) -> PcmFrame: return frame def get_pcm_frame_size(self) -> int: - return ( - self.config.samples_per_channel - * self.config.channel - * self.config.bytes_per_sample - ) + return (self.config.samples_per_channel * self.config.channel * self.config.bytes_per_sample) def new_buf(self) -> bytearray: return bytearray(self.get_pcm_frame_size()) - def read_pcm_stream( - self, stream: Iterator[bytes], chunk_size: int - ) -> Iterator[bytes]: + def read_pcm_stream(self, stream: Iterator[bytes], chunk_size: int) -> Iterator[bytes]: chunk = b"" for data in stream: chunk += data while len(chunk) >= chunk_size: yield chunk[:chunk_size] chunk = chunk[chunk_size:] + if chunk: yield chunk diff --git a/agents/addon/extension/litellm_python/__init__.py b/agents/addon/extension/litellm_python/__init__.py new file mode 100644 index 00000000..642a2664 --- /dev/null +++ b/agents/addon/extension/litellm_python/__init__.py @@ -0,0 +1,6 @@ +from . import litellm_addon +from .extension import EXTENSION_NAME +from .log import logger + + +logger.info(f"{EXTENSION_NAME} extension loaded") diff --git a/agents/addon/extension/litellm_python/extension.py b/agents/addon/extension/litellm_python/extension.py new file mode 100644 index 00000000..a04c95ad --- /dev/null +++ b/agents/addon/extension/litellm_python/extension.py @@ -0,0 +1 @@ +EXTENSION_NAME = "litellm_python" diff --git a/agents/addon/extension/litellm_python/litellm.py b/agents/addon/extension/litellm_python/litellm.py new file mode 100644 index 00000000..b056bc69 --- /dev/null +++ b/agents/addon/extension/litellm_python/litellm.py @@ -0,0 +1,79 @@ +import litellm +import random +from typing import Dict, List, Optional + + +class LiteLLMConfig: + def __init__(self, + api_key: str, + base_url: str, + frequency_penalty: float, + max_tokens: int, + model: str, + presence_penalty: float, + prompt: str, + provider: str, + temperature: float, + top_p: float, + seed: Optional[int] = None,): + self.api_key = api_key + self.base_url = base_url + self.frequency_penalty = frequency_penalty + self.max_tokens = max_tokens + self.model = model + self.presence_penalty = presence_penalty + self.prompt = prompt + self.provider = provider + self.seed = seed if seed is not None else random.randint(0, 10000) + self.temperature = temperature + self.top_p = top_p + + @classmethod + def default_config(cls): + return cls( + api_key="", + base_url="", + max_tokens=512, + model="gpt-4o-mini", + frequency_penalty=0.9, + presence_penalty=0.9, + prompt="You are a voice assistant who talks in a conversational way and can chat with me like my friends. I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. Don’t talk like a robot, instead I would like you to talk like a real human with emotions. I will use your answer for text-to-speech, so don’t return me any meaningless characters. I want you to be helpful, when I’m asking you for advice, give me precise, practical and useful advice instead of being vague. When giving me a list of options, express the options in a narrative way instead of bullet points.", + provider="", + seed=random.randint(0, 10000), + temperature=0.1, + top_p=1.0 + ) + + +class LiteLLM: + def __init__(self, config: LiteLLMConfig): + self.config = config + + def get_chat_completions_stream(self, messages: List[Dict[str, str]]): + kwargs = { + "api_key": self.config.api_key, + "base_url": self.config.base_url, + "custom_llm_provider": self.config.provider, + "frequency_penalty": self.config.frequency_penalty, + "max_tokens": self.config.max_tokens, + "messages": [ + { + "role": "system", + "content": self.config.prompt, + }, + *messages, + ], + "model": self.config.model, + "presence_penalty": self.config.presence_penalty, + "seed": self.config.seed, + "stream": True, + "temperature": self.config.temperature, + "top_p": self.config.top_p, + } + + try: + response = litellm.completion(**kwargs) + + return response + except Exception as e: + raise Exception(f"get_chat_completions_stream failed, err: {e}") diff --git a/agents/addon/extension/litellm_python/litellm_addon.py b/agents/addon/extension/litellm_python/litellm_addon.py new file mode 100644 index 00000000..460ae642 --- /dev/null +++ b/agents/addon/extension/litellm_python/litellm_addon.py @@ -0,0 +1,23 @@ +# +# +# Agora Real Time Engagement +# Created by XinHui Li in 2024. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +from rte import ( + Addon, + register_addon_as_extension, + RteEnv, +) +from .extension import EXTENSION_NAME +from .log import logger +from .litellm_extension import LiteLLMExtension + + +@register_addon_as_extension(EXTENSION_NAME) +class LiteLLMExtensionAddon(Addon): + def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None: + logger.info("on_create_instance") + + rte.on_create_instance_done(LiteLLMExtension(addon_name), context) diff --git a/agents/addon/extension/litellm_python/litellm_extension.py b/agents/addon/extension/litellm_python/litellm_extension.py new file mode 100644 index 00000000..fb007856 --- /dev/null +++ b/agents/addon/extension/litellm_python/litellm_extension.py @@ -0,0 +1,229 @@ +# +# +# Agora Real Time Engagement +# Created by XinHui Li in 2024. +# Copyright (c) 2024 Agora IO. All rights reserved. +# +# +from threading import Thread +from rte import ( + Extension, + RteEnv, + Cmd, + Data, + StatusCode, + CmdResult, +) +from .litellm import LiteLLM, LiteLLMConfig +from .log import logger +from .utils import get_micro_ts, parse_sentence + + +CMD_IN_FLUSH = "flush" +CMD_OUT_FLUSH = "flush" +DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" +DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" +DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" +DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT = "end_of_segment" + +PROPERTY_API_KEY = "api_key" # Required +PROPERTY_BASE_URL = "base_url" # Optional +PROPERTY_FREQUENCY_PENALTY = "frequency_penalty" # Optional +PROPERTY_GREETING = "greeting" # Optional +PROPERTY_MAX_MEMORY_LENGTH = "max_memory_length" # Optional +PROPERTY_MAX_TOKENS = "max_tokens" # Optional +PROPERTY_MODEL = "model" # Optional +PROPERTY_PRESENCE_PENALTY = "presence_penalty" # Optional +PROPERTY_PROMPT = "prompt" # Optional +PROPERTY_PROVIDER = "provider" # Optional +PROPERTY_TEMPERATURE = "temperature" # Optional +PROPERTY_TOP_P = "top_p" # Optional + + +class LiteLLMExtension(Extension): + memory = [] + max_memory_length = 10 + outdate_ts = 0 + litellm = None + + def on_start(self, rte: RteEnv) -> None: + logger.info("LiteLLMExtension on_start") + # Prepare configuration + litellm_config = LiteLLMConfig.default_config() + + for key in [PROPERTY_API_KEY, PROPERTY_GREETING, PROPERTY_MODEL, PROPERTY_PROMPT]: + try: + val = rte.get_property_string(key) + if val: + litellm_config.key = val + except Exception as e: + logger.warning(f"get_property_string optional {key} failed, err: {e}") + + for key in [PROPERTY_FREQUENCY_PENALTY, PROPERTY_PRESENCE_PENALTY, PROPERTY_TEMPERATURE, PROPERTY_TOP_P]: + try: + litellm_config.key = float(rte.get_property_float(key)) + except Exception as e: + logger.warning(f"get_property_float optional {key} failed, err: {e}") + + for key in [PROPERTY_MAX_MEMORY_LENGTH, PROPERTY_MAX_TOKENS]: + try: + litellm_config.key = int(rte.get_property_int(key)) + except Exception as e: + logger.warning(f"get_property_int optional {key} failed, err: {e}") + + # Create LiteLLM instance + self.litellm = LiteLLM(litellm_config) + logger.info(f"newLiteLLM succeed with max_tokens: {litellm_config.max_tokens}, model: {litellm_config.model}") + + # Send greeting if available + greeting = rte.get_property_string(PROPERTY_GREETING) + if greeting: + try: + output_data = Data.create("text_data") + output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting) + output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True) + rte.send_data(output_data) + logger.info(f"greeting [{greeting}] sent") + except Exception as e: + logger.error(f"greeting [{greeting}] send failed, err: {e}") + + rte.on_start_done() + + def on_stop(self, rte: RteEnv) -> None: + logger.info("LiteLLMExtension on_stop") + rte.on_stop_done() + + def on_cmd(self, rte: RteEnv, cmd: Cmd) -> None: + logger.info("LiteLLMExtension on_cmd") + cmd_json = cmd.to_json() + logger.info(f"LiteLLMExtension on_cmd json: {cmd_json}") + + cmd_name = cmd.get_name() + + if cmd_name == CMD_IN_FLUSH: + self.outdate_ts = get_micro_ts() + cmd_out = Cmd.create(CMD_OUT_FLUSH) + rte.send_cmd(cmd_out, None) + logger.info(f"LiteLLMExtension on_cmd sent flush") + else: + logger.info(f"LiteLLMExtension on_cmd unknown cmd: {cmd_name}") + cmd_result = CmdResult.create(StatusCode.ERROR) + cmd_result.set_property_string("detail", "unknown cmd") + rte.return_result(cmd_result, cmd) + return + + cmd_result = CmdResult.create(StatusCode.OK) + cmd_result.set_property_string("detail", "success") + rte.return_result(cmd_result, cmd) + + def on_data(self, rte: RteEnv, data: Data) -> None: + """ + on_data receives data from rte graph. + current supported data: + - name: text_data + example: + {name: text_data, properties: {text: "hello"} + """ + logger.info(f"LiteLLMExtension on_data") + + # Assume 'data' is an object from which we can get properties + try: + is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL) + if not is_final: + logger.info("ignore non-final input") + return + except Exception as e: + logger.error(f"on_data get_property_bool {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {e}") + return + + # Get input text + try: + input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT) + if not input_text: + logger.info("ignore empty text") + return + logger.info(f"on_data input text: [{input_text}]") + except Exception as e: + logger.error(f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {e}") + return + + # Prepare memory + if len(self.memory) > self.max_memory_length: + self.memory.pop(0) + self.memory.append({"role": "user", "content": input_text}) + + def chat_completions_stream_worker(start_time, input_text, memory): + try: + logger.info(f"chat_completions_stream_worker for input text: [{input_text}] memory: {memory}") + + # Get result from AI + resp = self.litellm.get_chat_completions_stream(memory) + if resp is None: + logger.info(f"chat_completions_stream_worker for input text: [{input_text}] failed") + return + + sentence = "" + full_content = "" + first_sentence_sent = False + + for chat_completions in resp: + if start_time < self.outdate_ts: + logger.info(f"chat_completions_stream_worker recv interrupt and flushing for input text: [{input_text}], startTs: {start_time}, outdateTs: {self.outdate_ts}") + break + + if (len(chat_completions.choices) > 0 and chat_completions.choices[0].delta.content is not None): + content = chat_completions.choices[0].delta.content + else: + content = "" + + full_content += content + + while True: + sentence, content, sentence_is_final = parse_sentence(sentence, content) + + if len(sentence) == 0 or not sentence_is_final: + logger.info(f"sentence {sentence} is empty or not final") + break + + logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] got sentence: [{sentence}]") + + # send sentence + try: + output_data = Data.create("text_data") + output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence) + output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, False) + rte.send_data(output_data) + logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] sent sentence [{sentence}]") + except Exception as e: + logger.error(f"chat_completions_stream_worker recv for input text: [{input_text}] send sentence [{sentence}] failed, err: {e}") + break + + sentence = "" + if not first_sentence_sent: + first_sentence_sent = True + logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] first sentence sent, first_sentence_latency {get_micro_ts() - start_time}ms") + + # remember response as assistant content in memory + memory.append({"role": "assistant", "content": full_content}) + + # send end of segment + try: + output_data = Data.create("text_data") + output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence) + output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True) + rte.send_data(output_data) + logger.info(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] sent") + except Exception as e: + logger.error(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] send failed, err: {e}") + + except Exception as e: + logger.error(f"chat_completions_stream_worker for input text: [{input_text}] failed, err: {e}") + + # Start thread to request and read responses from LiteLLM + start_time = get_micro_ts() + thread = Thread( + target=chat_completions_stream_worker, + args=(start_time, input_text, self.memory), + ) + thread.start() + logger.info(f"LiteLLMExtension on_data end") diff --git a/agents/addon/extension/litellm_python/log.py b/agents/addon/extension/litellm_python/log.py new file mode 100644 index 00000000..fad21710 --- /dev/null +++ b/agents/addon/extension/litellm_python/log.py @@ -0,0 +1,12 @@ +import logging +from .extension import EXTENSION_NAME + +logger = logging.getLogger(EXTENSION_NAME) +logger.setLevel(logging.INFO) + +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s") + +console_handler = logging.StreamHandler() +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) diff --git a/agents/addon/extension/litellm_python/manifest.json b/agents/addon/extension/litellm_python/manifest.json new file mode 100644 index 00000000..16fb4429 --- /dev/null +++ b/agents/addon/extension/litellm_python/manifest.json @@ -0,0 +1,83 @@ +{ + "type": "extension", + "name": "litellm_python", + "version": "0.1.0", + "language": "python", + "dependencies": [ + { + "type": "system", + "name": "rte_runtime_python", + "version": "0.4.0" + } + ], + "api": { + "property": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "frequency_penalty": { + "type": "float64" + }, + "greeting": { + "type": "string" + }, + "max_memory_length": { + "type": "int64" + }, + "max_tokens": { + "type": "int64" + }, + "model": { + "type": "string" + }, + "presence_penalty": { + "type": "float64" + }, + "prompt": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "temperature": { + "type": "float64" + }, + "top_p": { + "type": "float64" + } + }, + "data_in": [ + { + "name": "text_data", + "property": { + "text": { + "type": "string" + } + } + } + ], + "data_out": [ + { + "name": "text_data", + "property": { + "text": { + "type": "string" + } + } + } + ], + "cmd_in": [ + { + "name": "flush" + } + ], + "cmd_out": [ + { + "name": "flush" + } + ] + } +} diff --git a/agents/addon/extension/litellm_python/requirements.txt b/agents/addon/extension/litellm_python/requirements.txt new file mode 100644 index 00000000..2f7601ef --- /dev/null +++ b/agents/addon/extension/litellm_python/requirements.txt @@ -0,0 +1 @@ +litellm==1.42.12 \ No newline at end of file diff --git a/agents/addon/extension/litellm_python/utils.py b/agents/addon/extension/litellm_python/utils.py new file mode 100644 index 00000000..e387d906 --- /dev/null +++ b/agents/addon/extension/litellm_python/utils.py @@ -0,0 +1,19 @@ +import time + + +def get_micro_ts(): + return int(time.time() * 1_000_000) + + +def is_punctuation(char: str): + return char in [",", ",", ".", "。", "?", "?", "!", "!"] + + +def parse_sentence(sentence: str, content: str): + for i, char in enumerate(content): + sentence += char + + if is_punctuation(char): + return sentence, content[i + 1:], True + + return sentence, "", False diff --git a/agents/addon/extension/polly_tts/__init__.py b/agents/addon/extension/polly_tts/__init__.py index 0f60ee72..50d3bf2c 100644 --- a/agents/addon/extension/polly_tts/__init__.py +++ b/agents/addon/extension/polly_tts/__init__.py @@ -1,3 +1,5 @@ -from . import main +from . import polly_tts_addon +from .extension import EXTENSION_NAME +from .log import logger -print("polly_tts_python extension loaded") +logger.info(f"{EXTENSION_NAME} extension loaded") diff --git a/agents/addon/extension/polly_tts/extension.py b/agents/addon/extension/polly_tts/extension.py new file mode 100644 index 00000000..63cb3dab --- /dev/null +++ b/agents/addon/extension/polly_tts/extension.py @@ -0,0 +1 @@ +EXTENSION_NAME = "polly_tts" diff --git a/agents/addon/extension/polly_tts/log.py b/agents/addon/extension/polly_tts/log.py index dd7d68cf..fad21710 100644 --- a/agents/addon/extension/polly_tts/log.py +++ b/agents/addon/extension/polly_tts/log.py @@ -1,11 +1,10 @@ import logging +from .extension import EXTENSION_NAME -logger = logging.getLogger("polly_tts_python") +logger = logging.getLogger(EXTENSION_NAME) logger.setLevel(logging.INFO) -formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s" -) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s") console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) diff --git a/agents/addon/extension/polly_tts/polly_tts_addon.py b/agents/addon/extension/polly_tts/polly_tts_addon.py new file mode 100644 index 00000000..92047df3 --- /dev/null +++ b/agents/addon/extension/polly_tts/polly_tts_addon.py @@ -0,0 +1,15 @@ +from rte import ( + Addon, + register_addon_as_extension, + RteEnv, +) +from .extension import EXTENSION_NAME +from .log import logger +from .polly_tts_extension import PollyTTSExtension + + +@register_addon_as_extension(EXTENSION_NAME) +class PollyTTSExtensionAddon(Addon): + def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None: + logger.info("on_create_instance") + rte.on_create_instance_done(PollyTTSExtension(addon_name), context) diff --git a/agents/addon/extension/polly_tts/main.py b/agents/addon/extension/polly_tts/polly_tts_extension.py similarity index 80% rename from agents/addon/extension/polly_tts/main.py rename to agents/addon/extension/polly_tts/polly_tts_extension.py index 2fea0cee..63576ee9 100644 --- a/agents/addon/extension/polly_tts/main.py +++ b/agents/addon/extension/polly_tts/polly_tts_extension.py @@ -1,11 +1,9 @@ -from rte_runtime_python import ( - Addon, +from rte import ( Extension, - register_addon_as_extension, - Rte, + RteEnv, Cmd, PcmFrame, - RTE_PCM_FRAME_DATA_FMT, + PcmFrameDataFmt, Data, StatusCode, CmdResult, @@ -25,8 +23,8 @@ PROPERTY_ACCESS_KEY = "access_key" # Optional PROPERTY_SECRET_KEY = "secret_key" # Optional PROPERTY_ENGINE = 'engine' # Optional -PROPERTY_VOICE = 'voice' # Optional -PROPERTY_SAMPLE_RATE = 'sample_rate'# Optional +PROPERTY_VOICE = 'voice' # Optional +PROPERTY_SAMPLE_RATE = 'sample_rate' # Optional PROPERTY_LANG_CODE = 'lang_code' # Optional @@ -44,19 +42,19 @@ def __init__(self, name: str): self.number_of_channels = 1 def on_init( - self, rte: Rte, manifest: MetadataInfo, property: MetadataInfo + self, rte: RteEnv, manifest: MetadataInfo, property: MetadataInfo ) -> None: logger.info("PollyTTSExtension on_init") rte.on_init_done(manifest, property) - def on_start(self, rte: Rte) -> None: + def on_start(self, rte: RteEnv) -> None: logger.info("PollyTTSExtension on_start") polly_config = PollyConfig.default_config() for optional_param in [PROPERTY_REGION, PROPERTY_ENGINE, PROPERTY_VOICE, - PROPERTY_SAMPLE_RATE, PROPERTY_LANG_CODE, - PROPERTY_ACCESS_KEY, PROPERTY_SECRET_KEY ]: + PROPERTY_SAMPLE_RATE, PROPERTY_LANG_CODE, + PROPERTY_ACCESS_KEY, PROPERTY_SECRET_KEY]: try: value = rte.get_property_string(optional_param).strip() if value: @@ -71,7 +69,7 @@ def on_start(self, rte: Rte) -> None: self.thread.start() rte.on_start_done() - def on_stop(self, rte: Rte) -> None: + def on_stop(self, rte: RteEnv) -> None: logger.info("PollyTTSExtension on_stop") self.stopped = True @@ -83,7 +81,6 @@ def on_stop(self, rte: Rte) -> None: def need_interrupt(self, ts: datetime.time) -> bool: return (self.outdateTs - ts).total_seconds() > 1 - def __get_frame(self, data: bytes) -> PcmFrame: sample_rate = int(self.polly.config.sample_rate) @@ -92,17 +89,17 @@ def __get_frame(self, data: bytes) -> PcmFrame: f.set_bytes_per_sample(2) f.set_number_of_channels(1) - f.set_data_fmt(RTE_PCM_FRAME_DATA_FMT.RTE_PCM_FRAME_DATA_FMT_INTERLEAVE) + f.set_data_fmt(PcmFrameDataFmt.INTERLEAVE) f.set_samples_per_channel(sample_rate // 100) f.alloc_buf(self.frame_size) buff = f.lock_buf() if len(data) < self.frame_size: - buff[:] = bytes(self.frame_size) #fill with 0 + buff[:] = bytes(self.frame_size) # fill with 0 buff[:len(data)] = data f.unlock_buf(buff) return f - def async_polly_handler(self, rte: Rte): + def async_polly_handler(self, rte: RteEnv): while not self.stopped: value = self.queue.get() if value is None: @@ -123,8 +120,8 @@ def async_polly_handler(self, rte: Rte): f = self.__get_frame(chunk) rte.send_pcm_frame(f) except Exception as e: - logger.exception(e) - logger.exception(traceback.format_exc()) + logger.exception(e) + logger.exception(traceback.format_exc()) def flush(self): logger.info("PollyTTSExtension flush") @@ -132,7 +129,7 @@ def flush(self): self.queue.get() self.queue.put(("", datetime.now())) - def on_data(self, rte: Rte, data: Data) -> None: + def on_data(self, rte: RteEnv, data: Data) -> None: logger.info("PollyTTSExtension on_data") inputText = data.get_property_string("text") if len(inputText) == 0: @@ -144,7 +141,7 @@ def on_data(self, rte: Rte, data: Data) -> None: logger.info("on data %s %d", inputText, is_end) self.queue.put((inputText, datetime.now())) - def on_cmd(self, rte: Rte, cmd: Cmd) -> None: + def on_cmd(self, rte: RteEnv, cmd: Cmd) -> None: logger.info("PollyTTSExtension on_cmd") cmd_json = cmd.to_json() logger.info("PollyTTSExtension on_cmd json: %s" + cmd_json) @@ -161,9 +158,3 @@ def on_cmd(self, rte: Rte, cmd: Cmd) -> None: cmd_result = CmdResult.create(StatusCode.OK) cmd_result.set_property_string("detail", "success") rte.return_result(cmd_result, cmd) - -@register_addon_as_extension("polly_tts") -class PollyTTSExtensionAddon(Addon): - def on_create_instance(self, rte: Rte, addon_name: str, context) -> None: - logger.info("on_create_instance") - rte.on_create_instance_done(PollyTTSExtension(addon_name), context) diff --git a/agents/property.json.example b/agents/property.json.example index e3af5b44..c84ad926 100644 --- a/agents/property.json.example +++ b/agents/property.json.example @@ -1165,6 +1165,228 @@ ] } ] + }, + { + "name": "va.litellm.azure", + "auto_start": true, + "nodes": [ + { + "type": "extension", + "extension_group": "default", + "addon": "agora_rtc", + "name": "agora_rtc", + "property": { + "app_id": "", + "token": "", + "channel": "astra_agents_test", + "stream_id": 1234, + "remote_stream_id": 123, + "subscribe_audio": true, + "publish_audio": true, + "publish_data": true, + "enable_agora_asr": true, + "agora_asr_vendor_name": "microsoft", + "agora_asr_language": "en-US", + "agora_asr_vendor_key": "", + "agora_asr_vendor_region": "", + "agora_asr_session_control_file_path": "session_control.conf" + } + }, + { + "type": "extension", + "extension_group": "default", + "addon": "interrupt_detector", + "name": "interrupt_detector" + }, + { + "type": "extension", + "extension_group": "llm", + "addon": "litellm_python", + "name": "litellm", + "property": { + "api_key": "", + "base_url": "", + "greeting": "ASTRA agent connected. How can i help you today?", + "frequency_penalty": 0.9, + "max_memory_length": 10, + "max_tokens": 512, + "model": "gpt-3.5-turbo", + "presence_penalty": 0.9, + "prompt": "", + "provider": "", + "temperature": 0.1, + "top_p": 1.0 + } + }, + { + "type": "extension", + "extension_group": "tts", + "addon": "azure_tts", + "name": "azure_tts", + "property": { + "azure_subscription_key": "", + "azure_subscription_region": "", + "azure_synthesis_voice_name": "en-US-JaneNeural" + } + }, + { + "type": "extension", + "extension_group": "transcriber", + "addon": "chat_transcriber", + "name": "chat_transcriber" + }, + { + "type": "extension_group", + "addon": "default_extension_group", + "name": "default" + }, + { + "type": "extension_group", + "addon": "default_extension_group", + "name": "llm" + }, + { + "type": "extension_group", + "addon": "default_extension_group", + "name": "tts" + }, + { + "type": "extension_group", + "addon": "default_extension_group", + "name": "transcriber" + } + ], + "connections": [ + { + "extension_group": "default", + "extension": "agora_rtc", + "data": [ + { + "name": "text_data", + "dest": [ + { + "extension_group": "default", + "extension": "interrupt_detector" + }, + { + "extension_group": "llm", + "extension": "litellm" + }, + { + "extension_group": "transcriber", + "extension": "chat_transcriber" + } + ] + } + ] + }, + { + "extension_group": "llm", + "extension": "litellm", + "data": [ + { + "name": "text_data", + "dest": [ + { + "extension_group": "tts", + "extension": "azure_tts" + }, + { + "extension_group": "transcriber", + "extension": "chat_transcriber", + "cmd_conversions": [ + { + "cmd": { + "type": "per_property", + "keep_original": true, + "rules": [ + { + "path": "is_final", + "type": "fixed_value", + "value": "bool(true)" + }, + { + "path": "stream_id", + "type": "fixed_value", + "value": "uint32(999)" + } + ] + } + } + ] + } + ] + } + ], + "cmd": [ + { + "name": "flush", + "dest": [ + { + "extension_group": "tts", + "extension": "azure_tts" + } + ] + } + ] + }, + { + "extension_group": "tts", + "extension": "azure_tts", + "pcm_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ], + "cmd": [ + { + "name": "flush", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ] + }, + { + "extension_group": "transcriber", + "extension": "chat_transcriber", + "data": [ + { + "name": "data", + "dest": [ + { + "extension_group": "default", + "extension": "agora_rtc" + } + ] + } + ] + }, + { + "extension_group": "default", + "extension": "interrupt_detector", + "cmd": [ + { + "name": "flush", + "dest": [ + { + "extension_group": "llm", + "extension": "litellm" + } + ] + } + ] + } + ] } ] } diff --git a/docker-compose.yml b/docker-compose.yml index 409b5ca6..86b95e91 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -27,6 +27,7 @@ services: AZURE_TTS_REGION: ${AZURE_TTS_REGION} COSY_TTS_KEY: ${COSY_TTS_KEY} ELEVENLABS_TTS_KEY: ${ELEVENLABS_TTS_KEY} + LITELLM_MODEL: ${LITELLM_MODEL} OPENAI_API_KEY: ${OPENAI_API_KEY} OPENAI_BASE_URL: ${OPENAI_BASE_URL} OPENAI_MODEL: ${OPENAI_MODEL} diff --git a/server/internal/config.go b/server/internal/config.go index 3c034fe4..a0480166 100644 --- a/server/internal/config.go +++ b/server/internal/config.go @@ -14,6 +14,7 @@ const ( extensionNameAzureTTS = "azure_tts" extensionNameCosyTTS = "cosy_tts" extensionNameElevenlabsTTS = "elevenlabs_tts" + extensionNameLiteLLM = "litellm" extensionNameOpenaiChatgpt = "openai_chatgpt" extensionNamePollyTTS = "polly_tts" extensionNameQwenLLM = "qwen_llm" @@ -74,6 +75,15 @@ var ( "OPENAI_API_KEY": { {ExtensionName: extensionNameOpenaiChatgpt, Property: "api_key"}, }, + "LITELLM_API_KEY": { + {ExtensionName: extensionNameLiteLLM, Property: "api_key"}, + }, + "LITELLM_MODEL": { + {ExtensionName: extensionNameLiteLLM, Property: "model"}, + }, + "LITELLM_PROVIDER": { + {ExtensionName: extensionNameLiteLLM, Property: "provider"}, + }, "OPENAI_BASE_URL": { {ExtensionName: extensionNameOpenaiChatgpt, Property: "base_url"}, },