Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat add litellm extension #165

Merged
merged 9 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ COSY_TTS_KEY=
# ElevenLabs TTS key
ELEVENLABS_TTS_KEY=

# Extension: litellm
# Using Environment Variables, refer to https://docs.litellm.ai/docs/providers
# For example:
# OpenAI
# OPENAI_API_KEY=<your-api-key>
# OPENAI_API_BASE=<openai-api-base>
# AWS Bedrock
# AWS_ACCESS_KEY_ID=<your-aws-access-key-id>
# AWS_SECRET_ACCESS_KEY=<your-aws-secret-access-key>
# AWS_REGION_NAME=<aws-region-name>
LITELLM_MODEL=gpt-4o-mini

# Extension: openai_chatgpt
# OpenAI API key
OPENAI_API_KEY=
Expand Down
3 changes: 2 additions & 1 deletion agents/addon/extension/elevenlabs_tts_python/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import elevenlabs_tts_addon
from .extension import EXTENSION_NAME
from .log import logger


logger.info("elevenlabs_tts_python extension loaded")
logger.info(f"{EXTENSION_NAME} extension loaded")
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def default_elevenlabs_tts_config() -> ElevenlabsTTSConfig:
class ElevenlabsTTS:
def __init__(self, config: ElevenlabsTTSConfig) -> None:
self.config = config
self.client = ElevenLabs(
api_key=config.api_key, timeout=config.request_timeout_seconds
)
self.client = ElevenLabs(api_key=config.api_key, timeout=config.request_timeout_seconds)

def text_to_speech_stream(self, text: str) -> Iterator[bytes]:
audio_stream = self.client.generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
register_addon_as_extension,
RteEnv,
)
from .extension import EXTENSION_NAME
from .log import logger


@register_addon_as_extension("elevenlabs_tts_python")
@register_addon_as_extension(EXTENSION_NAME)
class ElevenlabsTTSExtensionAddon(Addon):
def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None:
logger.info("on_create_instance")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
import time

from rte import (
Addon,
Extension,
register_addon_as_extension,
RteEnv,
Cmd,
CmdResult,
StatusCode,
Data,
MetadataInfo,
)
from .elevenlabs_tts import default_elevenlabs_tts_config, ElevenlabsTTS
from .pcm import PcmConfig, Pcm
Expand Down Expand Up @@ -62,68 +59,44 @@ def on_start(self, rte: RteEnv) -> None:
try:
elevenlabs_tts_config.api_key = rte.get_property_string(PROPERTY_API_KEY)
except Exception as e:
logger.warning(
f"on_start get_property_string {PROPERTY_API_KEY} error: {e}"
)
logger.warning(f"on_start get_property_string {PROPERTY_API_KEY} error: {e}")
return

try:
model_id = rte.get_property_string(PROPERTY_MODEL_ID)
if len(model_id) > 0:
elevenlabs_tts_config.model_id = model_id
except Exception as e:
logger.warning(
f"on_start get_property_string {PROPERTY_MODEL_ID} error: {e}"
)
logger.warning(f"on_start get_property_string {PROPERTY_MODEL_ID} error: {e}")

try:
optimize_streaming_latency = rte.get_property_int(
PROPERTY_OPTIMIZE_STREAMING_LATENCY
)
optimize_streaming_latency = rte.get_property_int(PROPERTY_OPTIMIZE_STREAMING_LATENCY)
if optimize_streaming_latency > 0:
elevenlabs_tts_config.optimize_streaming_latency = (
optimize_streaming_latency
)
elevenlabs_tts_config.optimize_streaming_latency = optimize_streaming_latency
except Exception as e:
logger.warning(
f"on_start get_property_int {PROPERTY_OPTIMIZE_STREAMING_LATENCY} error: {e}"
)
logger.warning(f"on_start get_property_int {PROPERTY_OPTIMIZE_STREAMING_LATENCY} error: {e}")

try:
request_timeout_seconds = rte.get_property_int(
PROPERTY_REQUEST_TIMEOUT_SECONDS
)
request_timeout_seconds = rte.get_property_int(PROPERTY_REQUEST_TIMEOUT_SECONDS)
if request_timeout_seconds > 0:
elevenlabs_tts_config.request_timeout_seconds = request_timeout_seconds
except Exception as e:
logger.warning(
f"on_start get_property_int {PROPERTY_REQUEST_TIMEOUT_SECONDS} error: {e}"
)
logger.warning(f"on_start get_property_int {PROPERTY_REQUEST_TIMEOUT_SECONDS} error: {e}")

try:
elevenlabs_tts_config.similarity_boost = rte.get_property_float(
PROPERTY_SIMILARITY_BOOST
)
elevenlabs_tts_config.similarity_boost = rte.get_property_float(PROPERTY_SIMILARITY_BOOST)
except Exception as e:
logger.warning(
f"on_start get_property_float {PROPERTY_SIMILARITY_BOOST} error: {e}"
)
logger.warning(f"on_start get_property_float {PROPERTY_SIMILARITY_BOOST} error: {e}")

