Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions homeassistant/components/mqtt/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
Expand All @@ -44,7 +45,7 @@
)
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType

from . import (
DATA_MQTT,
Expand All @@ -54,6 +55,7 @@
debug_info,
subscription,
)
from ..mqtt import publish
from .const import (
ATTR_DISCOVERY_HASH,
ATTR_DISCOVERY_PAYLOAD,
Expand Down Expand Up @@ -520,8 +522,113 @@ async def cleanup_device_registry(
)


class MqttDiscoveryDeviceUpdateService:
"""Add support for auto discovery for platforms without an entity."""

def __init__(
self,
hass: HomeAssistant,
log_name: str,
discovery_data: dict[str, Any] | None = None,
device_id: str | None = None,
config_entry: ConfigEntry | None = None,
) -> None:
"""Initialize the update service."""

# Only activate update service id the parent class has a discover hash
if discovery_data is None:
return

self.hass = hass
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
discovery_topic = discovery_data[ATTR_DISCOVERY_TOPIC]
_device_removed: bool = False

async def async_discovery_update(
discovery_payload: DiscoveryInfoType | None,
) -> None:
"""Handle discovery update."""
if not discovery_payload:
# unregister the service through auto discovery
async_dispatcher_send(
hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)
await _async_tear_down()
return

# update the service through auto discovery
await self.async_discovery_update(discovery_payload)
_LOGGER.debug(
"%s %s updated has been processed",
log_name,
discovery_hash,
)
async_dispatcher_send(
hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
)

async def async_device_removed(event):
"""Handle the removal of a device."""
nonlocal _device_removed
event_device_id = event.data["device_id"]
if (
event.data["action"] != "remove"
or event_device_id != device_id
or _device_removed
):
return
_device_removed = True
# Clear the discovery topic so the service is not rediscovered after a restart
publish(hass, discovery_topic, "", retain=True)
await _async_tear_down()

async def _async_tear_down() -> None:
"""Handle the removal of the service."""
nonlocal _device_removed, self
await self.async_tear_down()
# remove the service for auto discovery updates and clean up the device registry
if not _device_removed and config_entry:
_device_removed = True
await cleanup_device_registry(hass, device_id, config_entry.entry_id)
clear_discovery_hash(hass, discovery_hash)
_remove_discovery()
_LOGGER.info(
"%s %s has been removed",
log_name,
discovery_hash,
)
del self

_remove_discovery = async_dispatcher_connect(
hass,
MQTT_DISCOVERY_UPDATED.format(discovery_hash),
async_discovery_update,
)
if device_id is not None:
self._remove_device_updated = hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, async_device_removed
)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
_LOGGER.info(
"%s %s has been initialized",
log_name,
discovery_hash,
)

async def async_tear_down(self) -> None:
"""Handle the cleanup of platform specific parts."""
raise NotImplementedError()

async def async_discovery_update(
self,
discovery_payload: DiscoveryInfoType,
) -> None:
"""Update the configuration through discovery."""
raise NotImplementedError()


class MqttDiscoveryUpdate(Entity):
"""Mixin used to handle updated discovery message."""
"""Mixin used to handle updated discovery message for entity based platforms."""

def __init__(self, discovery_data, discovery_update=None) -> None:
"""Initialize the discovery update mixin."""
Expand Down
136 changes: 26 additions & 110 deletions homeassistant/components/mqtt/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.template import Template
Expand All @@ -36,11 +32,11 @@
CONF_RETAIN,
DOMAIN,
)
from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash
from .discovery import MQTT_DISCOVERY_DONE, clear_discovery_hash
from .mixins import (
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
MqttDiscoveryDeviceUpdateService,
async_setup_entry_helper,
cleanup_device_registry,
device_info_from_config,
)

Expand Down Expand Up @@ -69,6 +65,8 @@
extra=vol.REMOVE_EXTRA,
)

LOG_NAME = "Notify service"

_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -155,7 +153,7 @@ async def _async_setup_notify(
config,
config_entry,
device_id,
discovery_hash,
discovery_data,
)
hass.data[MQTT_NOTIFY_SERVICES_SETUP][service_name] = service

