Skip to content

Commit

Permalink
Feat add litellm extension (#165)
Browse files Browse the repository at this point in the history
* feat(): add litellm extension

* chore(): remove duplicate register_addon_as_extension and simplify extension name

* chore(): upgrade to be compatible with the new version of RTE

* chore(): modify log

* chore(): modify log

* fix(): fix api_key

* chore(): optimize litellm

* chore(): modify comment

* chore(): modify litellm usage
  • Loading branch information
sunshinexcode authored Aug 7, 2024
1 parent 8fa589f commit 239dcf0
Show file tree
Hide file tree
Showing 25 changed files with 771 additions and 121 deletions.
12 changes: 12 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ COSY_TTS_KEY=
# ElevenLabs TTS key
ELEVENLABS_TTS_KEY=

# Extension: litellm
# Using Environment Variables, refer to https://docs.litellm.ai/docs/providers
# For example:
# OpenAI
# OPENAI_API_KEY=<your-api-key>
# OPENAI_API_BASE=<openai-api-base>
# AWS Bedrock
# AWS_ACCESS_KEY_ID=<your-aws-access-key-id>
# AWS_SECRET_ACCESS_KEY=<your-aws-secret-access-key>
# AWS_REGION_NAME=<aws-region-name>
LITELLM_MODEL=gpt-4o-mini

# Extension: openai_chatgpt
# OpenAI API key
OPENAI_API_KEY=
Expand Down
3 changes: 2 additions & 1 deletion agents/addon/extension/elevenlabs_tts_python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import elevenlabs_tts_addon
from .extension import EXTENSION_NAME
from .log import logger


logger.info("elevenlabs_tts_python extension loaded")
logger.info(f"{EXTENSION_NAME} extension loaded")
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def default_elevenlabs_tts_config() -> ElevenlabsTTSConfig:
class ElevenlabsTTS:
def __init__(self, config: ElevenlabsTTSConfig) -> None:
self.config = config
self.client = ElevenLabs(
api_key=config.api_key, timeout=config.request_timeout_seconds
)
self.client = ElevenLabs(api_key=config.api_key, timeout=config.request_timeout_seconds)

def text_to_speech_stream(self, text: str) -> Iterator[bytes]:
audio_stream = self.client.generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
register_addon_as_extension,
RteEnv,
)
from .extension import EXTENSION_NAME
from .log import logger


@register_addon_as_extension("elevenlabs_tts_python")
@register_addon_as_extension(EXTENSION_NAME)
class ElevenlabsTTSExtensionAddon(Addon):
def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None:
logger.info("on_create_instance")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
import time

from rte import (
Addon,
Extension,
register_addon_as_extension,
RteEnv,
Cmd,
CmdResult,
StatusCode,
Data,
MetadataInfo,
)
from .elevenlabs_tts import default_elevenlabs_tts_config, ElevenlabsTTS
from .pcm import PcmConfig, Pcm
Expand Down Expand Up @@ -62,68 +59,44 @@ def on_start(self, rte: RteEnv) -> None:
try:
elevenlabs_tts_config.api_key = rte.get_property_string(PROPERTY_API_KEY)
except Exception as e:
logger.warning(
f"on_start get_property_string {PROPERTY_API_KEY} error: {e}"
)
logger.warning(f"on_start get_property_string {PROPERTY_API_KEY} error: {e}")
return

try:
model_id = rte.get_property_string(PROPERTY_MODEL_ID)
if len(model_id) > 0:
elevenlabs_tts_config.model_id = model_id
except Exception as e:
logger.warning(
f"on_start get_property_string {PROPERTY_MODEL_ID} error: {e}"
)
logger.warning(f"on_start get_property_string {PROPERTY_MODEL_ID} error: {e}")

try:
optimize_streaming_latency = rte.get_property_int(
PROPERTY_OPTIMIZE_STREAMING_LATENCY
)
optimize_streaming_latency = rte.get_property_int(PROPERTY_OPTIMIZE_STREAMING_LATENCY)
if optimize_streaming_latency > 0:
elevenlabs_tts_config.optimize_streaming_latency = (
optimize_streaming_latency
)
elevenlabs_tts_config.optimize_streaming_latency = optimize_streaming_latency
except Exception as e:
logger.warning(
f"on_start get_property_int {PROPERTY_OPTIMIZE_STREAMING_LATENCY} error: {e}"
)
logger.warning(f"on_start get_property_int {PROPERTY_OPTIMIZE_STREAMING_LATENCY} error: {e}")

try:
request_timeout_seconds = rte.get_property_int(
PROPERTY_REQUEST_TIMEOUT_SECONDS
)
request_timeout_seconds = rte.get_property_int(PROPERTY_REQUEST_TIMEOUT_SECONDS)
if request_timeout_seconds > 0:
elevenlabs_tts_config.request_timeout_seconds = request_timeout_seconds
except Exception as e:
logger.warning(
f"on_start get_property_int {PROPERTY_REQUEST_TIMEOUT_SECONDS} error: {e}"
)
logger.warning(f"on_start get_property_int {PROPERTY_REQUEST_TIMEOUT_SECONDS} error: {e}")

