Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact glue and use base config #424

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 38 additions & 45 deletions agents/ten_packages/extension/deepgram_asr_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,34 @@
import asyncio

from deepgram import AsyncListenWebSocketClient, DeepgramClientOptions, LiveTranscriptionEvents, LiveOptions
from dataclasses import dataclass

from .config import DeepgramConfig

PROPERTY_API_KEY = "api_key" # Required
PROPERTY_LANG = "language" # Optional
PROPERTY_MODEL = "model" # Optional
PROPERTY_SAMPLE_RATE = "sample_rate" # Optional
from ten_ai_base import BaseConfig

DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text"
DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id"
DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment"

@dataclass
class DeepgramASRConfig(BaseConfig):
api_key: str = ""
language: str = "en-US"
model: str = "nova-2"
sample_rate: int = 16000

channels: int = 1
encoding: str = 'linear16'
interim_results: bool = True
punctuate: bool = True

class DeepgramASRExtension(AsyncExtension):
def __init__(self, name: str):
super().__init__(name)

self.stopped = False
self.deepgram_client : AsyncListenWebSocketClient = None
self.deepgram_config : DeepgramConfig = None
self.client : AsyncListenWebSocketClient = None
self.config : DeepgramASRConfig = None
self.ten_env : AsyncTenEnv = None

async def on_init(self, ten_env: AsyncTenEnv) -> None:
Expand All @@ -41,30 +49,15 @@ async def on_start(self, ten_env: AsyncTenEnv) -> None:
self.loop = asyncio.get_event_loop()
self.ten_env = ten_env

self.deepgram_config = DeepgramConfig.default_config()
self.config = DeepgramASRConfig.create(ten_env=ten_env)
ten_env.log_info(f"config: {self.config}")

try:
self.deepgram_config.api_key = ten_env.get_property_string(PROPERTY_API_KEY).strip()
except Exception as e:
ten_env.log_error(f"get property {PROPERTY_API_KEY} error: {e}")
if not self.config.api_key:
ten_env.log_error(f"get property api_key")
return

for optional_param in [
PROPERTY_LANG,
PROPERTY_MODEL,
PROPERTY_SAMPLE_RATE,
]:
try:
value = ten_env.get_property_string(optional_param).strip()
if value:
self.deepgram_config.__setattr__(optional_param, value)
except Exception as err:
ten_env.log_debug(
f"get property optional {optional_param} failed, err: {err}. Using default value: {self.deepgram_config.__getattribute__(optional_param)}"
)

