diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index c87e5ccba25c44..ab047616766a9f 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -31,6 +31,7 @@ device_registry as dr, entity_registry as er, ) +from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, async_dispatcher_send, @@ -44,7 +45,7 @@ ) from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.reload import async_setup_reload_service -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import ( DATA_MQTT, @@ -54,6 +55,7 @@ debug_info, subscription, ) +from ..mqtt import publish from .const import ( ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_PAYLOAD, @@ -520,8 +522,113 @@ async def cleanup_device_registry( ) +class MqttDiscoveryDeviceUpdateService: + """Add support for auto discovery for platforms without an entity.""" + + def __init__( + self, + hass: HomeAssistant, + log_name: str, + discovery_data: dict[str, Any] | None = None, + device_id: str | None = None, + config_entry: ConfigEntry | None = None, + ) -> None: + """Initialize the update service.""" + + # Only activate update service id the parent class has a discover hash + if discovery_data is None: + return + + self.hass = hass + discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] + discovery_topic = discovery_data[ATTR_DISCOVERY_TOPIC] + _device_removed: bool = False + + async def async_discovery_update( + discovery_payload: DiscoveryInfoType | None, + ) -> None: + """Handle discovery update.""" + if not discovery_payload: + # unregister the service through auto discovery + async_dispatcher_send( + hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None + ) + await _async_tear_down() + return + + # update the service through auto discovery + await self.async_discovery_update(discovery_payload) + _LOGGER.debug( + "%s %s updated has been processed", + log_name, + discovery_hash, + ) + async_dispatcher_send( + hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None + ) + + async def async_device_removed(event): + """Handle the removal of a device.""" + nonlocal _device_removed + event_device_id = event.data["device_id"] + if ( + event.data["action"] != "remove" + or event_device_id != device_id + or _device_removed + ): + return + _device_removed = True + # Clear the discovery topic so the service is not rediscovered after a restart + publish(hass, discovery_topic, "", retain=True) + await _async_tear_down() + + async def _async_tear_down() -> None: + """Handle the removal of the service.""" + nonlocal _device_removed, self + await self.async_tear_down() + # remove the service for auto discovery updates and clean up the device registry + if not _device_removed and config_entry: + _device_removed = True + await cleanup_device_registry(hass, device_id, config_entry.entry_id) + clear_discovery_hash(hass, discovery_hash) + _remove_discovery() + _LOGGER.info( + "%s %s has been removed", + log_name, + discovery_hash, + ) + del self + + _remove_discovery = async_dispatcher_connect( + hass, + MQTT_DISCOVERY_UPDATED.format(discovery_hash), + async_discovery_update, + ) + if device_id is not None: + self._remove_device_updated = hass.bus.async_listen( + EVENT_DEVICE_REGISTRY_UPDATED, async_device_removed + ) + async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) + _LOGGER.info( + "%s %s has been initialized", + log_name, + discovery_hash, + ) + + async def async_tear_down(self) -> None: + """Handle the cleanup of platform specific parts.""" + raise NotImplementedError() + + async def async_discovery_update( + self, + discovery_payload: DiscoveryInfoType, + ) -> None: + """Update the configuration through discovery.""" + raise NotImplementedError() + + class MqttDiscoveryUpdate(Entity): - """Mixin used to handle updated discovery message.""" + """Mixin used to handle updated discovery message for entity based platforms.""" def __init__(self, discovery_data, discovery_update=None) -> None: """Initialize the discovery update mixin.""" diff --git a/homeassistant/components/mqtt/notify.py b/homeassistant/components/mqtt/notify.py index 9ba341aab0daa3..61a9db8b2de512 100644 --- a/homeassistant/components/mqtt/notify.py +++ b/homeassistant/components/mqtt/notify.py @@ -13,11 +13,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED -from homeassistant.helpers.dispatcher import ( - async_dispatcher_connect, - async_dispatcher_send, -) +from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.template import Template @@ -36,11 +32,11 @@ CONF_RETAIN, DOMAIN, ) -from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash +from .discovery import MQTT_DISCOVERY_DONE, clear_discovery_hash from .mixins import ( MQTT_ENTITY_DEVICE_INFO_SCHEMA, + MqttDiscoveryDeviceUpdateService, async_setup_entry_helper, - cleanup_device_registry, device_info_from_config, ) @@ -69,6 +65,8 @@ extra=vol.REMOVE_EXTRA, ) +LOG_NAME = "Notify service" + _LOGGER = logging.getLogger(__name__) @@ -155,7 +153,7 @@ async def _async_setup_notify( config, config_entry, device_id, - discovery_hash, + discovery_data, ) hass.data[MQTT_NOTIFY_SERVICES_SETUP][service_name] = service @@ -184,86 +182,9 @@ async def async_get_service( return service -class MqttNotificationServiceUpdater: - """Add support for auto discovery updates.""" - - def __init__(self, hass: HomeAssistant, service: MqttNotificationService) -> None: - """Initialize the update service.""" - - async def async_discovery_update( - discovery_payload: DiscoveryInfoType | None, - ) -> None: - """Handle discovery update.""" - if not discovery_payload: - # unregister notify service through auto discovery - async_dispatcher_send( - hass, MQTT_DISCOVERY_DONE.format(service.discovery_hash), None - ) - await async_tear_down_service() - return - - # update notify service through auto discovery - await service.async_update_service(discovery_payload) - _LOGGER.debug( - "Notify service %s updated has been processed", - service.discovery_hash, - ) - async_dispatcher_send( - hass, MQTT_DISCOVERY_DONE.format(service.discovery_hash), None - ) - - async def async_device_removed(event): - """Handle the removal of a device.""" - device_id = event.data["device_id"] - if ( - event.data["action"] != "remove" - or device_id != service.device_id - or self._device_removed - ): - return - self._device_removed = True - await async_tear_down_service() - - async def async_tear_down_service(): - """Handle the removal of the service.""" - services = hass.data[MQTT_NOTIFY_SERVICES_SETUP] - if self._service.service_name in services.keys(): - del services[self._service.service_name] - if not self._device_removed and service.config_entry: - self._device_removed = True - await cleanup_device_registry( - hass, service.device_id, service.config_entry.entry_id - ) - clear_discovery_hash(hass, service.discovery_hash) - self._remove_discovery() - await service.async_unregister_services() - _LOGGER.info( - "Notify service %s has been removed", - service.discovery_hash, - ) - del self._service - - self._service = service - self._remove_discovery = async_dispatcher_connect( - hass, - MQTT_DISCOVERY_UPDATED.format(service.discovery_hash), - async_discovery_update, - ) - if service.device_id: - self._remove_device_updated = hass.bus.async_listen( - EVENT_DEVICE_REGISTRY_UPDATED, async_device_removed - ) - self._device_removed = False - async_dispatcher_send( - hass, MQTT_DISCOVERY_DONE.format(service.discovery_hash), None - ) - _LOGGER.info( - "Notify service %s has been initialized", - service.discovery_hash, - ) - - -class MqttNotificationService(notify.BaseNotificationService): +class MqttNotificationService( + MqttDiscoveryDeviceUpdateService, notify.BaseNotificationService +): """Implement the notification service for MQTT.""" def __init__( @@ -272,21 +193,19 @@ def __init__( service_config: MqttNotificationConfig, config_entry: ConfigEntry | None = None, device_id: str | None = None, - discovery_hash: tuple | None = None, + discovery_data: dict[str, Any] | None = None, ) -> None: """Initialize the service.""" self.hass = hass self._config = service_config + self._config_entry = config_entry self._commmand_template = MqttCommandTemplate( service_config.get(CONF_COMMAND_TEMPLATE), hass=hass ) self._device_id = device_id - self._discovery_hash = discovery_hash - self._config_entry = config_entry self._service_name = slugify(service_config[CONF_NAME]) - - self._updater = ( - MqttNotificationServiceUpdater(hass, self) if discovery_hash else None + MqttDiscoveryDeviceUpdateService.__init__( + self, hass, LOG_NAME, discovery_data, device_id, config_entry ) @property @@ -294,22 +213,17 @@ def device_id(self) -> str | None: """Return the device ID.""" return self._device_id - @property - def config_entry(self) -> ConfigEntry | None: - """Return the config_entry.""" - return self._config_entry - - @property - def discovery_hash(self) -> tuple | None: - """Return the discovery hash.""" - return self._discovery_hash - @property def service_name(self) -> str: - """Return the service ma,e.""" + """Return the service name.""" return self._service_name - async def async_update_service( + @property + def targets(self) -> dict[str, str]: + """Return a dictionary of registered targets.""" + return {target: target for target in self._config[CONF_TARGETS]} + + async def async_discovery_update( self, discovery_payload: DiscoveryInfoType, ) -> None: @@ -342,10 +256,12 @@ async def async_update_service( ) _update_device(self.hass, self._config_entry, config) - @property - def targets(self) -> dict[str, str]: - """Return a dictionary of registered targets.""" - return {target: target for target in self._config[CONF_TARGETS]} + async def async_tear_down(self) -> None: + """Cleanup when the service is removed.""" + await self.async_unregister_services() + services = self.hass.data[MQTT_NOTIFY_SERVICES_SETUP] + if self._service_name in services: + del services[self._service_name] async def async_send_message(self, message: str = "", **kwargs): """Build and send a MQTT message.""" diff --git a/homeassistant/components/mqtt/tag.py b/homeassistant/components/mqtt/tag.py index a2541c064c0238..da418a3a9c30a5 100644 --- a/homeassistant/components/mqtt/tag.py +++ b/homeassistant/components/mqtt/tag.py @@ -7,35 +7,25 @@ from homeassistant.const import CONF_DEVICE, CONF_PLATFORM, CONF_VALUE_TEMPLATE from homeassistant.helpers import device_registry as dr import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED -from homeassistant.helpers.dispatcher import ( - async_dispatcher_connect, - async_dispatcher_send, -) +from homeassistant.helpers.typing import DiscoveryInfoType from . import MqttValueTemplate, subscription from .. import mqtt -from .const import ( - ATTR_DISCOVERY_HASH, - ATTR_DISCOVERY_TOPIC, - CONF_QOS, - CONF_TOPIC, - DOMAIN, -) -from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash +from .const import ATTR_DISCOVERY_HASH, CONF_QOS, CONF_TOPIC, DOMAIN from .mixins import ( CONF_CONNECTIONS, CONF_IDENTIFIERS, MQTT_ENTITY_DEVICE_INFO_SCHEMA, - async_removed_from_device, + MqttDiscoveryDeviceUpdateService, async_setup_entry_helper, - cleanup_device_registry, device_info_from_config, ) from .util import valid_subscribe_topic _LOGGER = logging.getLogger(__name__) +LOG_NAME = "Tag" + TAG = "tag" TAGS = "mqtt_tags" @@ -88,7 +78,7 @@ async def async_setup_tag(hass, config, config_entry, discovery_data): config_entry, ) - await tag_scanner.setup() + await tag_scanner.subscribe_topics() if device_id: hass.data[TAGS][device_id][discovery_id] = tag_scanner @@ -101,7 +91,7 @@ def async_has_tags(hass, device_id): return hass.data[TAGS][device_id] != {} -class MQTTTagScanner: +class MQTTTagScanner(MqttDiscoveryDeviceUpdateService): """MQTT Tag scanner.""" def __init__(self, hass, config, device_id, discovery_data, config_entry): @@ -118,33 +108,27 @@ def __init__(self, hass, config, device_id, discovery_data, config_entry): self._setup_from_config(config) - async def discovery_update(self, payload): - """Handle discovery update.""" + MqttDiscoveryDeviceUpdateService.__init__( + self, hass, LOG_NAME, discovery_data, device_id, config_entry + ) + + async def async_discovery_update( + self, + discovery_payload: DiscoveryInfoType, + ) -> None: + """Update the configuration through discovery.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] _LOGGER.info( - "Got update for tag scanner with hash: %s '%s'", discovery_hash, payload - ) - if not payload: - # Empty payload: Remove tag scanner - _LOGGER.info("Removing tag scanner: %s", discovery_hash) - self.tear_down() - if self.device_id: - await cleanup_device_registry( - self.hass, self.device_id, self._config_entry.entry_id - ) - else: - # Non-empty payload: Update tag scanner - _LOGGER.info("Updating tag scanner: %s", discovery_hash) - config = PLATFORM_SCHEMA(payload) - self._config = config - if self.device_id: - _update_device(self.hass, self._config_entry, config) - self._setup_from_config(config) - await self.subscribe_topics() - - async_dispatcher_send( - self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None + "Got update for tag scanner with hash: %s '%s'", + discovery_hash, + discovery_payload, ) + config = PLATFORM_SCHEMA(discovery_payload) + self._config = config + if self.device_id: + _update_device(self.hass, self._config_entry, config) + self._setup_from_config(config) + await self.subscribe_topics() def _setup_from_config(self, config): self._value_template = MqttValueTemplate( @@ -152,23 +136,6 @@ def _setup_from_config(self, config): hass=self.hass, ).async_render_with_possible_json_value - async def setup(self): - """Set up the MQTT tag scanner.""" - discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] - await self.subscribe_topics() - if self.device_id: - self._remove_device_updated = self.hass.bus.async_listen( - EVENT_DEVICE_REGISTRY_UPDATED, self.device_updated - ) - self._remove_discovery = async_dispatcher_connect( - self.hass, - MQTT_DISCOVERY_UPDATED.format(discovery_hash), - self.discovery_update, - ) - async_dispatcher_send( - self.hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None - ) - async def subscribe_topics(self): """Subscribe to MQTT topics.""" @@ -195,31 +162,10 @@ async def tag_scanned(msg): ) await subscription.async_subscribe_topics(self.hass, self._sub_state) - async def device_updated(self, event): - """Handle the update or removal of a device.""" - if not async_removed_from_device( - self.hass, event, self.device_id, self._config_entry.entry_id - ): - return - - # Stop subscribing to discovery updates to not trigger when we clear the - # discovery topic - self.tear_down() - - # Clear the discovery topic so the entity is not rediscovered after a restart - discovery_topic = self.discovery_data[ATTR_DISCOVERY_TOPIC] - mqtt.publish(self.hass, discovery_topic, "", retain=True) - - def tear_down(self): + async def async_tear_down(self): """Cleanup tag scanner.""" discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH] discovery_id = discovery_hash[1] - - clear_discovery_hash(self.hass, discovery_hash) - if self.device_id: - self._remove_device_updated() - self._remove_discovery() - self._sub_state = subscription.async_unsubscribe_topics( self.hass, self._sub_state ) diff --git a/tests/components/mqtt/test_notify.py b/tests/components/mqtt/test_notify.py index 33a32d858af6ac..f28ed821875fec 100644 --- a/tests/components/mqtt/test_notify.py +++ b/tests/components/mqtt/test_notify.py @@ -704,6 +704,12 @@ async def test_discovery_with_device_removal(hass, mqtt_mock, caplog, device_reg in caplog.text ) + # Test if discovery topic is cleared with retained flag + mqtt_mock.async_publish.assert_called_once_with( + f"homeassistant/{notify.DOMAIN}/{service_name1}/config", "", 0, True + ) + mqtt_mock.async_publish.reset_mock() + async def test_publishing_with_custom_encoding(hass, mqtt_mock, caplog): """Test publishing MQTT payload with different encoding via discovery and configuration."""