Expand Down Expand Up @@ -184,86 +182,9 @@ async def async_get_service(
return service


class MqttNotificationServiceUpdater:
"""Add support for auto discovery updates."""

def __init__(self, hass: HomeAssistant, service: MqttNotificationService) -> None:
"""Initialize the update service."""

async def async_discovery_update(
discovery_payload: DiscoveryInfoType | None,
) -> None:
"""Handle discovery update."""
if not discovery_payload:
# unregister notify service through auto discovery
async_dispatcher_send(
hass, MQTT_DISCOVERY_DONE.format(service.discovery_hash), None
)
await async_tear_down_service()
return

# update notify service through auto discovery
await service.async_update_service(discovery_payload)
_LOGGER.debug(
"Notify service %s updated has been processed",
service.discovery_hash,
)
async_dispatcher_send(
hass, MQTT_DISCOVERY_DONE.format(service.discovery_hash), None
)

async def async_device_removed(event):
"""Handle the removal of a device."""
device_id = event.data["device_id"]
if (
event.data["action"] != "remove"
or device_id != service.device_id
or self._device_removed
):
return
self._device_removed = True
await async_tear_down_service()

async def async_tear_down_service():
"""Handle the removal of the service."""
services = hass.data[MQTT_NOTIFY_SERVICES_SETUP]
if self._service.service_name in services.keys():
del services[self._service.service_name]
if not self._device_removed and service.config_entry:
self._device_removed = True
await cleanup_device_registry(
hass, service.device_id, service.config_entry.entry_id
)
clear_discovery_hash(hass, service.discovery_hash)
self._remove_discovery()
await service.async_unregister_services()
_LOGGER.info(
"Notify service %s has been removed",
service.discovery_hash,
)
del self._service

self._service = service
self._remove_discovery = async_dispatcher_connect(
hass,
MQTT_DISCOVERY_UPDATED.format(service.discovery_hash),
async_discovery_update,
)
if service.device_id:
self._remove_device_updated = hass.bus.async_listen(
EVENT_DEVICE_REGISTRY_UPDATED, async_device_removed
)
self._device_removed = False
async_dispatcher_send(
hass, MQTT_DISCOVERY_DONE.format(service.discovery_hash), None
)
_LOGGER.info(
"Notify service %s has been initialized",
service.discovery_hash,
)


class MqttNotificationService(notify.BaseNotificationService):
class MqttNotificationService(
MqttDiscoveryDeviceUpdateService, notify.BaseNotificationService
):
"""Implement the notification service for MQTT."""

def __init__(
Expand All @@ -272,44 +193,37 @@ def __init__(
service_config: MqttNotificationConfig,
config_entry: ConfigEntry | None = None,
device_id: str | None = None,
discovery_hash: tuple | None = None,
discovery_data: dict[str, Any] | None = None,
) -> None:
"""Initialize the service."""
self.hass = hass
self._config = service_config
self._config_entry = config_entry
self._commmand_template = MqttCommandTemplate(
service_config.get(CONF_COMMAND_TEMPLATE), hass=hass
)
self._device_id = device_id
self._discovery_hash = discovery_hash
self._config_entry = config_entry
self._service_name = slugify(service_config[CONF_NAME])

self._updater = (
MqttNotificationServiceUpdater(hass, self) if discovery_hash else None
MqttDiscoveryDeviceUpdateService.__init__(
self, hass, LOG_NAME, discovery_data, device_id, config_entry
)

@property
def device_id(self) -> str | None:
"""Return the device ID."""
return self._device_id

@property
def config_entry(self) -> ConfigEntry | None:
"""Return the config_entry."""
return self._config_entry

@property
def discovery_hash(self) -> tuple | None:
"""Return the discovery hash."""
return self._discovery_hash

@property
def service_name(self) -> str:
"""Return the service ma,e."""
"""Return the service name."""
return self._service_name

async def async_update_service(
@property
def targets(self) -> dict[str, str]:
"""Return a dictionary of registered targets."""
return {target: target for target in self._config[CONF_TARGETS]}

async def async_discovery_update(
self,
discovery_payload: DiscoveryInfoType,
) -> None:
Expand Down Expand Up @@ -342,10 +256,12 @@ async def async_update_service(
)
_update_device(self.hass, self._config_entry, config)

@property
def targets(self) -> dict[str, str]:
"""Return a dictionary of registered targets."""
return {target: target for target in self._config[CONF_TARGETS]}
async def async_tear_down(self) -> None:
"""Cleanup when the service is removed."""
await self.async_unregister_services()
services = self.hass.data[MQTT_NOTIFY_SERVICES_SETUP]
if self._service_name in services:
del services[self._service_name]

async def async_send_message(self, message: str = "", **kwargs):
"""Build and send a MQTT message."""
Expand Down
Loading