Skip to content
Draft
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
048b007
feat(TTS): add WebSocket support and configurable protocol options
gokuljs Dec 25, 2025
6b86b51
feat(TTS): implement JSONSynthesizeStream for enhanced audio synthesi…
gokuljs Dec 25, 2025
6316b53
add segment and additional options for enhanced audio synthesis conf…
gokuljs Dec 25, 2025
aeb66ea
refactor(TTS): clean up whitespace in tts.py for improved readability
gokuljs Dec 25, 2025
f7872ac
feat(TTS): enhance JSONSynthesizeStream with WebSocket error handling…
gokuljs Dec 26, 2025
ea3d7fd
feat(TTS): add WebSocket message handling and buffer management to JS…
gokuljs Dec 26, 2025
a9845b4
fix(TTS): update JSONSynthesizeStream to handle 'chunk' type messages…
gokuljs Dec 26, 2025
3aebd5e
feat(TTS): initialize WebSocket response handling in JSONSynthesizeSt…
gokuljs Dec 26, 2025
d6c26ba
refactor(TTS): improve WebSocket message handling and buffer manageme…
gokuljs Dec 26, 2025
68aa43c
refactor(TTS): utilize configurable model timeout in JSONSynthesizeSt…
gokuljs Dec 26, 2025
5f650ce
fix(TTS): prevent use of unsupported Arcana model for WebSocket strea…
gokuljs Dec 26, 2025
96649d0
fix(TTS): clarify error message for unsupported Arcana model in JSONS…
gokuljs Dec 26, 2025
f81d670
refactor(TTS): streamline WebSocket message handling in JSONSynthesiz…
gokuljs Jan 5, 2026
d2f0579
refactor(TTS): remove unused WebSocket text URL and related parameter…
gokuljs Jan 8, 2026
54c7336
refactor(TTS): reorganize imports and remove unused WebSocket text UR…
gokuljs Jan 8, 2026
37dbfea
refactor(TTS): update parameter type hinting in _build_ws_url method …
gokuljs Jan 8, 2026
514a526
refactor(TTS): update WebSocket timeout handling and streamline task …
gokuljs Jan 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 209 additions & 5 deletions livekit-plugins/livekit-plugins-rime/livekit/plugins/rime/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
from __future__ import annotations

import asyncio
import base64
import json
import os
from dataclasses import dataclass, replace
from typing import Literal
from urllib.parse import urlencode

import aiohttp

Expand All @@ -34,6 +38,7 @@
NotGivenOr,
)
from livekit.agents.utils import is_given
from livekit.agents.voice.io import TimedString

from .langs import TTSLangs
from .log import logger
Expand All @@ -42,13 +47,15 @@
# arcana can take as long as 80% of the total duration of the audio it's synthesizing.
ARCANA_MODEL_TIMEOUT = 60 * 4
MISTV2_MODEL_TIMEOUT = 30
RIME_BASE_URL = "https://users.rime.ai/v1/rime-tts"
RIME_BASE_URL = "https://users.rime.ai/v1/rime-tts" # http
RIME_WS_JSON_URL = "wss://users.rime.ai/ws2" # ws_json


@dataclass
class _TTSOptions:
model: TTSModels | str
speaker: str
segment: NotGivenOr[str] = NOT_GIVEN
arcana_options: _ArcanaOptions | None = None
mistv2_options: _Mistv2Options | None = None

Expand All @@ -71,6 +78,10 @@ class _Mistv2Options:
reduce_latency: NotGivenOr[bool] = NOT_GIVEN
pause_between_brackets: NotGivenOr[bool] = NOT_GIVEN
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN
# websocket specific
no_text_normalization: NotGivenOr[bool] = NOT_GIVEN
inline_speed_alpha: NotGivenOr[str] = NOT_GIVEN
save_oovs: NotGivenOr[bool] = NOT_GIVEN


