Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
See https://docs.livekit.io/agents/integrations/stt/soniox/ for more information.
"""

from .stt import STT, STTOptions
from .stt import STT, ContextGeneralItem, ContextObject, ContextTranslationTerm, STTOptions
from .version import __version__

__all__ = ["STT", "STTOptions", "__version__"]
__all__ = [
"STT",
"STTOptions",
"ContextObject",
"ContextGeneralItem",
"ContextTranslationTerm",
"__version__",
]


from livekit.agents import Plugin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import json
import os
import time
from dataclasses import dataclass
from dataclasses import asdict, dataclass

import aiohttp

Expand Down Expand Up @@ -56,22 +56,47 @@ def is_end_token(token: dict) -> bool:
return token.get("text") in (END_TOKEN, FINALIZED_TOKEN)


@dataclass
class ContextGeneralItem:
key: str
value: str


@dataclass
class ContextTranslationTerm:
source: str
target: str


@dataclass
class ContextObject:
"""Context object for models with context_version 2, for Soniox stt-rt-v3-preview and higher.

Learn more about context in the documentation:
https://soniox.com/docs/stt/concepts/context
"""

general: list[ContextGeneralItem] | None = None
text: str | None = None
terms: list[str] | None = None
translation_terms: list[ContextTranslationTerm] | None = None


@dataclass
class STTOptions:
"""Configuration options for Soniox Speech-to-Text service."""

model: str | None = "stt-rt-preview"

language_hints: list[str] | None = None
context: str | None = None
context: ContextObject | str | None = None

num_channels: int = 1
sample_rate: int = 16000

enable_speaker_diarization: bool = False
enable_language_identification: bool = True

enable_non_final_tokens: bool = True
max_non_final_tokens_duration_ms: int | None = None

client_reference_id: str | None = None


Expand Down Expand Up @@ -176,6 +201,10 @@ async def _connect_ws(self):
# If VAD was passed, disable endpoint detection, otherwise enable it.
enable_endpoint_detection = not self._stt._vad_stream

context = self._stt._params.context
if isinstance(context, ContextObject):
context = asdict(context)

# Create initial config object.
config = {
"api_key": self._stt._api_key,
Expand All @@ -185,9 +214,8 @@ async def _connect_ws(self):
"enable_endpoint_detection": enable_endpoint_detection,
"sample_rate": self._stt._params.sample_rate,
"language_hints": self._stt._params.language_hints,
"context": self._stt._params.context,
"enable_non_final_tokens": self._stt._params.enable_non_final_tokens,
"max_non_final_tokens_duration_ms": self._stt._params.max_non_final_tokens_duration_ms,
"context": context,
"enable_speaker_diarization": self._stt._params.enable_speaker_diarization,
"enable_language_identification": self._stt._params.enable_language_identification,
"client_reference_id": self._stt._params.client_reference_id,
}
Expand Down