self.deepgram_client = AsyncListenWebSocketClient(config=DeepgramClientOptions(
api_key=self.deepgram_config.api_key,
self.client = AsyncListenWebSocketClient(config=DeepgramClientOptions(
api_key=self.config.api_key,
options={"keepalive": "true"}
))

Expand All @@ -80,12 +73,12 @@ async def on_audio_frame(self, ten_env: AsyncTenEnv, frame: AudioFrame) -> None:
return

self.stream_id = frame.get_property_int('stream_id')
await self.deepgram_client.send(frame_buf)
await self.client.send(frame_buf)

async def on_stop(self, ten_env: AsyncTenEnv) -> None:
ten_env.log_info("on_stop")

await self.deepgram_client.finish()
await self.client.finish()

async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
cmd_json = cmd.to_json()
Expand Down Expand Up @@ -118,22 +111,22 @@ async def on_message(_, result, **kwargs):
async def on_error(_, error, **kwargs):
self.ten_env.log_error(f"deepgram event callback on_error: {error}")

self.deepgram_client.on(LiveTranscriptionEvents.Open, on_open)
self.deepgram_client.on(LiveTranscriptionEvents.Close, on_close)
self.deepgram_client.on(LiveTranscriptionEvents.Transcript, on_message)
self.deepgram_client.on(LiveTranscriptionEvents.Error, on_error)

options = LiveOptions(language=self.deepgram_config.language,
model=self.deepgram_config.model,
sample_rate=self.deepgram_config.sample_rate,
channels=self.deepgram_config.channels,
encoding=self.deepgram_config.encoding,
interim_results=self.deepgram_config.interim_results,
punctuate=self.deepgram_config.punctuate)
self.client.on(LiveTranscriptionEvents.Open, on_open)
self.client.on(LiveTranscriptionEvents.Close, on_close)
self.client.on(LiveTranscriptionEvents.Transcript, on_message)
self.client.on(LiveTranscriptionEvents.Error, on_error)

options = LiveOptions(language=self.config.language,
model=self.config.model,
sample_rate=self.config.sample_rate,
channels=self.config.channels,
encoding=self.config.encoding,
interim_results=self.config.interim_results,
punctuate=self.config.punctuate)
# connect to websocket
result = await self.deepgram_client.start(options)
result = await self.client.start(options)
if result is False:
if self.deepgram_client.status_code == 402:
if self.client.status_code == 402:
self.ten_env.log_error("Failed to connect to Deepgram - your account has run out of credits.")
else:
self.ten_env.log_error("Failed to connect to Deepgram")
Expand Down
2 changes: 0 additions & 2 deletions agents/ten_packages/extension/glue_python_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,4 @@
# See the LICENSE file for more information.
#
from . import addon
from .log import logger

logger.info("glue_python_async extension loaded")
3 changes: 1 addition & 2 deletions agents/ten_packages/extension/glue_python_async/addon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@ class AsyncGlueExtensionAddon(Addon):

def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None:
from .extension import AsyncGlueExtension
from .log import logger
logger.info("AsyncGlueExtensionAddon on_create_instance")
ten_env.log_info("AsyncGlueExtensionAddon on_create_instance")
ten_env.on_create_instance_done(AsyncGlueExtension(name), context)
Original file line number Diff line number Diff line change
@@ -1,16 +1,51 @@
import os
import openai
import json
from openai import AsyncOpenAI
import traceback # Add this import
import traceback
import logging
import logging.config

from typing import List, Union
from pydantic import BaseModel, HttpUrl
from typing import List, Union, Dict, Optional
from pydantic import BaseModel, HttpUrl, ValidationError

from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import Depends, FastAPI, HTTPException, Request
import asyncio

# Enable Pydantic debug mode
from pydantic import BaseConfig

BaseConfig.debug = True

# Set up logging
logging.config.dictConfig({
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"file": {
"level": "DEBUG",
"formatter": "default",
"class": "logging.FileHandler",
"filename": "example.log",
},
},
"loggers": {
"": {
"handlers": ["file"],
"level": "DEBUG",
"propagate": True,
},
},
})
logger = logging.getLogger(__name__)

app = FastAPI(title="Chat Completion API",
description="API for streaming chat completions with support for text, image, and audio content",
version="1.0.0")
Expand All @@ -27,79 +62,109 @@ class ImageContent(BaseModel):
image_url: HttpUrl

class AudioContent(BaseModel):
type: str = "audio"
audio_url: HttpUrl

class Message(BaseModel):
role: str
content: Union[TextContent, ImageContent, AudioContent, List[Union[TextContent, ImageContent, AudioContent]]]
type: str = "input_audio"
input_audio: Dict[str, str]

class ToolFunction(BaseModel):
name: str
description: Optional[str]
parameters: Optional[Dict]
strict: bool = False

class Tool(BaseModel):
type: str = "function"
function: ToolFunction

class ToolChoice(BaseModel):
type: str = "function"
function: Optional[Dict]

class ResponseFormat(BaseModel):
type: str = "json_schema"
json_schema: Optional[Dict[str, str]]

class SystemMessage(BaseModel):
role: str = "system"
content: Union[str, List[str]]

class UserMessage(BaseModel):
role: str = "user"
content: Union[str, List[Union[TextContent, ImageContent, AudioContent]]]

class AssistantMessage(BaseModel):
role: str = "assistant"
content: Union[str, List[TextContent]] = None
audio: Optional[Dict[str, str]] = None
tool_calls: Optional[List[Dict]] = None

class ToolMessage(BaseModel):
role: str = "tool"
content: Union[str, List[str]]
tool_call_id: str

class ChatCompletionRequest(BaseModel):
messages: List[Message]
model: str
temperature: float = 1.0
context: Optional[Dict] = None
model: Optional[str] = None
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]]
response_format: Optional[ResponseFormat] = None
modalities: List[str] = ["text"]
audio: Optional[Dict[str, str]] = None
tools: Optional[List[Tool]] = None
tool_choice: Optional[Union[str, ToolChoice]] = "auto"
parallel_tool_calls: bool = True
stream: bool = True
stream_options: Optional[Dict] = None

