Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions homeassistant/components/assist_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def async_pipeline_from_audio_stream(
event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata,
stt_stream: AsyncIterable[bytes],
wake_word_phrase: str | None = None,
pipeline_id: str | None = None,
conversation_id: str | None = None,
tts_audio_output: str | None = None,
Expand All @@ -101,6 +102,7 @@ async def async_pipeline_from_audio_stream(
device_id=device_id,
stt_metadata=stt_metadata,
stt_stream=stt_stream,
wake_word_phrase=wake_word_phrase,
run=PipelineRun(
hass,
context=context,
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/assist_pipeline/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
CONF_DEBUG_RECORDING_DIR = "debug_recording_dir"

DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
DEFAULT_WAKE_WORD_COOLDOWN = 2 # seconds
WAKE_WORD_COOLDOWN = 2 # seconds

EVENT_RECORDING = f"{DOMAIN}_recording"
11 changes: 11 additions & 0 deletions homeassistant/components/assist_pipeline/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ class SpeechToTextError(PipelineError):
"""Error in speech-to-text portion of pipeline."""


class DuplicateWakeUpDetectedError(WakeWordDetectionError):
"""Error when multiple voice assistants wake up at the same time (same wake word)."""

def __init__(self, wake_up_phrase: str) -> None:
"""Set error message."""
super().__init__(
"duplicate_wake_up_detected",
f"Duplicate wake-up detected for {wake_up_phrase}",
)


class IntentRecognitionError(PipelineError):
"""Error in intent recognition portion of pipeline."""

Expand Down
46 changes: 36 additions & 10 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@
CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG,
DATA_LAST_WAKE_UP,
DEFAULT_WAKE_WORD_COOLDOWN,
DOMAIN,
WAKE_WORD_COOLDOWN,
)
from .error import (
DuplicateWakeUpDetectedError,
IntentRecognitionError,
PipelineError,
PipelineNotFound,
Expand Down Expand Up @@ -453,9 +454,6 @@ class WakeWordSettings:
audio_seconds_to_buffer: float = 0
"""Seconds of audio to buffer before detection and forward to STT."""

cooldown_seconds: float = DEFAULT_WAKE_WORD_COOLDOWN
"""Seconds after a wake word detection where other detections are ignored."""


@dataclass(frozen=True)
class AudioSettings:
Expand Down Expand Up @@ -742,16 +740,22 @@ async def wake_word_detection(
wake_word_output: dict[str, Any] = {}
else:
# Avoid duplicate detections by checking cooldown
wake_up_key = f"{self.wake_word_entity_id}.{result.wake_word_id}"
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(wake_up_key)
last_wake_up = self.hass.data[DATA_LAST_WAKE_UP].get(
result.wake_word_phrase
)
if last_wake_up is not None:
sec_since_last_wake_up = time.monotonic() - last_wake_up
if sec_since_last_wake_up < wake_word_settings.cooldown_seconds:
_LOGGER.debug("Duplicate wake word detection occurred")
raise WakeWordDetectionAborted
if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
_LOGGER.debug(
"Duplicate wake word detection occurred for %s",
result.wake_word_phrase,
)
raise DuplicateWakeUpDetectedError(result.wake_word_phrase)

# Record last wake up time to block duplicate detections
self.hass.data[DATA_LAST_WAKE_UP][wake_up_key] = time.monotonic()
self.hass.data[DATA_LAST_WAKE_UP][
result.wake_word_phrase
] = time.monotonic()

if result.queued_audio:
# Add audio that was pending at detection.
Expand Down Expand Up @@ -1308,6 +1312,9 @@ class PipelineInput:
stt_stream: AsyncIterable[bytes] | None = None
"""Input audio for stt. Required when start_stage = stt."""

wake_word_phrase: str | None = None
"""Optional key used to de-duplicate wake-ups for local wake word detection."""

intent_input: str | None = None
"""Input for conversation agent. Required when start_stage = intent."""

Expand Down Expand Up @@ -1352,6 +1359,25 @@ async def execute(self) -> None:
assert self.stt_metadata is not None
assert stt_processed_stream is not None

if self.wake_word_phrase is not None:
# Avoid duplicate wake-ups by checking cooldown
last_wake_up = self.run.hass.data[DATA_LAST_WAKE_UP].get(
self.wake_word_phrase
)
if last_wake_up is not None:
sec_since_last_wake_up = time.monotonic() - last_wake_up
if sec_since_last_wake_up < WAKE_WORD_COOLDOWN:
_LOGGER.debug(
"Speech-to-text cancelled to avoid duplicate wake-up for %s",
self.wake_word_phrase,
)
raise DuplicateWakeUpDetectedError(self.wake_word_phrase)

# Record last wake up time to block duplicate detections
self.run.hass.data[DATA_LAST_WAKE_UP][
self.wake_word_phrase
] = time.monotonic()

stt_input_stream = stt_processed_stream

if stt_audio_buffer:
Expand Down
11 changes: 10 additions & 1 deletion homeassistant/components/assist_pipeline/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
extra=vol.ALLOW_EXTRA,
),
PipelineStage.STT: vol.Schema(
{vol.Required("input"): {vol.Required("sample_rate"): int}},
{
vol.Required("input"): {
vol.Required("sample_rate"): int,
vol.Optional("wake_word_phrase"): str,
}
},
extra=vol.ALLOW_EXTRA,
),
PipelineStage.INTENT: vol.Schema(
Expand Down Expand Up @@ -149,12 +154,15 @@ async def websocket_run(
msg_input = msg["input"]
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
incoming_sample_rate = msg_input["sample_rate"]
wake_word_phrase: str | None = None

if start_stage == PipelineStage.WAKE_WORD:
wake_word_settings = WakeWordSettings(
timeout=msg["input"].get("timeout", DEFAULT_WAKE_WORD_TIMEOUT),
audio_seconds_to_buffer=msg_input.get("audio_seconds_to_buffer", 0),
)
elif start_stage == PipelineStage.STT:
wake_word_phrase = msg["input"].get("wake_word_phrase")

async def stt_stream() -> AsyncGenerator[bytes, None]:
state = None
Expand Down Expand Up @@ -189,6 +197,7 @@ def handle_binary(
channel=stt.AudioChannels.CHANNEL_MONO,
)
input_args["stt_stream"] = stt_stream()
input_args["wake_word_phrase"] = wake_word_phrase

# Audio settings
audio_settings = AudioSettings(
Expand Down
9 changes: 9 additions & 0 deletions homeassistant/components/wake_word/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ class WakeWord:
"""Wake word model."""

id: str
"""Id of wake word model"""

name: str
"""Name of wake word model"""

phrase: str | None = None
"""Wake word phrase used to trigger model"""


@dataclass
Expand All @@ -17,6 +23,9 @@ class DetectionResult:
wake_word_id: str
"""Id of detected wake word"""

wake_word_phrase: str
"""Normalized phrase for the detected wake word"""

timestamp: int | None
"""Timestamp of audio chunk with detected wake word"""

Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/wyoming/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
"dependencies": ["assist_pipeline"],
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push",
"requirements": ["wyoming==1.5.2"],
"requirements": ["wyoming==1.5.3"],
"zeroconf": ["_wyoming._tcp.local."]
}
47 changes: 44 additions & 3 deletions homeassistant/components/wyoming/satellite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Support for Wyoming satellite services."""

import asyncio
from collections.abc import AsyncGenerator
import io
Expand All @@ -10,6 +11,7 @@
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.error import Error
from wyoming.info import Describe, Info
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import PauseSatellite, RunSatellite
Expand Down Expand Up @@ -86,7 +88,9 @@ async def run(self) -> None:
await self._connect_and_loop()
except asyncio.CancelledError:
raise # don't restart
except Exception: # pylint: disable=broad-exception-caught
except Exception as err: # pylint: disable=broad-exception-caught
_LOGGER.debug("%s: %s", err.__class__.__name__, str(err))

# Ensure sensor is off (before restart)
self.device.set_is_active(False)

Expand Down Expand Up @@ -197,6 +201,8 @@ async def _connect_and_loop(self) -> None:
async def _run_pipeline_loop(self) -> None:
"""Run a pipeline one or more times."""
assert self._client is not None
client_info: Info | None = None
wake_word_phrase: str | None = None
run_pipeline: RunPipeline | None = None
send_ping = True

Expand All @@ -209,6 +215,9 @@ async def _run_pipeline_loop(self) -> None:
)
pending = {pipeline_ended_task, client_event_task}

# Update info from satellite
await self._client.write_event(Describe().event())

while self.is_running and (not self.device.is_muted):
if send_ping:
# Ensure satellite is still connected
Expand All @@ -230,6 +239,9 @@ async def _run_pipeline_loop(self) -> None:
)
pending.add(pipeline_ended_task)

# Clear last wake word detection
wake_word_phrase = None

if (run_pipeline is not None) and run_pipeline.restart_on_end:
# Automatically restart pipeline.
# Used with "always on" streaming satellites.
Expand All @@ -253,7 +265,7 @@ async def _run_pipeline_loop(self) -> None:
elif RunPipeline.is_type(client_event.type):
# Satellite requested pipeline run
run_pipeline = RunPipeline.from_event(client_event)
self._run_pipeline_once(run_pipeline)
self._run_pipeline_once(run_pipeline, wake_word_phrase)
elif (
AudioChunk.is_type(client_event.type) and self._is_pipeline_running
):
Expand All @@ -265,6 +277,32 @@ async def _run_pipeline_loop(self) -> None:
# Stop pipeline
_LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"")
elif Info.is_type(client_event.type):
client_info = Info.from_event(client_event)
_LOGGER.debug("Updated client info: %s", client_info)
elif Detection.is_type(client_event.type):
detection = Detection.from_event(client_event)
wake_word_phrase = detection.name

# Resolve wake word name/id to phrase if info is available.
#
# This allows us to deconflict multiple satellite wake-ups
# with the same wake word.
if (client_info is not None) and (client_info.wake is not None):
found_phrase = False
for wake_service in client_info.wake:
for wake_model in wake_service.models:
if wake_model.name == detection.name:
wake_word_phrase = (
wake_model.phrase or wake_model.name
)
found_phrase = True
break

if found_phrase:
break

_LOGGER.debug("Client detected wake word: %s", wake_word_phrase)
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)

Expand All @@ -274,7 +312,9 @@ async def _run_pipeline_loop(self) -> None:
)
pending.add(client_event_task)

def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None:
def _run_pipeline_once(
self, run_pipeline: RunPipeline, wake_word_phrase: str | None = None
) -> None:
"""Run a pipeline once."""
_LOGGER.debug("Received run information: %s", run_pipeline)

Expand Down Expand Up @@ -332,6 +372,7 @@ def _run_pipeline_once(self, run_pipeline: RunPipeline) -> None:
volume_multiplier=self.device.volume_multiplier,
),
device_id=self.device.device_id,
wake_word_phrase=wake_word_phrase,
),
name="wyoming satellite pipeline",
)
Expand Down
23 changes: 21 additions & 2 deletions homeassistant/components/wyoming/wake_word.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Support for Wyoming wake-word-detection services."""

import asyncio
from collections.abc import AsyncIterable
import logging
Expand Down Expand Up @@ -49,7 +50,9 @@ def __init__(
wake_service = service.info.wake[0]

self._supported_wake_words = [
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
wake_word.WakeWord(
id=ww.name, name=ww.description or ww.name, phrase=ww.phrase
)
for ww in wake_service.models
]
self._attr_name = wake_service.name
Expand All @@ -64,7 +67,11 @@ async def get_supported_wake_words(self) -> list[wake_word.WakeWord]:
if info is not None:
wake_service = info.wake[0]
self._supported_wake_words = [
wake_word.WakeWord(id=ww.name, name=ww.description or ww.name)
wake_word.WakeWord(
id=ww.name,
name=ww.description or ww.name,
phrase=ww.phrase,
)
for ww in wake_service.models
]

Expand Down Expand Up @@ -140,6 +147,7 @@ async def next_chunk():

return wake_word.DetectionResult(
wake_word_id=detection.name,
wake_word_phrase=self._get_phrase(detection.name),
timestamp=detection.timestamp,
queued_audio=queued_audio,
)
Expand Down Expand Up @@ -183,3 +191,14 @@ async def next_chunk():
_LOGGER.exception("Error processing audio stream: %s", err)

return None

def _get_phrase(self, model_id: str) -> str:
"""Get wake word phrase for model id."""
for ww_model in self._supported_wake_words:
if not ww_model.phrase:
continue

if ww_model.id == model_id:
return ww_model.phrase

return model_id
2 changes: 1 addition & 1 deletion requirements_all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2863,7 +2863,7 @@ wled==0.17.0
wolf-comm==0.0.4

# homeassistant.components.wyoming
wyoming==1.5.2
wyoming==1.5.3

# homeassistant.components.xbox
xbox-webapi==2.0.11
Expand Down
2 changes: 1 addition & 1 deletion requirements_test_all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2195,7 +2195,7 @@ wled==0.17.0
wolf-comm==0.0.4

# homeassistant.components.wyoming
wyoming==1.5.2
wyoming==1.5.3

# homeassistant.components.xbox
xbox-webapi==2.0.11
Expand Down
Loading