Skip to content

Commit

Permalink
Make sure MQTT client is available when starting depending platforms (h…
Browse files Browse the repository at this point in the history
…ome-assistant#91164)

* Make sure MQTT is available starting mqtt_json

* Wait for mqtt client

* Sync client connect

* Simplify

* Addiitional tests async_wait_for_mqtt_client

* Improve comment waiting for mqtt

* Improve docstr

* Do not wait unless the MQTT client is in setup

* Handle entry errors during setup

* More comments - do not clear event

* Add snips and mqtt_room

* Add manual_mqtt

* Update homeassistant/components/mqtt/__init__.py

Co-authored-by: J. Nick Koston <[email protected]>

* Use a fixture, improve tests

* Simplify

---------

Co-authored-by: J. Nick Koston <[email protected]>
  • Loading branch information
jbouwh and bdraco authored Apr 20, 2023
1 parent adc4728 commit 0bcda9f
Show file tree
Hide file tree
Showing 11 changed files with 347 additions and 35 deletions.
8 changes: 7 additions & 1 deletion homeassistant/components/manual_mqtt/alarm_control_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,19 @@ def _state_schema(state):
)


def setup_platform(
async def async_setup_platform(
hass: HomeAssistant,
config: ConfigType,
add_entities: AddEntitiesCallback,
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up the manual MQTT alarm platform."""
# Make sure MQTT integration is enabled and the client is available
# We cannot count on dependencies as the alarm_control_panel platform setup
# also will be triggered when mqtt is loading the `alarm_control_panel` platform
if not await mqtt.async_wait_for_mqtt_client(hass):
_LOGGER.error("MQTT integration is not available")
return
add_entities(
[
ManualMQTTAlarm(
Expand Down
79 changes: 52 additions & 27 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
CONF_WS_HEADERS,
CONF_WS_PATH,
DATA_MQTT,
DATA_MQTT_AVAILABLE,
DEFAULT_DISCOVERY,
DEFAULT_ENCODING,
DEFAULT_PREFIX,
Expand All @@ -87,8 +88,9 @@
ReceiveMessage,
ReceivePayloadType,
)
from .util import (
from .util import ( # noqa: F401
async_create_certificate_temp_files,
async_wait_for_mqtt_client,
get_mqtt_data,
mqtt_config_entry_enabled,
valid_publish_topic,
Expand Down Expand Up @@ -183,34 +185,54 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -

async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Load a config entry."""
conf = dict(entry.data)
# Fetch configuration
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_yaml = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = get_mqtt_data(hass)
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
# Initial setup
websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_mqtt_info)
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
client.start(mqtt_data)

await async_create_certificate_temp_files(hass, dict(entry.data))
# Restore saved subscriptions
if mqtt_data.subscriptions_to_restore:
mqtt_data.client.async_restore_tracked_subscriptions(
mqtt_data.subscriptions_to_restore
conf: dict[str, Any]
mqtt_data: MqttData

async def _setup_client() -> tuple[MqttData, dict[str, Any]]:
"""Set up the MQTT client."""
# Fetch configuration
conf = dict(entry.data)
hass_config = await conf_util.async_hass_config_yaml(hass)
mqtt_yaml = PLATFORM_CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {}))
client = MQTT(hass, entry, conf)
if DOMAIN in hass.data:
mqtt_data = get_mqtt_data(hass)
mqtt_data.config = mqtt_yaml
mqtt_data.client = client
else:
# Initial setup
websocket_api.async_register_command(hass, websocket_subscribe)
websocket_api.async_register_command(hass, websocket_mqtt_info)
hass.data[DATA_MQTT] = mqtt_data = MqttData(config=mqtt_yaml, client=client)
client.start(mqtt_data)

await async_create_certificate_temp_files(hass, dict(entry.data))
# Restore saved subscriptions
if mqtt_data.subscriptions_to_restore:
mqtt_data.client.async_restore_tracked_subscriptions(
mqtt_data.subscriptions_to_restore
)
mqtt_data.subscriptions_to_restore = []
mqtt_data.reload_dispatchers.append(
entry.add_update_listener(_async_config_entry_updated)
)
mqtt_data.subscriptions_to_restore = []
mqtt_data.reload_dispatchers.append(
entry.add_update_listener(_async_config_entry_updated)
)

await mqtt_data.client.async_connect()
await mqtt_data.client.async_connect()
return (mqtt_data, conf)

client_available: asyncio.Future[bool]
if DATA_MQTT_AVAILABLE not in hass.data:
client_available = hass.data[DATA_MQTT_AVAILABLE] = asyncio.Future()
else:
client_available = hass.data[DATA_MQTT_AVAILABLE]

setup_ok: bool = False
try:
mqtt_data, conf = await _setup_client()
setup_ok = True
finally:
if not client_available.done():
client_available.set_result(setup_ok)

async def async_publish_service(call: ServiceCall) -> None:
"""Handle MQTT publish service calls."""
Expand Down Expand Up @@ -565,6 +587,9 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
registry_hooks.popitem()[1]()
# Wait for all ACKs and stop the loop
await mqtt_client.async_disconnect()

