diff --git a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py index 30ca0e6f63..9ad4492063 100644 --- a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py +++ b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/stt.py @@ -34,7 +34,7 @@ NotGivenOr, ) from livekit.agents.utils import AudioBuffer, is_given -from speechmatics.rt import ( # type: ignore +from speechmatics.rt import ( AsyncClient, AudioEncoding, AudioFormat, @@ -71,6 +71,7 @@ class STTOptions: additional_vocab: list[AdditionalVocabEntry] = dataclasses.field(default_factory=list) punctuation_overrides: dict = dataclasses.field(default_factory=dict) diarization_sensitivity: float = 0.5 + max_speakers: int | None = None speaker_active_format: str = "{text}" speaker_passive_format: str = "{text}" prefer_current_speaker: bool = False @@ -98,6 +99,7 @@ def __init__( additional_vocab: NotGivenOr[list[AdditionalVocabEntry]] = NOT_GIVEN, punctuation_overrides: NotGivenOr[dict] = NOT_GIVEN, diarization_sensitivity: float = 0.5, + max_speakers: NotGivenOr[int] = NOT_GIVEN, speaker_active_format: str = "{text}", speaker_passive_format: str = "{text}", prefer_current_speaker: bool = False, @@ -172,6 +174,11 @@ def __init__( the sensitivity of diarization and helps when two or more speakers have similar voices. Defaults to 0.5. + max_speakers (int): Maximum number of speakers to detect during diarization. When set, + the STT engine will limit the number of unique speakers identified in the transcription. + This is useful for scenarios where you know the maximum number of participants (e.g., + 2-person interviews, small group meetings). Optional. + speaker_active_format (str): Formatter for active speaker ID. This formatter is used to format the text output for individual speakers and ensures that the context is clear for language models further down the pipeline. The attributes `text` and @@ -245,20 +252,26 @@ def __init__( config: TranscriptionConfig = transcription_config language = language if is_given(language) else config.language - output_locale = output_locale if is_given(output_locale) else config.output_locale - domain = domain if is_given(domain) else config.domain - operating_point = operating_point or config.operating_point + if not is_given(output_locale) and config.output_locale is not None: + output_locale = config.output_locale + if not is_given(domain) and config.domain is not None: + domain = config.domain enable_diarization = enable_diarization or config.diarization == "speaker" - enable_partials = enable_partials or config.enable_partials - max_delay = max_delay or config.max_delay - additional_vocab = ( - additional_vocab if is_given(additional_vocab) else config.additional_vocab - ) - punctuation_overrides = ( - punctuation_overrides - if is_given(punctuation_overrides) - else config.punctuation_overrides - ) + if not is_given(additional_vocab) and config.additional_vocab is not None: + additional_vocab = [ + AdditionalVocabEntry(content=k, sounds_like=v) + for k, v in config.additional_vocab.items() + ] + if not is_given(punctuation_overrides) and config.punctuation_overrides is not None: + punctuation_overrides = config.punctuation_overrides + # Extract max_speakers from speaker_diarization_config if present + if ( + not is_given(max_speakers) + and (dz_cfg := config.speaker_diarization_config) + and hasattr(dz_cfg, "max_speakers") + and dz_cfg.max_speakers is not None + ): + max_speakers = dz_cfg.max_speakers if is_given(audio_settings): logger.warning( @@ -267,7 +280,7 @@ def __init__( audio: AudioSettings = audio_settings sample_rate = sample_rate or audio.sample_rate - audio_encoding = audio_encoding or audio.encoding + audio_encoding = audio_encoding or AudioEncoding(audio.encoding) self._stt_options = STTOptions( operating_point=operating_point, @@ -282,6 +295,7 @@ def __init__( additional_vocab=additional_vocab if is_given(additional_vocab) else [], punctuation_overrides=punctuation_overrides if is_given(punctuation_overrides) else {}, diarization_sensitivity=diarization_sensitivity, + max_speakers=max_speakers if is_given(max_speakers) else None, speaker_active_format=speaker_active_format, speaker_passive_format=speaker_passive_format, prefer_current_speaker=prefer_current_speaker, @@ -370,26 +384,30 @@ def _process_config(self) -> None: ) if self._stt_options.additional_vocab: - transcription_config.additional_vocab = [ - { - "content": e.content, - "sounds_like": e.sounds_like, - } + # API expects list of dicts, not dict format + transcription_config.additional_vocab = [ # type: ignore + {"content": e.content, "sounds_like": e.sounds_like} for e in self._stt_options.additional_vocab ] if self._stt_options.enable_diarization: - dz_cfg: dict[str, Any] = {} - if self._stt_options.diarization_sensitivity is not None: - dz_cfg["speaker_sensitivity"] = self._stt_options.diarization_sensitivity - if self._stt_options.prefer_current_speaker is not None: - dz_cfg["prefer_current_speaker"] = self._stt_options.prefer_current_speaker + # Use dict for speaker diarization config to support all fields including speakers + dz_cfg: dict[str, Any] = { + "speaker_sensitivity": self._stt_options.diarization_sensitivity, + "prefer_current_speaker": self._stt_options.prefer_current_speaker, + } + + # Add max_speakers if provided + if self._stt_options.max_speakers is not None: + dz_cfg["max_speakers"] = self._stt_options.max_speakers + + # Add speakers mapping from known speakers if self._stt_options.known_speakers: dz_cfg["speakers"] = { s.label: s.speaker_identifiers for s in self._stt_options.known_speakers } - if dz_cfg: - transcription_config.speaker_diarization_config = dz_cfg + + transcription_config.speaker_diarization_config = dz_cfg # type: ignore[assignment] if ( self._stt_options.end_of_utterance_silence_trigger and self._stt_options.end_of_utterance_mode == EndOfUtteranceMode.FIXED @@ -461,23 +479,23 @@ async def _run(self) -> None: opts = self._stt._stt_options - @self._client.on(ServerMessageType.RECOGNITION_STARTED) # type: ignore + @self._client.on(ServerMessageType.RECOGNITION_STARTED) def _evt_on_recognition_started(message: dict[str, Any]) -> None: logger.debug("Recognition started", extra={"data": message}) if opts.enable_partials: - @self._client.on(ServerMessageType.ADD_PARTIAL_TRANSCRIPT) # type: ignore + @self._client.on(ServerMessageType.ADD_PARTIAL_TRANSCRIPT) def _evt_on_partial_transcript(message: dict[str, Any]) -> None: self._handle_transcript(message, is_final=False) - @self._client.on(ServerMessageType.ADD_TRANSCRIPT) # type: ignore + @self._client.on(ServerMessageType.ADD_TRANSCRIPT) def _evt_on_final_transcript(message: dict[str, Any]) -> None: self._handle_transcript(message, is_final=True) if opts.end_of_utterance_mode == EndOfUtteranceMode.FIXED: - @self._client.on(ServerMessageType.END_OF_UTTERANCE) # type: ignore + @self._client.on(ServerMessageType.END_OF_UTTERANCE) def _evt_on_end_of_utterance(message: dict[str, Any]) -> None: self._handle_end_of_utterance() diff --git a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/types.py b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/types.py index 4a660a81fa..d97fd6cd56 100644 --- a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/types.py +++ b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/types.py @@ -5,7 +5,7 @@ from typing import Any from livekit.agents.stt import SpeechData -from speechmatics.rt import TranscriptionConfig # type: ignore +from speechmatics.rt import TranscriptionConfig __all__ = ["TranscriptionConfig"] diff --git a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/utils.py b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/utils.py index 397a19c44c..181afd4e94 100644 --- a/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/utils.py +++ b/livekit-plugins/livekit-plugins-speechmatics/livekit/plugins/speechmatics/utils.py @@ -1,8 +1,6 @@ from urllib.parse import urlencode -from speechmatics.rt import ( # type: ignore - __version__ as sdk_version, -) +from speechmatics.rt import __version__ as sdk_version from .version import __version__ as lk_version diff --git a/livekit-plugins/livekit-plugins-speechmatics/pyproject.toml b/livekit-plugins/livekit-plugins-speechmatics/pyproject.toml index 0bb4013240..1254f5c62d 100644 --- a/livekit-plugins/livekit-plugins-speechmatics/pyproject.toml +++ b/livekit-plugins/livekit-plugins-speechmatics/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3 :: Only", ] + dependencies = ["livekit-agents>=1.2.14", "speechmatics-rt>=0.4.0"] [project.urls] diff --git a/pyproject.toml b/pyproject.toml index b6046f3266..e04425c023 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,3 +126,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = "smithy_core.*" ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "speechmatics.*" +follow_untyped_imports = true