def format_openai_messages(messages):
formatted_messages = []
for msg in messages:
if isinstance(msg.content, list):
content = []
for item in msg.content:
if item.type == "text":
content.append({"type": "text", "text": item.text})
elif item.type == "image":
content.append({"type": "image_url", "image_url": str(item.image_url)})
elif item.type == "audio":
content.append({"type": "audio_url", "audio_url": str(item.audio_url)})
else:
if msg.content.type == "text":
content = [{"type": "text", "text": msg.content.text}]
elif msg.content.type == "image":
content = [{"type": "image_url", "image_url": {"url": str(msg.content.image_url)}}]
elif msg.content.type == "audio":
content = [{"type": "audio_url", "audio_url": {"url": str(msg.content.audio_url)}}]

formatted_messages.append({"role": msg.role, "content": content})
return formatted_messages
'''
{'messages': [{'role': 'user', 'content': 'Hello. Hello. Hello.'}, {'role': 'user', 'content': 'Unprocessedable.'}], 'tools': [], 'tools_choice': 'none', 'model': 'gpt-3.5-turbo', 'stream': True}
'''

security = HTTPBearer()

def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials
if token != os.getenv("API_TOKEN"):
logger.warning("Invalid or missing token")
raise HTTPException(status_code=403, detail="Invalid or missing token")

@app.post("/chat/completions", dependencies=[Depends(verify_token)])
async def create_chat_completion(request: ChatCompletionRequest, req: Request):
try:
messages = format_openai_messages(request.messages)
logger.debug(f"Received request: {request.json()}")
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = await client.chat.completions.create(
model=request.model,
messages=messages,
temperature=request.temperature,
stream=request.stream
messages=request.messages, # Directly use request messages
tool_choice=request.tool_choice if request.tools and request.tool_choice else None,
tools=request.tools if request.tools else None,
# modalities=request.modalities,
response_format=request.response_format,
stream=request.stream,
stream_options=request.stream_options
)

async def generate():
try:
async for chunk in response:
if chunk.choices[0].delta.content is not None:
yield f"data: {chunk.choices[0].delta.content}\n\n"
yield "data: [DONE]\n\n"
except asyncio.CancelledError:
print("Request was cancelled")
raise

return StreamingResponse(generate(), media_type="text/event-stream")
if request.stream:
async def generate():
try:
async for chunk in response:
logger.info(f"Received chunk: {chunk}")
yield f"data: {json.dumps(chunk.to_dict())}\n\n"
yield "data: [DONE]\n\n"
except asyncio.CancelledError:
logger.info("Request was cancelled")
raise

return StreamingResponse(generate(), media_type="text/event-stream")
else:
result = await response
return result
except asyncio.CancelledError:
print("Request was cancelled")
logger.info("Request was cancelled")
raise HTTPException(status_code=499, detail="Request was cancelled")
except Exception as e:
traceback_str = ''.join(traceback.format_tb(e.__traceback__))
error_message = f"{str(e)}\n{traceback_str}"
print(error_message)
logger.error(error_message)
raise HTTPException(status_code=500, detail=error_message)

if __name__ == "__main__":
Expand Down
Loading