# Cleanup MQTT client availability
hass.data.pop(DATA_MQTT_AVAILABLE, None)
# Store remaining subscriptions to be able to restore or reload them
# when the entry is set up again
if subscriptions := mqtt_client.subscriptions:
Expand Down
1 change: 1 addition & 0 deletions homeassistant/components/mqtt/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
CONF_TLS_INSECURE = "tls_insecure"

DATA_MQTT = "mqtt"
DATA_MQTT_AVAILABLE = "mqtt_client_available"

DEFAULT_PREFIX = "homeassistant"
DEFAULT_BIRTH_WILL_TOPIC = DEFAULT_PREFIX + "/status"
Expand Down
37 changes: 37 additions & 0 deletions homeassistant/components/mqtt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from __future__ import annotations

import asyncio
import os
from pathlib import Path
import tempfile
from typing import Any

import async_timeout
import voluptuous as vol

from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.typing import ConfigType
Expand All @@ -22,13 +25,16 @@
CONF_CLIENT_CERT,
CONF_CLIENT_KEY,
DATA_MQTT,
DATA_MQTT_AVAILABLE,
DEFAULT_ENCODING,
DEFAULT_QOS,
DEFAULT_RETAIN,
DOMAIN,
)
from .models import MqttData

AVAILABILITY_TIMEOUT = 30.0

TEMP_DIR_NAME = f"home-assistant-{DOMAIN}"

_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
Expand All @@ -41,6 +47,37 @@ def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None:
return not bool(hass.config_entries.async_entries(DOMAIN)[0].disabled_by)


async def async_wait_for_mqtt_client(hass: HomeAssistant) -> bool:
"""Wait for the MQTT client to become available.
Waits when mqtt set up is in progress,
It is not needed that the client is connected.
Returns True if the mqtt client is available.
Returns False when the client is not available.
"""
if not mqtt_config_entry_enabled(hass):
return False

entry = hass.config_entries.async_entries(DOMAIN)[0]
if entry.state == ConfigEntryState.LOADED:
return True

state_reached_future: asyncio.Future[bool]
if DATA_MQTT_AVAILABLE not in hass.data:
hass.data[DATA_MQTT_AVAILABLE] = state_reached_future = asyncio.Future()
else:
state_reached_future = hass.data[DATA_MQTT_AVAILABLE]
if state_reached_future.done():
return state_reached_future.result()

try:
async with async_timeout.timeout(AVAILABILITY_TIMEOUT):
# Await the client setup or an error state was received
return await state_reached_future
except asyncio.TimeoutError:
return False


def valid_topic(topic: Any) -> str:
"""Validate that this is a valid topic name/filter."""
validated_topic = cv.string(topic)
Expand Down
7 changes: 7 additions & 0 deletions homeassistant/components/mqtt_json/device_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ async def async_setup_scanner(
discovery_info: DiscoveryInfoType | None = None,
) -> bool:
"""Set up the MQTT JSON tracker."""
# Make sure MQTT integration is enabled and the client is available
# We cannot count on dependencies as the device_tracker platform setup
# also will be triggered when mqtt is loading the `device_tracker` platform
if not await mqtt.async_wait_for_mqtt_client(hass):
_LOGGER.error("MQTT integration is not available")
return False

devices = config[CONF_DEVICES]
qos = config[CONF_QOS]

Expand Down
6 changes: 6 additions & 0 deletions homeassistant/components/mqtt_room/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ async def async_setup_platform(
discovery_info: DiscoveryInfoType | None = None,
) -> None:
"""Set up MQTT room Sensor."""
# Make sure MQTT integration is enabled and the client is available
# We cannot count on dependencies as the sensor platform setup
# also will be triggered when mqtt is loading the `sensor` platform
if not await mqtt.async_wait_for_mqtt_client(hass):
_LOGGER.error("MQTT integration is not available")
return
async_add_entities(
[
MQTTRoomSensor(
Expand Down
8 changes: 2 additions & 6 deletions homeassistant/components/snips/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,8 @@

async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Activate Snips component."""
# Make sure MQTT is available and the entry is loaded
if not hass.config_entries.async_entries(
mqtt.DOMAIN
) or not await hass.config_entries.async_wait_component(
hass.config_entries.async_entries(mqtt.DOMAIN)[0]
):
# Make sure MQTT integration is enabled and the client is available
if not await mqtt.async_wait_for_mqtt_client(hass):
_LOGGER.error("MQTT integration is not available")
return False

Expand Down
21 changes: 21 additions & 0 deletions tests/components/manual_mqtt/test_alarm_control_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,3 +1506,24 @@ async def test_state_changes_are_published_to_mqtt(
mqtt_mock.async_publish.assert_called_once_with(
"alarm/state", STATE_ALARM_DISARMED, 0, True
)


async def test_no_mqtt(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None:
"""Test publishing of MQTT messages when state changes."""
assert await async_setup_component(
hass,
alarm_control_panel.DOMAIN,
{
alarm_control_panel.DOMAIN: {
"platform": "manual_mqtt",
"name": "test",
"state_topic": "alarm/state",
"command_topic": "alarm/command",
}
},
)
await hass.async_block_till_done()

entity_id = "alarm_control_panel.test"
assert hass.states.get(entity_id) is None
assert "MQTT integration is not available" in caplog.text
Loading

0 comments on commit 0bcda9f

Please sign in to comment.