try:
elevenlabs_tts_config.similarity_boost = rte.get_property_float(
PROPERTY_SIMILARITY_BOOST
)
elevenlabs_tts_config.similarity_boost = rte.get_property_float(PROPERTY_SIMILARITY_BOOST)
except Exception as e:
logger.warning(
f"on_start get_property_float {PROPERTY_SIMILARITY_BOOST} error: {e}"
)
logger.warning(f"on_start get_property_float {PROPERTY_SIMILARITY_BOOST} error: {e}")

try:
elevenlabs_tts_config.speaker_boost = rte.get_property_bool(
PROPERTY_SPEAKER_BOOST
)
elevenlabs_tts_config.speaker_boost = rte.get_property_bool(PROPERTY_SPEAKER_BOOST)
except Exception as e:
logger.warning(
f"on_start get_property_bool {PROPERTY_SPEAKER_BOOST} error: {e}"
)
logger.warning(f"on_start get_property_bool {PROPERTY_SPEAKER_BOOST} error: {e}")

try:
elevenlabs_tts_config.stability = rte.get_property_float(PROPERTY_STABILITY)
except Exception as e:
logger.warning(
f"on_start get_property_float {PROPERTY_STABILITY} error: {e}"
)
logger.warning(f"on_start get_property_float {PROPERTY_STABILITY} error: {e}")

try:
elevenlabs_tts_config.style = rte.get_property_float(PROPERTY_STYLE)
Expand All @@ -133,9 +106,7 @@ def on_start(self, rte: RteEnv) -> None:
# create elevenlabsTTS instance
self.elevenlabs_tts = ElevenlabsTTS(elevenlabs_tts_config)

logger.info(
f"ElevenlabsTTS succeed with model_id: {self.elevenlabs_tts.config.model_id}, VoiceId: {self.elevenlabs_tts.config.voice_id}"
)
logger.info(f"ElevenlabsTTS succeed with model_id: {self.elevenlabs_tts.config.model_id}, VoiceId: {self.elevenlabs_tts.config.voice_id}")

# create pcm instance
self.pcm = Pcm(PcmConfig())
Expand Down Expand Up @@ -186,9 +157,7 @@ def on_data(self, rte: RteEnv, data: Data) -> None:
try:
text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT)
except Exception as e:
logger.warning(
f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} error: {e}"
)
logger.warning(f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} error: {e}")
return

if len(text) == 0:
Expand All @@ -207,9 +176,7 @@ def process_text_queue(self, rte: RteEnv):
logger.debug(f"process_text_queue, text: [{msg.text}]")

if msg.received_ts < self.outdate_ts:
logger.info(
f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}"
)
logger.info(f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}")
continue

start_time = time.time()
Expand All @@ -224,9 +191,7 @@ def process_text_queue(self, rte: RteEnv):

for chunk in self.pcm.read_pcm_stream(audio_stream, self.pcm_frame_size):
if msg.received_ts < self.outdate_ts:
logger.info(
f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}"
)
logger.info(f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}")
break

if not chunk:
Expand All @@ -238,9 +203,7 @@ def process_text_queue(self, rte: RteEnv):
pcm_frame_read += n

if pcm_frame_read != self.pcm.get_pcm_frame_size():
logger.debug(
f"the number of bytes read is [{pcm_frame_read}] inconsistent with pcm frame size",
)
logger.debug(f"the number of bytes read is [{pcm_frame_read}] inconsistent with pcm frame size")
continue

self.pcm.send(rte, buf)
Expand All @@ -250,28 +213,16 @@ def process_text_queue(self, rte: RteEnv):

if first_frame_latency == 0:
first_frame_latency = int((time.time() - start_time) * 1000)
logger.info(
f"first frame available for text: [{msg.text}], received_ts: {msg.received_ts}, first_frame_latency: {first_frame_latency}ms",
)
logger.info(f"first frame available for text: [{msg.text}], received_ts: {msg.received_ts}, first_frame_latency: {first_frame_latency}ms")

logger.debug(f"sending pcm data, text: [{msg.text}]")

if pcm_frame_read > 0:
self.pcm.send(rte, buf)
sent_frames += 1
logger.info(
f"sending pcm remain data, text: [{msg.text}], pcm_frame_read: {pcm_frame_read}"
)
logger.info(f"sending pcm remain data, text: [{msg.text}], pcm_frame_read: {pcm_frame_read}")

finish_latency = int((time.time() - start_time) * 1000)
logger.info(
f"send pcm data finished, text: [{msg.text}], received_ts: {msg.received_ts}, read_bytes: {read_bytes}, sent_frames: {sent_frames}, \
logger.info(f"send pcm data finished, text: [{msg.text}], received_ts: {msg.received_ts}, read_bytes: {read_bytes}, sent_frames: {sent_frames},
first_frame_latency: {first_frame_latency}ms, finish_latency: {finish_latency}ms"
)


