Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
NotGivenOr,
)
from livekit.agents.utils import AudioBuffer, is_given
from speechmatics.rt import ( # type: ignore
from speechmatics.rt import (
AsyncClient,
AudioEncoding,
AudioFormat,
Expand Down Expand Up @@ -71,6 +71,7 @@ class STTOptions:
additional_vocab: list[AdditionalVocabEntry] = dataclasses.field(default_factory=list)
punctuation_overrides: dict = dataclasses.field(default_factory=dict)
diarization_sensitivity: float = 0.5
max_speakers: int | None = None
speaker_active_format: str = "{text}"
speaker_passive_format: str = "{text}"
prefer_current_speaker: bool = False
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
additional_vocab: NotGivenOr[list[AdditionalVocabEntry]] = NOT_GIVEN,
punctuation_overrides: NotGivenOr[dict] = NOT_GIVEN,
diarization_sensitivity: float = 0.5,
max_speakers: NotGivenOr[int] = NOT_GIVEN,
speaker_active_format: str = "{text}",
speaker_passive_format: str = "{text}",
prefer_current_speaker: bool = False,
Expand Down Expand Up @@ -172,6 +174,11 @@ def __init__(
the sensitivity of diarization and helps when two or more speakers have similar voices.
Defaults to 0.5.

max_speakers (int): Maximum number of speakers to detect during diarization. When set,
the STT engine will limit the number of unique speakers identified in the transcription.
This is useful for scenarios where you know the maximum number of participants (e.g.,
2-person interviews, small group meetings). Optional.

speaker_active_format (str): Formatter for active speaker ID. This formatter is used
to format the text output for individual speakers and ensures that the context is
clear for language models further down the pipeline. The attributes `text` and
Expand Down Expand Up @@ -245,20 +252,26 @@ def __init__(

config: TranscriptionConfig = transcription_config
language = language if is_given(language) else config.language
output_locale = output_locale if is_given(output_locale) else config.output_locale
domain = domain if is_given(domain) else config.domain
operating_point = operating_point or config.operating_point
if not is_given(output_locale) and config.output_locale is not None:
output_locale = config.output_locale
if not is_given(domain) and config.domain is not None:
domain = config.domain
enable_diarization = enable_diarization or config.diarization == "speaker"
enable_partials = enable_partials or config.enable_partials
max_delay = max_delay or config.max_delay
additional_vocab = (
additional_vocab if is_given(additional_vocab) else config.additional_vocab
)
punctuation_overrides = (
punctuation_overrides
if is_given(punctuation_overrides)
else config.punctuation_overrides
)
if not is_given(additional_vocab) and config.additional_vocab is not None:
additional_vocab = [
AdditionalVocabEntry(content=k, sounds_like=v)
for k, v in config.additional_vocab.items()
]
if not is_given(punctuation_overrides) and config.punctuation_overrides is not None:
punctuation_overrides = config.punctuation_overrides
# Extract max_speakers from speaker_diarization_config if present
if (
not is_given(max_speakers)
and (dz_cfg := config.speaker_diarization_config)
and hasattr(dz_cfg, "max_speakers")
and dz_cfg.max_speakers is not None
):
max_speakers = dz_cfg.max_speakers

if is_given(audio_settings):
logger.warning(
Expand All @@ -267,7 +280,7 @@ def __init__(

audio: AudioSettings = audio_settings
sample_rate = sample_rate or audio.sample_rate
audio_encoding = audio_encoding or audio.encoding
audio_encoding = audio_encoding or AudioEncoding(audio.encoding)

self._stt_options = STTOptions(
operating_point=operating_point,
Expand All @@ -282,6 +295,7 @@ def __init__(
additional_vocab=additional_vocab if is_given(additional_vocab) else [],
punctuation_overrides=punctuation_overrides if is_given(punctuation_overrides) else {},
diarization_sensitivity=diarization_sensitivity,
max_speakers=max_speakers if is_given(max_speakers) else None,
speaker_active_format=speaker_active_format,
speaker_passive_format=speaker_passive_format,
prefer_current_speaker=prefer_current_speaker,
Expand Down Expand Up @@ -370,26 +384,30 @@ def _process_config(self) -> None:
)

if self._stt_options.additional_vocab:
transcription_config.additional_vocab = [
{
"content": e.content,
"sounds_like": e.sounds_like,
}
# API expects list of dicts, not dict format
transcription_config.additional_vocab = [ # type: ignore
{"content": e.content, "sounds_like": e.sounds_like}
for e in self._stt_options.additional_vocab
]

if self._stt_options.enable_diarization:
dz_cfg: dict[str, Any] = {}
if self._stt_options.diarization_sensitivity is not None:
dz_cfg["speaker_sensitivity"] = self._stt_options.diarization_sensitivity
if self._stt_options.prefer_current_speaker is not None:
dz_cfg["prefer_current_speaker"] = self._stt_options.prefer_current_speaker
# Use dict for speaker diarization config to support all fields including speakers
dz_cfg: dict[str, Any] = {
"speaker_sensitivity": self._stt_options.diarization_sensitivity,
"prefer_current_speaker": self._stt_options.prefer_current_speaker,
}

# Add max_speakers if provided
if self._stt_options.max_speakers is not None:
dz_cfg["max_speakers"] = self._stt_options.max_speakers

# Add speakers mapping from known speakers
if self._stt_options.known_speakers:
dz_cfg["speakers"] = {
s.label: s.speaker_identifiers for s in self._stt_options.known_speakers
}
if dz_cfg:
transcription_config.speaker_diarization_config = dz_cfg

transcription_config.speaker_diarization_config = dz_cfg # type: ignore[assignment]
if (
self._stt_options.end_of_utterance_silence_trigger
and self._stt_options.end_of_utterance_mode == EndOfUtteranceMode.FIXED
Expand Down Expand Up @@ -461,23 +479,23 @@ async def _run(self) -> None:

opts = self._stt._stt_options

@self._client.on(ServerMessageType.RECOGNITION_STARTED) # type: ignore
@self._client.on(ServerMessageType.RECOGNITION_STARTED)
def _evt_on_recognition_started(message: dict[str, Any]) -> None:
logger.debug("Recognition started", extra={"data": message})

if opts.enable_partials:

@self._client.on(ServerMessageType.ADD_PARTIAL_TRANSCRIPT) # type: ignore
@self._client.on(ServerMessageType.ADD_PARTIAL_TRANSCRIPT)
def _evt_on_partial_transcript(message: dict[str, Any]) -> None:
self._handle_transcript(message, is_final=False)

@self._client.on(ServerMessageType.ADD_TRANSCRIPT) # type: ignore
@self._client.on(ServerMessageType.ADD_TRANSCRIPT)
def _evt_on_final_transcript(message: dict[str, Any]) -> None:
self._handle_transcript(message, is_final=True)

if opts.end_of_utterance_mode == EndOfUtteranceMode.FIXED:

@self._client.on(ServerMessageType.END_OF_UTTERANCE) # type: ignore
@self._client.on(ServerMessageType.END_OF_UTTERANCE)
def _evt_on_end_of_utterance(message: dict[str, Any]) -> None:
self._handle_end_of_utterance()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any

from livekit.agents.stt import SpeechData
from speechmatics.rt import TranscriptionConfig # type: ignore
from speechmatics.rt import TranscriptionConfig

__all__ = ["TranscriptionConfig"]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from urllib.parse import urlencode

from speechmatics.rt import ( # type: ignore
__version__ as sdk_version,
)
from speechmatics.rt import __version__ as sdk_version

from .version import __version__ as lk_version

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3 :: Only",
]

dependencies = ["livekit-agents>=1.2.14", "speechmatics-rt>=0.4.0"]

[project.urls]
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,7 @@ ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "smithy_core.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "speechmatics.*"
follow_untyped_imports = true