Skip to content

Commit

Permalink
431 feature support fashionai to support interrupt for video (#450)
Browse files Browse the repository at this point in the history
* fix: fix env parameters

* feat: update fashion.ai
  • Loading branch information
plutoless authored Dec 10, 2024
1 parent 47ed694 commit eea03e2
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 105 deletions.
81 changes: 30 additions & 51 deletions agents/examples/experimental/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,33 @@
"name": "text_data"
}
],
"cmd": [
{
"name": "on_user_joined",
"dest": [
{
"extension_group": "chatgpt",
"extension": "openai_chatgpt"
}
]
},
{
"name": "on_user_left",
"dest": [
{
"extension_group": "chatgpt",
"extension": "openai_chatgpt"
}
]
}
],
"extension": "agora_rtc",
"extension_group": "default"
},
{
"cmd": [
{
"dest": [
{
"extension": "azure_tts",
"extension_group": "tts"
},
{
"extension": "fashionai",
"extension_group": "default"
Expand All @@ -63,32 +79,6 @@
"extension": "openai_chatgpt",
"extension_group": "chatgpt"
},
{
"audio_frame": [
{
"dest": [
{
"extension": "agora_rtc",
"extension_group": "default"
}
],
"name": "pcm_frame"
}
],
"cmd": [
{
"dest": [
{
"extension": "agora_rtc",
"extension_group": "default"
}
],
"name": "flush"
}
],
"extension": "azure_tts",
"extension_group": "tts"
},
{
"data": [
{
Expand Down Expand Up @@ -128,10 +118,10 @@
"property": {
"agora_asr_language": "en-US",
"agora_asr_session_control_file_path": "session_control.conf",
"agora_asr_vendor_key": "$AZURE_STT_KEY",
"agora_asr_vendor_key": "${env:AZURE_STT_KEY}",
"agora_asr_vendor_name": "microsoft",
"agora_asr_vendor_region": "$AZURE_STT_REGION",
"app_id": "$AGORA_APP_ID",
"agora_asr_vendor_region": "${env:AZURE_STT_REGION}",
"app_id": "${env:AGORA_APP_ID}",
"channel": "ten_agent_test",
"enable_agora_asr": true,
"publish_audio": true,
Expand All @@ -150,7 +140,7 @@
"type": "extension"
},
{
"addon": "openai_chatgpt",
"addon": "openai_chatgpt_python",
"extension_group": "chatgpt",
"name": "openai_chatgpt",
"property": {
Expand All @@ -166,17 +156,6 @@
},
"type": "extension"
},
{
"addon": "azure_tts",
"extension_group": "tts",
"name": "azure_tts",
"property": {
"azure_subscription_key": "$AZURE_TTS_KEY",
"azure_subscription_region": "$AZURE_TTS_REGION",
"azure_synthesis_voice_name": "en-US-JaneNeural"
},
"type": "extension"
},
{
"addon": "message_collector",
"extension_group": "transcriber",
Expand All @@ -188,11 +167,11 @@
"extension_group": "default",
"name": "fashionai",
"property": {
"app_id": "$AGORA_APP_ID",
"channel": "ten_agents_test",
"app_id": "${env:AGORA_APP_ID}",
"channel": "ten_agent_test",
"stream_id": 12345,
"token": "",
"service_id": "agora"
"token": "<agora_token>",
"service_id": "agoramultimodel"
},
"type": "extension"
}
Expand Down Expand Up @@ -513,7 +492,7 @@
"property": {
"app_id": "${env:AGORA_APP_ID}",
"token": "<agora_token>",
"channel": "ten_agents_test",
"channel": "ten_agent_test",
"stream_id": 1234,
"remote_stream_id": 123,
"subscribe_audio": true,
Expand Down Expand Up @@ -3627,7 +3606,7 @@
"property": {
"app_id": "${env:AGORA_APP_ID}",
"token": "<agora_token>",
"channel": "ten_agents_test",
"channel": "ten_agent_test",
"stream_id": 1234,
"remote_stream_id": 123,
"subscribe_audio": true,
Expand Down
76 changes: 24 additions & 52 deletions agents/ten_packages/extension/fashionai/src/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,38 @@
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
import traceback
from ten import (
AudioFrame,
VideoFrame,
Extension,
TenEnv,
AsyncTenEnv,
Cmd,
StatusCode,
CmdResult,
Data,
)
from ten.async_extension import AsyncExtension

from .log import logger
import asyncio
from .fashionai_client import FashionAIClient
import threading
from datetime import datetime

class FashionAIExtension(Extension):
class FashionAIExtension(AsyncExtension):
app_id = ""
token = ""
channel = ""
stream_id = 0
service_id = "agora"

def on_init(self, ten_env: TenEnv) -> None:
async def on_init(self, ten_env: AsyncTenEnv) -> None:
logger.info("FASHION_AI on_init *********************************************************")
self.stopped = False
self.queue = asyncio.Queue(maxsize=3000)
self.threadWebsocketLoop = None

ten_env.on_init_done()

def on_start(self, ten_env: TenEnv) -> None:
async def on_start(self, ten_env: AsyncTenEnv) -> None:
logger.info("FASHION_AI on_start *********************************************************")

# TODO: read properties, initialize resources
Expand All @@ -54,55 +53,42 @@ def on_start(self, ten_env: TenEnv) -> None:

if len(self.token) > 0:
self.app_id = self.token
self.client = FashionAIClient("wss://ingress.service.fasionai.com/websocket/node7/server1", self.service_id)

def thread_target():
self.threadWebsocketLoop = asyncio.new_event_loop()
asyncio.set_event_loop(self.threadWebsocketLoop)
self.threadWebsocketLoop.run_until_complete(self.init_fashionai(self.app_id, self.channel, self.stream_id))

self.threadWebsocket = threading.Thread(target=thread_target)
self.threadWebsocket.start()
self.client = FashionAIClient("wss://ingress.service.fasionai.com/websocket/node5/agoramultimodel2", self.service_id)
asyncio.create_task(self.process_input_text())
await self.init_fashionai(self.app_id, self.channel, self.stream_id)

ten_env.on_start_done()

def on_stop(self, ten_env: TenEnv) -> None:
async def on_stop(self, ten_env: AsyncTenEnv) -> None:
logger.info("FASHION_AI on_stop")
self.stopped = True
asyncio.run_coroutine_threadsafe(self.queue.put(None), self.threadWebsocketLoop)
asyncio.run_coroutine_threadsafe(self.flush(), self.threadWebsocketLoop)

self.threadWebsocket.join()

ten_env.on_stop_done()
await self.queue.put(None)

def on_deinit(self, ten_env: TenEnv) -> None:
async def on_deinit(self, ten_env: AsyncTenEnv) -> None:
logger.info("FASHION_AI on_deinit")
ten_env.on_deinit_done()

def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
logger.info("FASHION_AI on_cmd name {}".format(cmd_name))

# TODO: process cmd
if cmd_name == "flush":
self.outdate_ts = datetime.now()
try:
asyncio.run_coroutine_threadsafe(
self.flush(), self.threadWebsocketLoop
).result(timeout=0.1)
await self.client.send_interrupt()

except Exception as e:
ten_env.return_result(CmdResult.create(StatusCode.ERROR), cmd)
logger.warning(f"flush err: {traceback.format_exc()}")

cmd_out = Cmd.create("flush")
ten_env.send_cmd(cmd_out, lambda ten, result: logger.info("send_cmd flush done"))
await ten_env.send_cmd(cmd_out)
# ten_env.send_cmd(cmd_out, lambda ten, result: logger.info("send_cmd flush done"))
else:
logger.info("unknown cmd {}".format(cmd_name))

logger.info("FASHION_AI on_cmd done")
cmd_result = CmdResult.create(StatusCode.OK)
ten_env.return_result(cmd_result, cmd)

def on_data(self, ten_env: TenEnv, data: Data) -> None:
async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None:
# TODO: process data
inputText = data.get_property_string("text")
if len(inputText) == 0:
Expand All @@ -111,33 +97,27 @@ def on_data(self, ten_env: TenEnv, data: Data) -> None:

logger.info("FASHION_AI on data %s", inputText)
try:
future = asyncio.run_coroutine_threadsafe(
self.queue.put(inputText), self.threadWebsocketLoop
)
future.result(timeout=0.1)
await self.queue.put(inputText)
except asyncio.TimeoutError:
logger.warning(f"FASHION_AI put inputText={inputText} queue timed out")
except Exception as e:
logger.warning(f"FASHION_AI put inputText={inputText} queue err: {e}")
logger.info("FASHION_AI send_inputText %s", inputText)

pass

def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
async def on_audio_frame(self, ten_env: AsyncTenEnv, audio_frame: AudioFrame) -> None:
# TODO: process pcm frame
pass

def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
async def on_video_frame(self, ten_env: AsyncTenEnv, video_frame: VideoFrame) -> None:
# TODO: process image frame
pass

async def init_fashionai(self, app_id, channel, stream_id):
await self.client.connect()
await self.client.stream_start(app_id, channel, stream_id)
await self.client.render_start()
await self.async_polly_handler()

async def async_polly_handler(self):
async def process_input_text(self):
while True:
inputText = await self.queue.get()
if inputText is None:
Expand All @@ -151,11 +131,3 @@ async def async_polly_handler(self):
await self.client.send_inputText(inputText)
except Exception as e:
logger.exception(e)

async def flush(self):
logger.info("FASHION_AI flush")
while not self.queue.empty():
value = await self.queue.get()
if value is None:
break
logger.info(f"Flushing value: {value}")
38 changes: 36 additions & 2 deletions agents/ten_packages/extension/fashionai/src/fashionai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,23 @@ def __init__(self, uri, service_id):
self.uri = uri
self.websocket = None
self.service_id = service_id
self.cancelled = False

async def connect(self):
ssl_context = ssl._create_unverified_context()
self.websocket = await websockets.connect(self.uri, ssl=ssl_context)
asyncio.create_task(self.listen()) # Start listening immediately after connection

async def listen(self):
"""Continuously listen for incoming messages."""
if self.websocket is not None:
try:
async for message in self.websocket:
logger.info(f"FASHION_AI Received: {message}")
# await self.handle_message(message)
except websockets.exceptions.ConnectionClosedError as e:
logger.info(f"FASHION_AI Connection closed with error: {e}")
await self.reconnect()

async def stream_start(self, app_id, channel, stream_id):
await self.send_message(
Expand All @@ -29,6 +42,15 @@ async def stream_start(self, app_id, channel, stream_id):
"signal": "STREAM_START",
}
)

async def stream_stop(self):
await self.send_message(
{
"request_id": str(uuid.uuid4()),
"service_id": self.service_id,
"signal": "STREAM_STOP",
}
)

async def render_start(self):
await self.send_message(
Expand All @@ -38,8 +60,11 @@ async def render_start(self):
"signal": "RENDER_START",
}
)
self.cancelled = False

async def send_inputText(self, inputText):
if self.cancelled:
await self.render_start()
await self.send_message(
{
"request_id": str(uuid.uuid4()),
Expand All @@ -49,14 +74,23 @@ async def send_inputText(self, inputText):
}
)

async def send_interrupt(self):
await self.send_message(
{
"service_id": self.service_id,
"signal": "RENDER_CANCEL",
}
)
self.cancelled = True


async def send_message(self, message):
if self.websocket is not None:
try:
await self.websocket.send(json.dumps(message))
logger.info(f"FASHION_AI Sent: {message}")
response = await asyncio.wait_for(self.websocket.recv(), timeout=2)
logger.info(f"FASHION_AI Received: {response}")
# response = await asyncio.wait_for(self.websocket.recv(), timeout=2)
# logger.info(f"FASHION_AI Received: {response}")
except websockets.exceptions.ConnectionClosedError as e:
logger.info(f"FASHION_AI Connection closed with error: {e}")
await self.reconnect()
Expand Down

0 comments on commit eea03e2

Please sign in to comment.