try:
elevenlabs_tts_config.speaker_boost = rte.get_property_bool(
PROPERTY_SPEAKER_BOOST
)
elevenlabs_tts_config.speaker_boost = rte.get_property_bool(PROPERTY_SPEAKER_BOOST)
except Exception as e:
logger.warning(
f"on_start get_property_bool {PROPERTY_SPEAKER_BOOST} error: {e}"
)
logger.warning(f"on_start get_property_bool {PROPERTY_SPEAKER_BOOST} error: {e}")

try:
elevenlabs_tts_config.stability = rte.get_property_float(PROPERTY_STABILITY)
except Exception as e:
logger.warning(
f"on_start get_property_float {PROPERTY_STABILITY} error: {e}"
)
logger.warning(f"on_start get_property_float {PROPERTY_STABILITY} error: {e}")

try:
elevenlabs_tts_config.style = rte.get_property_float(PROPERTY_STYLE)
Expand All @@ -133,9 +106,7 @@ def on_start(self, rte: RteEnv) -> None:
# create elevenlabsTTS instance
self.elevenlabs_tts = ElevenlabsTTS(elevenlabs_tts_config)

logger.info(
f"ElevenlabsTTS succeed with model_id: {self.elevenlabs_tts.config.model_id}, VoiceId: {self.elevenlabs_tts.config.voice_id}"
)
logger.info(f"ElevenlabsTTS succeed with model_id: {self.elevenlabs_tts.config.model_id}, VoiceId: {self.elevenlabs_tts.config.voice_id}")

# create pcm instance
self.pcm = Pcm(PcmConfig())
Expand Down Expand Up @@ -186,9 +157,7 @@ def on_data(self, rte: RteEnv, data: Data) -> None:
try:
text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT)
except Exception as e:
logger.warning(
f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} error: {e}"
)
logger.warning(f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} error: {e}")
return

if len(text) == 0:
Expand All @@ -207,9 +176,7 @@ def process_text_queue(self, rte: RteEnv):
logger.debug(f"process_text_queue, text: [{msg.text}]")

if msg.received_ts < self.outdate_ts:
logger.info(
f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}"
)
logger.info(f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}")
continue

start_time = time.time()
Expand All @@ -224,9 +191,7 @@ def process_text_queue(self, rte: RteEnv):

for chunk in self.pcm.read_pcm_stream(audio_stream, self.pcm_frame_size):
if msg.received_ts < self.outdate_ts:
logger.info(
f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}"
)
logger.info(f"textChan interrupt and flushing for input text: [{msg.text}], received_ts: {msg.received_ts}, outdate_ts: {self.outdate_ts}")
break

if not chunk:
Expand All @@ -238,9 +203,7 @@ def process_text_queue(self, rte: RteEnv):
pcm_frame_read += n

if pcm_frame_read != self.pcm.get_pcm_frame_size():
logger.debug(
f"the number of bytes read is [{pcm_frame_read}] inconsistent with pcm frame size",
)
logger.debug(f"the number of bytes read is [{pcm_frame_read}] inconsistent with pcm frame size")
continue

self.pcm.send(rte, buf)
Expand All @@ -250,28 +213,16 @@ def process_text_queue(self, rte: RteEnv):

if first_frame_latency == 0:
first_frame_latency = int((time.time() - start_time) * 1000)
logger.info(
f"first frame available for text: [{msg.text}], received_ts: {msg.received_ts}, first_frame_latency: {first_frame_latency}ms",
)
logger.info(f"first frame available for text: [{msg.text}], received_ts: {msg.received_ts}, first_frame_latency: {first_frame_latency}ms")

logger.debug(f"sending pcm data, text: [{msg.text}]")

if pcm_frame_read > 0:
self.pcm.send(rte, buf)
sent_frames += 1
logger.info(
f"sending pcm remain data, text: [{msg.text}], pcm_frame_read: {pcm_frame_read}"
)
logger.info(f"sending pcm remain data, text: [{msg.text}], pcm_frame_read: {pcm_frame_read}")

finish_latency = int((time.time() - start_time) * 1000)
logger.info(
f"send pcm data finished, text: [{msg.text}], received_ts: {msg.received_ts}, read_bytes: {read_bytes}, sent_frames: {sent_frames}, \
logger.info(f"send pcm data finished, text: [{msg.text}], received_ts: {msg.received_ts}, read_bytes: {read_bytes}, sent_frames: {sent_frames},
first_frame_latency: {first_frame_latency}ms, finish_latency: {finish_latency}ms"
)