NUM_CHANNELS = 1
Expand All @@ -81,9 +92,12 @@ def __init__(
self,
*,
base_url: str = RIME_BASE_URL,
ws_json_url: str = RIME_WS_JSON_URL,
protocol: Literal["http", "ws_json"] = "http",
model: TTSModels | str = "arcana",
speaker: NotGivenOr[ArcanaVoices | str] = NOT_GIVEN,
lang: TTSLangs | str = "eng",
segment: NotGivenOr[str] = NOT_GIVEN,
# Arcana options
repetition_penalty: NotGivenOr[float] = NOT_GIVEN,
temperature: NotGivenOr[float] = NOT_GIVEN,
Expand All @@ -95,12 +109,15 @@ def __init__(
reduce_latency: NotGivenOr[bool] = NOT_GIVEN,
pause_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
no_text_normalization: NotGivenOr[bool] = NOT_GIVEN,
inline_speed_alpha: NotGivenOr[str] = NOT_GIVEN,
save_oovs: NotGivenOr[bool] = NOT_GIVEN,
api_key: NotGivenOr[str] = NOT_GIVEN,
http_session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
capabilities=tts.TTSCapabilities(
streaming=False,
streaming=protocol != "http",
),
sample_rate=sample_rate,
num_channels=NUM_CHANNELS,
Expand All @@ -120,6 +137,7 @@ def __init__(
self._opts = _TTSOptions(
model=model,
speaker=speaker,
segment=segment,
)
if model == "arcana":
self._opts.arcana_options = _ArcanaOptions(
Expand All @@ -138,9 +156,14 @@ def __init__(
reduce_latency=reduce_latency,
pause_between_brackets=pause_between_brackets,
phonemize_between_brackets=phonemize_between_brackets,
no_text_normalization=no_text_normalization,
inline_speed_alpha=inline_speed_alpha,
save_oovs=save_oovs,
)
self._session = http_session
self._base_url = base_url
self._ws_json_url = ws_json_url
self._protocol = protocol

self._total_timeout = ARCANA_MODEL_TIMEOUT if model == "arcana" else MISTV2_MODEL_TIMEOUT

Expand All @@ -159,10 +182,18 @@ def _ensure_session(self) -> aiohttp.ClientSession:
return self._session

def synthesize(
self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> ChunkedStream:
return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)

def stream(
self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
) -> JSONSynthesizeStream:
return JSONSynthesizeStream(tts=self, conn_options=conn_options)

def update_options(
self,
*,
Expand All @@ -181,12 +212,20 @@ def update_options(
pause_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
base_url: NotGivenOr[str] = NOT_GIVEN,
ws_json_url: NotGivenOr[str] = NOT_GIVEN,
segment: NotGivenOr[str] = NOT_GIVEN,
no_text_normalization: NotGivenOr[bool] = NOT_GIVEN,
save_oovs: NotGivenOr[bool] = NOT_GIVEN,
inline_speed_alpha: NotGivenOr[str] = NOT_GIVEN,
) -> None:
if is_given(base_url):
self._base_url = base_url
if is_given(ws_json_url):
self._ws_json_url = ws_json_url
if is_given(segment):
self._opts.segment = segment
if is_given(model):
self._opts.model = model

if model == "arcana" and self._opts.arcana_options is None:
self._opts.arcana_options = _ArcanaOptions()
elif model == "mistv2" and self._opts.mistv2_options is None:
Expand Down Expand Up @@ -222,6 +261,12 @@ def update_options(
self._opts.mistv2_options.pause_between_brackets = pause_between_brackets
if is_given(phonemize_between_brackets):
self._opts.mistv2_options.phonemize_between_brackets = phonemize_between_brackets
if is_given(no_text_normalization):
self._opts.mistv2_options.no_text_normalization = no_text_normalization
if is_given(inline_speed_alpha):
self._opts.mistv2_options.inline_speed_alpha = inline_speed_alpha
if is_given(save_oovs):
self._opts.mistv2_options.save_oovs = save_oovs


class ChunkedStream(tts.ChunkedStream):
Expand Down Expand Up @@ -280,7 +325,8 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
},
json=payload,
timeout=aiohttp.ClientTimeout(
total=self._tts._total_timeout, sock_connect=self._conn_options.timeout
total=self._tts._total_timeout,
sock_connect=self._conn_options.timeout,
),
) as resp:
resp.raise_for_status()
Expand Down Expand Up @@ -308,3 +354,161 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
) from None
except Exception as e:
raise APIConnectionError() from e


class JSONSynthesizeStream(tts.SynthesizeStream):
def __init__(self, tts: TTS, conn_options: APIConnectOptions) -> None:
super().__init__(tts=tts, conn_options=conn_options)
self._tts: TTS = tts
self._opts = replace(tts._opts)
self._ws: aiohttp.ClientWebSocketResponse | None = None
if self._opts.model == "arcana":
raise ValueError(
"The Arcana model is not supported for JSON WebSocket streaming. Please switch to the 'mistv2' model."
)

def _build_ws_url(self) -> str:
params: dict[str, str | int | float | bool] = {
"modelId": self._opts.model,
"speaker": self._opts.speaker,
"audioFormat": "pcm",
}
if is_given(self._opts.segment):
params["segment"] = self._opts.segment
elif self._opts.model == "mistv2":
mistv2_opts = self._opts.mistv2_options
assert mistv2_opts is not None
if is_given(mistv2_opts.lang):
params["lang"] = mistv2_opts.lang
if is_given(mistv2_opts.sample_rate):
params["samplingRate"] = mistv2_opts.sample_rate
if is_given(mistv2_opts.speed_alpha):
params["speedAlpha"] = mistv2_opts.speed_alpha
if is_given(mistv2_opts.reduce_latency):
params["reduceLatency"] = mistv2_opts.reduce_latency
if is_given(mistv2_opts.pause_between_brackets):
params["pauseBetweenBrackets"] = mistv2_opts.pause_between_brackets
if is_given(mistv2_opts.phonemize_between_brackets):
params["phonemizeBetweenBrackets"] = mistv2_opts.phonemize_between_brackets
if is_given(mistv2_opts.no_text_normalization):
params["noTextNormalization"] = mistv2_opts.no_text_normalization
if is_given(mistv2_opts.inline_speed_alpha):
params["inlineSpeedAlpha"] = mistv2_opts.inline_speed_alpha
if is_given(mistv2_opts.save_oovs):
params["saveOovs"] = mistv2_opts.save_oovs
return f"{self._tts._ws_json_url}?{urlencode(params)}"

async def clear_buffer(self) -> None:
"""Send clear operation to discard buffered text"""
if self._ws and not self._ws.closed:
await self._ws.send_str(json.dumps({"operation": "clear"}))

async def aclose(self) -> None:
"""Close the stream and send EOS if needed"""
if self._ws and not self._ws.closed:
await self._ws.send_str(json.dumps({"operation": "eos"}))
await super().aclose()

async def _send_task(self, ws: aiohttp.ClientWebSocketResponse) -> None:
try:
async for input_data in self._input_ch:
if isinstance(input_data, str):
await ws.send_str(json.dumps({"text": input_data}))
await ws.send_str(json.dumps({"operation": "flush"}))
elif isinstance(input_data, self._FlushSentinel):
await ws.send_str(json.dumps({"operation": "flush"}))
except Exception as e:
logger.error("Rime WebSocket send task failed: %s", e)
raise APIConnectionError(f"Send task failed: {e}") from e

async def _recv_task(
self, ws: aiohttp.ClientWebSocketResponse, output_emitter: tts.AudioEmitter
) -> None:
segment_started = False

while True:
msg = await ws.receive()
if msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSING,
):
if segment_started:
output_emitter.end_segment()
output_emitter.end_input()
return

if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("Unexpected Rime message type: %s", msg.type)
continue

try:
data = json.loads(msg.data)
except json.JSONDecodeError:
logger.warning("Invalid JSON from Rime: %s", msg.data)
continue

if data.get("type") == "chunk":
if not segment_started:
segment_id = data.get("contextId") or utils.shortuuid()
output_emitter.start_segment(segment_id=segment_id)
segment_started = True
audio_data = base64.b64decode(data["data"])
output_emitter.push(audio_data)
elif data.get("type") == "done":
if segment_started:
output_emitter.end_segment()
output_emitter.end_input()
break
elif data.get("type") == "timestamps":
word_timestamps = data.get("word_timestamps", {})
words = word_timestamps.get("words", [])
starts = word_timestamps.get("start", [])
ends = word_timestamps.get("end", [])

timed_words = []
for word, start, end in zip(words, starts, ends):
timed_words.append(TimedString(text=word, start_time=start, end_time=end))
if timed_words:
output_emitter.push_timed_transcript(timed_words)
elif data.get("type") == "error":
logger.error(f"Rime error: {data.get('message')}")
if segment_started:
output_emitter.end_segment()
output_emitter.end_input()
raise APIStatusError(f"Rime error: {data.get('message')}")

async def _run(self, output_emitter: tts.AudioEmitter) -> None:
request_id = utils.shortuuid()
output_emitter.initialize(
request_id=request_id,
sample_rate=self._tts.sample_rate,
num_channels=NUM_CHANNELS,
mime_type="audio/pcm",
stream=True,
)
ws_url = self._build_ws_url()

send_task = None
recv_task = None

try:
async with self._tts._ensure_session().ws_connect(
ws_url,
headers={"Authorization": f"Bearer {self._tts._api_key}"},
timeout=aiohttp.ClientWSTimeout(ws_close=self._tts._total_timeout),
) as ws:
self._ws = ws
send_task = asyncio.create_task(self._send_task(ws))
recv_task = asyncio.create_task(self._recv_task(ws, output_emitter))
await asyncio.gather(send_task, recv_task)
except asyncio.TimeoutError:
raise APITimeoutError() from None
except aiohttp.ClientError as e:
raise APIConnectionError() from e
except Exception as e:
raise APIConnectionError() from e
finally:
tasks_to_cancel = [t for t in [send_task, recv_task] if t is not None]
if tasks_to_cancel:
await utils.aio.gracefully_cancel(*tasks_to_cancel)