diff --git a/homeassistant/components/nest/__init__.py b/homeassistant/components/nest/__init__.py index 1262ff190c94cf..794b481618e7b0 100644 --- a/homeassistant/components/nest/__init__.py +++ b/homeassistant/components/nest/__init__.py @@ -4,13 +4,13 @@ from abc import ABC, abstractmethod import asyncio -from collections.abc import Awaitable, Callable from http import HTTPStatus import logging from aiohttp import ClientError, ClientResponseError, web from google_nest_sdm.camera_traits import CameraClipPreviewTrait from google_nest_sdm.device import Device +from google_nest_sdm.device_manager import DeviceManager from google_nest_sdm.event import EventMessage from google_nest_sdm.event_media import Media from google_nest_sdm.exceptions import ( @@ -71,7 +71,7 @@ async_get_media_source_devices, async_get_transcoder, ) -from .types import NestConfigEntry, NestData +from .types import DevicesAddedListener, NestConfigEntry, NestData _LOGGER = logging.getLogger(__name__) @@ -124,19 +124,17 @@ class SignalUpdateCallback: def __init__( self, hass: HomeAssistant, - config_reload_cb: Callable[[], Awaitable[None]], config_entry: NestConfigEntry, ) -> None: """Initialize EventCallback.""" self._hass = hass - self._config_reload_cb = config_reload_cb self._config_entry = config_entry + self._device_listeners: list[DevicesAddedListener] = [] + self._known_devices: dict[str, Device] = {} + self._device_manager: DeviceManager | None = None async def async_handle_event(self, event_message: EventMessage) -> None: """Process an incoming EventMessage.""" - if event_message.relation_update: - _LOGGER.info("Devices or homes have changed; Need reload to take effect") - return if not event_message.resource_update_name: return device_id = event_message.resource_update_name @@ -187,6 +185,59 @@ def _supported_traits(self, device_id: str) -> list[str]: return [] return list(device.traits) + def set_device_manager(self, device_manager: DeviceManager) -> None: + """Set the device manager and register for device changes.""" + self._device_manager = device_manager + device_manager.set_change_callback(self._devices_updated_cb) + self._update_devices(self._device_manager.devices) + + async def _devices_updated_cb(self) -> None: + """Handle callback when devices are updated.""" + _LOGGER.debug("Devices updated callback invoked") + if self._device_manager is None: + _LOGGER.debug("No device manager available") + return + self._update_devices(self._device_manager.devices) + + def register_devices_listener(self, listener: DevicesAddedListener) -> None: + """Add a listener for device changes.""" + self._device_listeners.append(listener) + # Immediately notify about existing devices + listener(list(self._known_devices.values())) + + def _update_devices(self, devices: dict[str, Device]) -> None: + """Update the set of devices and notify listeners of changes. + + This is invoked when the set of devices changes with the entire set of + devices, and will notify listeners about any newly added devices and + remove devices from the device registry that are no longer present. + """ + added_devices = [] + for device_id, device in devices.items(): + if device_id in self._known_devices: + continue + added_devices.append(device) + self._known_devices[device_id] = device + if added_devices: + _LOGGER.debug("Adding new devices: %s", added_devices) + for listener in self._device_listeners: + listener(added_devices) + + # Remove any device entries that are no longer present + device_registry = dr.async_get(self._hass) + device_entries = dr.async_entries_for_config_entry( + device_registry, self._config_entry.entry_id + ) + for device_entry in device_entries: + device_id = next(iter(device_entry.identifiers))[1] + if device_id in devices: + continue + _LOGGER.info("Removing stale device entry '%s'", device_id) + device_registry.async_update_device( + device_id=device_entry.id, + remove_config_entry_id=self._config_entry.entry_id, + ) + async def async_setup_entry(hass: HomeAssistant, entry: NestConfigEntry) -> bool: """Set up Nest from a config entry with dispatch between old/new flows.""" @@ -225,10 +276,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: NestConfigEntry) -> bool subscriber.cache_policy.store = await async_get_media_event_store(hass, subscriber) subscriber.cache_policy.transcoder = await async_get_transcoder(hass) - async def async_config_reload() -> None: - await hass.config_entries.async_reload(entry.entry_id) - - update_callback = SignalUpdateCallback(hass, async_config_reload, entry) + # The device manager has a single change callback. When the change + # callback is invoked, we update the DeviceListener with the current + # set of devices which will notify any registered listeners with the + # changes. + update_callback = SignalUpdateCallback(hass, entry) subscriber.set_update_callback(update_callback.async_handle_event) try: unsub = await subscriber.start_async() @@ -270,10 +322,13 @@ def on_hass_stop(_: Event) -> None: hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop) ) + update_callback.set_device_manager(device_manager) + entry.async_on_unload(unsub) entry.runtime_data = NestData( subscriber=subscriber, device_manager=device_manager, + register_devices_listener=update_callback.register_devices_listener, ) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) diff --git a/homeassistant/components/nest/camera.py b/homeassistant/components/nest/camera.py index f5985da9ff88c4..4b5bee127d09a3 100644 --- a/homeassistant/components/nest/camera.py +++ b/homeassistant/components/nest/camera.py @@ -57,16 +57,19 @@ async def async_setup_entry( ) -> None: """Set up the cameras.""" - entities: list[NestCameraBaseEntity] = [] - for device in entry.runtime_data.device_manager.devices.values(): - if (live_stream := device.traits.get(CameraLiveStreamTrait.NAME)) is None: - continue - if StreamingProtocol.WEB_RTC in live_stream.supported_protocols: - entities.append(NestWebRTCEntity(device)) - elif StreamingProtocol.RTSP in live_stream.supported_protocols: - entities.append(NestRTSPEntity(device)) - - async_add_entities(entities) + def devices_added(devices: list[Device]) -> None: + entities: list[NestCameraBaseEntity] = [] + for device in devices: + if (live_stream := device.traits.get(CameraLiveStreamTrait.NAME)) is None: + continue + if StreamingProtocol.WEB_RTC in live_stream.supported_protocols: + entities.append(NestWebRTCEntity(device)) + elif StreamingProtocol.RTSP in live_stream.supported_protocols: + entities.append(NestRTSPEntity(device)) + + async_add_entities(entities) + + entry.runtime_data.register_devices_listener(devices_added) class StreamRefresh: diff --git a/homeassistant/components/nest/climate.py b/homeassistant/components/nest/climate.py index 25f39704393acf..cf1e67ad887feb 100644 --- a/homeassistant/components/nest/climate.py +++ b/homeassistant/components/nest/climate.py @@ -82,11 +82,14 @@ async def async_setup_entry( ) -> None: """Set up the client entities.""" - async_add_entities( - ThermostatEntity(device) - for device in entry.runtime_data.device_manager.devices.values() - if ThermostatHvacTrait.NAME in device.traits - ) + def devices_added(devices: list[Device]) -> None: + async_add_entities( + ThermostatEntity(device) + for device in devices + if ThermostatHvacTrait.NAME in device.traits + ) + + entry.runtime_data.register_devices_listener(devices_added) class ThermostatEntity(ClimateEntity): diff --git a/homeassistant/components/nest/quality_scale.yaml b/homeassistant/components/nest/quality_scale.yaml index a91b957e2f2818..83282067d377fe 100644 --- a/homeassistant/components/nest/quality_scale.yaml +++ b/homeassistant/components/nest/quality_scale.yaml @@ -53,16 +53,16 @@ rules: entity-disabled-by-default: todo discovery: todo exception-translations: todo - devices: todo + devices: done docs-supported-devices: todo icon-translations: todo docs-known-limitations: todo - stale-devices: todo + stale-devices: done docs-supported-functions: todo repair-issues: todo reconfiguration-flow: todo entity-category: todo - dynamic-devices: todo + dynamic-devices: done docs-troubleshooting: todo diagnostics: todo docs-use-cases: todo diff --git a/homeassistant/components/nest/sensor.py b/homeassistant/components/nest/sensor.py index a6fda48fe87792..553068bb8b2540 100644 --- a/homeassistant/components/nest/sensor.py +++ b/homeassistant/components/nest/sensor.py @@ -37,13 +37,16 @@ async def async_setup_entry( ) -> None: """Set up the sensors.""" - entities: list[SensorEntity] = [] - for device in entry.runtime_data.device_manager.devices.values(): - if TemperatureTrait.NAME in device.traits: - entities.append(TemperatureSensor(device)) - if HumidityTrait.NAME in device.traits: - entities.append(HumiditySensor(device)) - async_add_entities(entities) + def devices_added(devices: list[Device]) -> None: + entities: list[SensorEntity] = [] + for device in devices: + if TemperatureTrait.NAME in device.traits: + entities.append(TemperatureSensor(device)) + if HumidityTrait.NAME in device.traits: + entities.append(HumiditySensor(device)) + async_add_entities(entities) + + entry.runtime_data.register_devices_listener(devices_added) class SensorBase(SensorEntity): diff --git a/homeassistant/components/nest/types.py b/homeassistant/components/nest/types.py index bd6cd5cd88701d..e682a1f10db92a 100644 --- a/homeassistant/components/nest/types.py +++ b/homeassistant/components/nest/types.py @@ -1,12 +1,16 @@ """Type definitions for Nest.""" +from collections.abc import Callable from dataclasses import dataclass +from google_nest_sdm.device import Device from google_nest_sdm.device_manager import DeviceManager from google_nest_sdm.google_nest_subscriber import GoogleNestSubscriber from homeassistant.config_entries import ConfigEntry +type DevicesAddedListener = Callable[[list[Device]], None] + @dataclass class NestData: @@ -14,6 +18,7 @@ class NestData: subscriber: GoogleNestSubscriber device_manager: DeviceManager + register_devices_listener: Callable[[DevicesAddedListener], None] type NestConfigEntry = ConfigEntry[NestData] diff --git a/tests/components/nest/test_events.py b/tests/components/nest/test_events.py index d4ad81bd4e834f..a7090c8e0e2a52 100644 --- a/tests/components/nest/test_events.py +++ b/tests/components/nest/test_events.py @@ -519,8 +519,8 @@ async def test_structure_update_event( assert not events assert entity_registry.async_get("camera.front") - # Currently need a manual reload to detect the new entity - assert not entity_registry.async_get("camera.back") + # New entity is now registered automatically when the event arrives + assert entity_registry.async_get("camera.back") @pytest.mark.parametrize( diff --git a/tests/components/nest/test_init.py b/tests/components/nest/test_init.py index b1839a4ae58616..6effa34fa52432 100644 --- a/tests/components/nest/test_init.py +++ b/tests/components/nest/test_init.py @@ -28,12 +28,16 @@ from homeassistant.components.nest.const import OAUTH2_TOKEN from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr +from homeassistant.util.dt import utcnow from .common import ( PROJECT_ID, SUBSCRIBER_ID, TEST_CONFIG_NEW_SUBSCRIPTION, + CreateDevice, PlatformSetup, + create_nest_event, ) from tests.test_util.aiohttp import AiohttpClientMocker @@ -348,3 +352,97 @@ async def test_migrate_unique_id( assert config_entry.state is ConfigEntryState.LOADED assert config_entry.unique_id == PROJECT_ID + + +async def test_add_devices( + hass: HomeAssistant, + setup_platform: PlatformSetup, + create_device: CreateDevice, + subscriber: AsyncMock, + device_registry: dr.DeviceRegistry, +) -> None: + """Test that adding devices after initial setup works.""" + device_id1 = "enterprises/project-id/devices/device-id" + traits = { + "sdm.devices.traits.Temperature": { + "ambientTemperatureCelsius": 25.1, + }, + } + create_device.create(raw_traits=traits, raw_data={"name": device_id1}) + await setup_platform() + + device_entries = dr.async_entries_for_config_entry( + device_registry, hass.config_entries.async_entries(DOMAIN)[0].entry_id + ) + assert len(device_entries) == 1 + + # Add a second device and trigger a notification to refresh + device_id2 = "enterprises/project-id/devices/device-id-2" + create_device.create(raw_traits=traits, raw_data={"name": device_id2}) + + event_message = create_nest_event( + { + "eventId": "some-event-id", + "timestamp": utcnow().isoformat(timespec="seconds"), + "relationUpdate": { + "type": "UPDATED", + "subject": "some-subject", + "object": "some-object", + }, + }, + ) + await subscriber.async_receive_event(event_message) + await hass.async_block_till_done() + await hass.async_block_till_done() + + device_entries = dr.async_entries_for_config_entry( + device_registry, hass.config_entries.async_entries(DOMAIN)[0].entry_id + ) + assert len(device_entries) == 2 + + +async def test_stale_device_cleanup( + hass: HomeAssistant, + setup_platform: PlatformSetup, + create_device: CreateDevice, + subscriber: AsyncMock, + device_registry: dr.DeviceRegistry, +) -> None: + """Test that stale devices are removed.""" + # Device #1 will be returned by the API. + device_id1 = "enterprises/project-id/devices/device-id" + device_registry.async_get_or_create( + config_entry_id=hass.config_entries.async_entries(DOMAIN)[0].entry_id, + identifiers={(DOMAIN, device_id1)}, + manufacturer="Google Nest", + ) + create_device.create( + raw_traits={ + "sdm.devices.traits.Temperature": { + "ambientTemperatureCelsius": 25.1, + }, + }, + raw_data={"name": device_id1}, + ) + + # Device #2 is stale and should be removed. + device_registry.async_get_or_create( + config_entry_id=hass.config_entries.async_entries(DOMAIN)[0].entry_id, + identifiers={(DOMAIN, "enterprises/project-id/devices/device-id-stale")}, + manufacturer="Google Nest", + ) + + # Verify both devices are registered before setup. + device_entries = dr.async_entries_for_config_entry( + device_registry, hass.config_entries.async_entries(DOMAIN)[0].entry_id + ) + assert len(device_entries) == 2 + + # Setup should remove the stale device. + await setup_platform() + + device_entries = dr.async_entries_for_config_entry( + device_registry, hass.config_entries.async_entries(DOMAIN)[0].entry_id + ) + assert len(device_entries) == 1 + assert device_entries[0].identifiers == {(DOMAIN, device_id1)}