@register_addon_as_extension("elevenlabs_tts_python")
class ElevenlabsTTSExtensionAddon(Addon):
def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None:
logger.info("on_create_instance")
rte.on_create_instance_done(ElevenlabsTTSExtension(addon_name), context)
1 change: 1 addition & 0 deletions agents/addon/extension/elevenlabs_tts_python/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
EXTENSION_NAME = "elevenlabs_tts_python"
7 changes: 3 additions & 4 deletions agents/addon/extension/elevenlabs_tts_python/log.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
from .extension import EXTENSION_NAME

logger = logging.getLogger("elevenlabs_tts_python")
logger = logging.getLogger(EXTENSION_NAME)
logger.setLevel(logging.INFO)

formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s"
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s")

console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
Expand Down
15 changes: 4 additions & 11 deletions agents/addon/extension/elevenlabs_tts_python/pcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def get_pcm_frame(self, buf: memoryview) -> PcmFrame:
frame.set_number_of_channels(self.config.num_channels)
frame.set_timestamp(self.config.timestamp)
frame.set_data_fmt(PcmFrameDataFmt.INTERLEAVE)
frame.set_samples_per_channel(
self.config.samples_per_channel // self.config.channel
)
frame.set_samples_per_channel(self.config.samples_per_channel // self.config.channel)

frame.alloc_buf(self.get_pcm_frame_size())
frame_buf = frame.lock_buf()
Expand All @@ -35,24 +33,19 @@ def get_pcm_frame(self, buf: memoryview) -> PcmFrame:
return frame

def get_pcm_frame_size(self) -> int:
return (
self.config.samples_per_channel
* self.config.channel
* self.config.bytes_per_sample
)
return (self.config.samples_per_channel * self.config.channel * self.config.bytes_per_sample)

def new_buf(self) -> bytearray:
return bytearray(self.get_pcm_frame_size())

def read_pcm_stream(
self, stream: Iterator[bytes], chunk_size: int
) -> Iterator[bytes]:
def read_pcm_stream(self, stream: Iterator[bytes], chunk_size: int) -> Iterator[bytes]:
chunk = b""
for data in stream:
chunk += data
while len(chunk) >= chunk_size:
yield chunk[:chunk_size]
chunk = chunk[chunk_size:]

if chunk:
yield chunk

Expand Down
6 changes: 6 additions & 0 deletions agents/addon/extension/litellm_python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import litellm_addon
from .extension import EXTENSION_NAME
from .log import logger


logger.info(f"{EXTENSION_NAME} extension loaded")
1 change: 1 addition & 0 deletions agents/addon/extension/litellm_python/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
EXTENSION_NAME = "litellm_python"
79 changes: 79 additions & 0 deletions agents/addon/extension/litellm_python/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import litellm
import random
from typing import Dict, List, Optional


class LiteLLMConfig:
def __init__(self,
api_key: str,
base_url: str,
frequency_penalty: float,
max_tokens: int,
model: str,
presence_penalty: float,
prompt: str,
provider: str,
temperature: float,
top_p: float,
seed: Optional[int] = None,):
self.api_key = api_key
self.base_url = base_url
self.frequency_penalty = frequency_penalty
self.max_tokens = max_tokens
self.model = model
self.presence_penalty = presence_penalty
self.prompt = prompt
self.provider = provider
self.seed = seed if seed is not None else random.randint(0, 10000)
self.temperature = temperature
self.top_p = top_p

@classmethod
def default_config(cls):
return cls(
api_key="",
base_url="",
max_tokens=512,
model="gpt-4o-mini",
frequency_penalty=0.9,
presence_penalty=0.9,
prompt="You are a voice assistant who talks in a conversational way and can chat with me like my friends. I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. Don’t talk like a robot, instead I would like you to talk like a real human with emotions. I will use your answer for text-to-speech, so don’t return me any meaningless characters. I want you to be helpful, when I’m asking you for advice, give me precise, practical and useful advice instead of being vague. When giving me a list of options, express the options in a narrative way instead of bullet points.",
provider="",
seed=random.randint(0, 10000),
temperature=0.1,
top_p=1.0
)


class LiteLLM:
def __init__(self, config: LiteLLMConfig):
self.config = config

def get_chat_completions_stream(self, messages: List[Dict[str, str]]):
kwargs = {
"api_key": self.config.api_key,
"base_url": self.config.base_url,
"custom_llm_provider": self.config.provider,
"frequency_penalty": self.config.frequency_penalty,
"max_tokens": self.config.max_tokens,
"messages": [
{
"role": "system",
"content": self.config.prompt,
},
*messages,
],
"model": self.config.model,
"presence_penalty": self.config.presence_penalty,
"seed": self.config.seed,
"stream": True,
"temperature": self.config.temperature,
"top_p": self.config.top_p,
}

try:
response = litellm.completion(**kwargs)

return response
except Exception as e:
raise Exception(f"get_chat_completions_stream failed, err: {e}")
Loading