Skip to content

Commit

Permalink
fix: fix all default graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
plutoless committed Dec 23, 2024
1 parent 69a3b44 commit ae17c15
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 34 deletions.
60 changes: 60 additions & 0 deletions agents/examples/default/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,66 @@
"type": "extension",
"name": "deepgram_asr_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "vision_tool_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "vision_analyze_tool_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "transcribe_asr_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "gemini_llm_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "bedrock_llm_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "polly_tts",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "minimax_tts_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "minimax_v2v_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "cosy_tts_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "elevenlabs_tts_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "dify_python",
"version": "=0.1.0"
},
{
"type": "extension",
"name": "gemini_v2v_python",
"version": "=0.1.0"
}
]
}
2 changes: 1 addition & 1 deletion agents/ten_packages/extension/cartesia_tts/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def on_request_tts(
audio_stream = await self.client.text_to_speech_stream(input_text)

async for audio_data in audio_stream:
self.send_audio_out(ten_env, audio_data["audio"])
await self.send_audio_out(ten_env, audio_data["audio"])

async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None:
return await super().on_cancel_tts(ten_env)
2 changes: 1 addition & 1 deletion agents/ten_packages/extension/cosy_tts_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def _process_audio_data(self, ten_env: AsyncTenEnv) -> None:
if audio_data is None:
break

self.send_audio_out(ten_env, audio_data)
await self.send_audio_out(ten_env, audio_data)

async def on_request_tts(
self, ten_env: AsyncTenEnv, input_text: str, end_of_segment: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

from dataclasses import dataclass
from typing import AsyncIterator
from elevenlabs import Voice, VoiceSettings
from elevenlabs.client import AsyncElevenLabs

from ten_ai_base.config import BaseConfig


Expand All @@ -30,11 +27,18 @@ class ElevenLabsTTSConfig(BaseConfig):
class ElevenLabsTTS:
def __init__(self, config: ElevenLabsTTSConfig) -> None:
self.config = config
self.client = AsyncElevenLabs(
api_key=config.api_key, timeout=config.request_timeout_seconds
)
self.client = None

def text_to_speech_stream(self, text: str) -> AsyncIterator[bytes]:
# to avoid circular import issue when using openai with 11labs
from elevenlabs.client import AsyncElevenLabs
from elevenlabs import Voice, VoiceSettings

if not self.client:
self.client = AsyncElevenLabs(
api_key=self.config.api_key, timeout=self.config.request_timeout_seconds
)

return self.client.generate(
text=text,
model=self.config.model_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def on_request_tts(
audio_stream = await self.client.text_to_speech_stream(input_text)
ten_env.log_info(f"on_request_tts: {input_text}")
async for audio_data in audio_stream:
self.send_audio_out(ten_env, audio_data)
await self.send_audio_out(ten_env, audio_data)
ten_env.log_info(f"on_request_tts: {input_text} done")

async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
elevenlabs>=1.13.0
elevenlabs>=1.50.0
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def on_stop(self, ten: TenEnv) -> None:

def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None:
logger.info("GeminiLLMExtension on_cmd")
cmd_json = cmd.to_json()
logger.info(f"GeminiLLMExtension on_cmd json: {cmd_json}")
cmd_name = cmd.get_name()
logger.info(f"GeminiLLMExtension on_cmd json: {cmd_name}")

cmd_name = cmd.get_name()

Expand Down
1 change: 0 additions & 1 deletion agents/ten_packages/extension/gemini_v2v_python/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ ten_package("gemini_v2v_python") {
"__init__.py",
"addon.py",
"extension.py",
"log.py",
"manifest.json",
"property.json",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def on_request_tts(
try:
data = self.client.get(ten_env, input_text)
async for frame in data:
self.send_audio_out(
await self.send_audio_out(
ten_env, frame, sample_rate=self.client.config.sample_rate
)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (p *openaiChatGPTExtension) OnStart(tenEnv ten.TenEnv) {
outputData, _ := ten.NewData("text_data")
outputData.SetProperty(dataOutTextDataPropertyText, greeting)
outputData.SetProperty(dataOutTextDataPropertyTextEndOfSegment, true)
if err := tenEnv.SendData(outputData); err != nil {
if err := tenEnv.SendData(outputData, nil); err != nil {
slog.Error(fmt.Sprintf("greeting [%s] send failed, err: %v", greeting, err), logTag)
} else {
slog.Info(fmt.Sprintf("greeting [%s] sent", greeting), logTag)
Expand All @@ -210,7 +210,7 @@ func (p *openaiChatGPTExtension) OnCmd(
if err != nil {
slog.Error(fmt.Sprintf("OnCmd get name failed, err: %v", err), logTag)
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeError)
tenEnv.ReturnResult(cmdResult, cmd)
tenEnv.ReturnResult(cmdResult, cmd, nil)
return
}
slog.Info(fmt.Sprintf("OnCmd %s", cmdInFlush), logTag)
Expand All @@ -226,20 +226,20 @@ func (p *openaiChatGPTExtension) OnCmd(
if err != nil {
slog.Error(fmt.Sprintf("new cmd %s failed, err: %v", cmdOutFlush, err), logTag)
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeError)
tenEnv.ReturnResult(cmdResult, cmd)
tenEnv.ReturnResult(cmdResult, cmd, nil)
return
}
if err := tenEnv.SendCmd(outCmd, nil); err != nil {
slog.Error(fmt.Sprintf("send cmd %s failed, err: %v", cmdOutFlush, err), logTag)
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeError)
tenEnv.ReturnResult(cmdResult, cmd)
tenEnv.ReturnResult(cmdResult, cmd, nil)
return
} else {
slog.Info(fmt.Sprintf("cmd %s sent", cmdOutFlush), logTag)
}
}
cmdResult, _ := ten.NewCmdResult(ten.StatusCodeOk)
tenEnv.ReturnResult(cmdResult, cmd)
tenEnv.ReturnResult(cmdResult, cmd, nil)
}

// OnData receives data from ten graph.
Expand Down Expand Up @@ -351,7 +351,7 @@ func (p *openaiChatGPTExtension) OnData(
}
outputData.SetProperty(dataOutTextDataPropertyText, sentence)
outputData.SetProperty(dataOutTextDataPropertyTextEndOfSegment, false)
if err := tenEnv.SendData(outputData); err != nil {
if err := tenEnv.SendData(outputData, nil); err != nil {
slog.Error(fmt.Sprintf("GetChatCompletionsStream recv for input text: [%s] send sentence [%s] failed, err: %v", inputText, sentence, err), logTag)
break
} else {
Expand All @@ -377,7 +377,7 @@ func (p *openaiChatGPTExtension) OnData(
outputData, _ := ten.NewData("text_data")
outputData.SetProperty(dataOutTextDataPropertyText, sentence)
outputData.SetProperty(dataOutTextDataPropertyTextEndOfSegment, true)
if err := tenEnv.SendData(outputData); err != nil {
if err := tenEnv.SendData(outputData, nil); err != nil {
slog.Error(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] end of segment with sentence [%s] send failed, err: %v", inputText, sentence, err), logTag)
} else {
slog.Info(fmt.Sprintf("GetChatCompletionsStream for input text: [%s] end of segment with sentence [%s] sent", inputText, sentence), logTag)
Expand Down
1 change: 0 additions & 1 deletion agents/ten_packages/extension/polly_tts/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@ ten_package("polly_tts") {
"extension.py",
"manifest.json",
"property.json",
"tests",
]
}
2 changes: 1 addition & 1 deletion agents/ten_packages/extension/polly_tts/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def on_request_tts(
try:
data = self.client.text_to_speech_stream(ten_env, input_text)
async for frame in data:
self.send_audio_out(
await self.send_audio_out(
ten_env, frame, sample_rate=self.client.config.sample_rate
)
except Exception:
Expand Down
36 changes: 25 additions & 11 deletions agents/ten_packages/system/ten_ai_base/interface/ten_ai_base/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from ten.audio_frame import AudioFrame, AudioFrameDataFmt
from ten.cmd import Cmd
from ten.cmd_result import CmdResult, StatusCode
from ten_ai_base.const import CMD_IN_FLUSH, CMD_OUT_FLUSH, DATA_IN_PROPERTY_END_OF_SEGMENT, DATA_IN_PROPERTY_TEXT
from ten_ai_base.const import (
CMD_IN_FLUSH,
CMD_OUT_FLUSH,
DATA_IN_PROPERTY_END_OF_SEGMENT,
DATA_IN_PROPERTY_TEXT,
)
from ten_ai_base.types import TTSPcmOptions
from .helper import AsyncQueue, PCMWriter, get_property_bool, get_property_string

Expand All @@ -28,14 +33,15 @@ class AsyncTTSBaseExtension(AsyncExtension, ABC):
Use begin_send_audio_out, send_audio_out, end_send_audio_out to send the audio data to the output.
Override on_request_tts to implement the TTS logic.
"""

# Create the queue for message processing

def __init__(self, name: str):
super().__init__(name)
self.queue = AsyncQueue()
self.current_task = None
self.loop_task = None
self.leftover_bytes = b''
self.leftover_bytes = b""

async def on_init(self, ten_env: AsyncTenEnv) -> None:
await super().on_init(ten_env)
Expand Down Expand Up @@ -66,7 +72,7 @@ async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None:
status_code, detail = StatusCode.OK, "success"
cmd_result = CmdResult.create(status_code)
cmd_result.set_property_string("detail", detail)
async_ten_env.return_result(cmd_result, cmd)
await async_ten_env.return_result(cmd_result, cmd)

async def on_data(self, async_ten_env: AsyncTenEnv, data: Data) -> None:
# Get the necessary properties
Expand All @@ -91,7 +97,9 @@ async def flush_input_items(self, ten_env: AsyncTenEnv):
ten_env.log_info("Cancelling the current task during flush.")
self.current_task.cancel()

def send_audio_out(self, ten_env: AsyncTenEnv, audio_data: bytes, **args: TTSPcmOptions) -> None:
async def send_audio_out(
self, ten_env: AsyncTenEnv, audio_data: bytes, **args: TTSPcmOptions
) -> None:
"""End sending audio out."""
sample_rate = args.get("sample_rate", 16000)
bytes_per_sample = args.get("bytes_per_sample", 2)
Expand All @@ -103,29 +111,35 @@ def send_audio_out(self, ten_env: AsyncTenEnv, audio_data: bytes, **args: TTSPcm
# Check if combined_data length is odd
if len(combined_data) % (bytes_per_sample * number_of_channels) != 0:
# Save the last incomplete frame
valid_length = len(combined_data) - (len(combined_data) % (bytes_per_sample * number_of_channels))
valid_length = len(combined_data) - (
len(combined_data) % (bytes_per_sample * number_of_channels)
)
self.leftover_bytes = combined_data[valid_length:]
combined_data = combined_data[:valid_length]
else:
self.leftover_bytes = b''
self.leftover_bytes = b""

if combined_data:
f = AudioFrame.create("pcm_frame")
f.set_sample_rate(sample_rate)
f.set_bytes_per_sample(bytes_per_sample)
f.set_number_of_channels(number_of_channels)
f.set_data_fmt(AudioFrameDataFmt.INTERLEAVE)
f.set_samples_per_channel(len(combined_data) // (bytes_per_sample * number_of_channels))
f.set_samples_per_channel(
len(combined_data) // (bytes_per_sample * number_of_channels)
)
f.alloc_buf(len(combined_data))
buff = f.lock_buf()
buff[:] = combined_data
f.unlock_buf(buff)
ten_env.send_audio_frame(f)
await ten_env.send_audio_frame(f)
except Exception as e:
ten_env.log_error(f"error send audio frame, {traceback.format_exc()}")

@abstractmethod
async def on_request_tts(self, ten_env: AsyncTenEnv, input_text: str, end_of_segment: bool) -> None:
async def on_request_tts(
self, ten_env: AsyncTenEnv, input_text: str, end_of_segment: bool
) -> None:
"""
Called when a new input item is available in the queue. Override this method to implement the TTS request logic.
Use send_audio_out to send the audio data to the output when the audio data is ready.
Expand All @@ -137,7 +151,6 @@ async def on_cancel_tts(self, ten_env: AsyncTenEnv) -> None:
"""Called when the TTS request is cancelled."""
pass


async def _process_queue(self, ten_env: AsyncTenEnv):
"""Asynchronously process queue items one by one."""
while True:
Expand All @@ -146,7 +159,8 @@ async def _process_queue(self, ten_env: AsyncTenEnv):

try:
self.current_task = asyncio.create_task(
self.on_request_tts(ten_env, text, end_of_segment))
self.on_request_tts(ten_env, text, end_of_segment)
)
await self.current_task # Wait for the current task to finish or be cancelled
except asyncio.CancelledError:
ten_env.log_info(f"Task cancelled: {text}")
Expand Down

0 comments on commit ae17c15

Please sign in to comment.