@register_addon_as_extension("elevenlabs_tts_python")
class ElevenlabsTTSExtensionAddon(Addon):
def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None:
logger.info("on_create_instance")
rte.on_create_instance_done(ElevenlabsTTSExtension(addon_name), context)
1 change: 1 addition & 0 deletions agents/addon/extension/elevenlabs_tts_python/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
EXTENSION_NAME = "elevenlabs_tts_python"
7 changes: 3 additions & 4 deletions agents/addon/extension/elevenlabs_tts_python/log.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
from .extension import EXTENSION_NAME

logger = logging.getLogger("elevenlabs_tts_python")
logger = logging.getLogger(EXTENSION_NAME)
logger.setLevel(logging.INFO)

formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s"
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s")

console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
Expand Down
15 changes: 4 additions & 11 deletions agents/addon/extension/elevenlabs_tts_python/pcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def get_pcm_frame(self, buf: memoryview) -> PcmFrame:
frame.set_number_of_channels(self.config.num_channels)
frame.set_timestamp(self.config.timestamp)
frame.set_data_fmt(PcmFrameDataFmt.INTERLEAVE)
frame.set_samples_per_channel(
self.config.samples_per_channel // self.config.channel
)
frame.set_samples_per_channel(self.config.samples_per_channel // self.config.channel)

frame.alloc_buf(self.get_pcm_frame_size())
frame_buf = frame.lock_buf()
Expand All @@ -35,24 +33,19 @@ def get_pcm_frame(self, buf: memoryview) -> PcmFrame:
return frame

def get_pcm_frame_size(self) -> int:
return (
self.config.samples_per_channel
* self.config.channel
* self.config.bytes_per_sample
)
return (self.config.samples_per_channel * self.config.channel * self.config.bytes_per_sample)

def new_buf(self) -> bytearray:
return bytearray(self.get_pcm_frame_size())

def read_pcm_stream(
self, stream: Iterator[bytes], chunk_size: int
) -> Iterator[bytes]:
def read_pcm_stream(self, stream: Iterator[bytes], chunk_size: int) -> Iterator[bytes]:
chunk = b""
for data in stream:
chunk += data
while len(chunk) >= chunk_size:
yield chunk[:chunk_size]
chunk = chunk[chunk_size:]

if chunk:
yield chunk

Expand Down
6 changes: 6 additions & 0 deletions agents/addon/extension/litellm_python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import litellm_addon
from .extension import EXTENSION_NAME
from .log import logger


logger.info(f"{EXTENSION_NAME} extension loaded")
1 change: 1 addition & 0 deletions agents/addon/extension/litellm_python/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
EXTENSION_NAME = "litellm_python"
79 changes: 79 additions & 0 deletions agents/addon/extension/litellm_python/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import litellm
import random
from typing import Dict, List, Optional


class LiteLLMConfig:
def __init__(self,
api_key: str,
base_url: str,
frequency_penalty: float,
max_tokens: int,
model: str,
presence_penalty: float,
prompt: str,
provider: str,
temperature: float,
top_p: float,
seed: Optional[int] = None,):
self.api_key = api_key
self.base_url = base_url
self.frequency_penalty = frequency_penalty
self.max_tokens = max_tokens
self.model = model
self.presence_penalty = presence_penalty
self.prompt = prompt
self.provider = provider
self.seed = seed if seed is not None else random.randint(0, 10000)
self.temperature = temperature
self.top_p = top_p

@classmethod
def default_config(cls):
return cls(
api_key="",
base_url="",
max_tokens=512,
model="gpt-4o-mini",
frequency_penalty=0.9,
presence_penalty=0.9,
prompt="You are a voice assistant who talks in a conversational way and can chat with me like my friends. I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. Don’t talk like a robot, instead I would like you to talk like a real human with emotions. I will use your answer for text-to-speech, so don’t return me any meaningless characters. I want you to be helpful, when I’m asking you for advice, give me precise, practical and useful advice instead of being vague. When giving me a list of options, express the options in a narrative way instead of bullet points.",
provider="",
seed=random.randint(0, 10000),
temperature=0.1,
top_p=1.0
)


class LiteLLM:
def __init__(self, config: LiteLLMConfig):
self.config = config

def get_chat_completions_stream(self, messages: List[Dict[str, str]]):
kwargs = {
"api_key": self.config.api_key,
"base_url": self.config.base_url,
"custom_llm_provider": self.config.provider,
"frequency_penalty": self.config.frequency_penalty,
"max_tokens": self.config.max_tokens,
"messages": [
{
"role": "system",
"content": self.config.prompt,
},
*messages,
],
"model": self.config.model,
"presence_penalty": self.config.presence_penalty,
"seed": self.config.seed,
"stream": True,
"temperature": self.config.temperature,
"top_p": self.config.top_p,
}

try:
response = litellm.completion(**kwargs)

return response
except Exception as e:
raise Exception(f"get_chat_completions_stream failed, err: {e}")
Loading

0 comments on commit 239dcf0

Please sign in to comment.