-
-
Notifications
You must be signed in to change notification settings - Fork 37.4k
Fix MQTT retained message not being re-dispatched #12004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
61cf745
9020d50
d65a312
9b4e937
176665c
5847a5d
7f22a25
466fe3f
af2a772
ff236b3
3e5bbd9
fa0cf75
2c54d36
7e86228
fa761cb
bbf701d
467a56a
bae00b6
12a1c11
ceba81c
997b09d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,10 @@ | |
| https://home-assistant.io/components/mqtt/ | ||
| """ | ||
| import asyncio | ||
| from collections import namedtuple | ||
| from itertools import groupby | ||
| from typing import Optional | ||
| from operator import attrgetter | ||
| import logging | ||
| import os | ||
| import socket | ||
|
|
@@ -15,13 +19,12 @@ | |
|
|
||
| import voluptuous as vol | ||
|
|
||
| from homeassistant.helpers.typing import HomeAssistantType | ||
| from homeassistant.core import callback | ||
| from homeassistant.setup import async_prepare_setup_platform | ||
| from homeassistant.exceptions import HomeAssistantError | ||
| from homeassistant.loader import bind_hass | ||
| from homeassistant.helpers import template, config_validation as cv | ||
| from homeassistant.helpers.dispatcher import ( | ||
| async_dispatcher_connect, dispatcher_send) | ||
| from homeassistant.helpers import template, ConfigType, config_validation as cv | ||
| from homeassistant.helpers.entity import Entity | ||
| from homeassistant.util.async import ( | ||
| run_coroutine_threadsafe, run_callback_threadsafe) | ||
|
|
@@ -39,7 +42,6 @@ | |
| DATA_MQTT = 'mqtt' | ||
|
|
||
| SERVICE_PUBLISH = 'publish' | ||
| SIGNAL_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received' | ||
|
|
||
| CONF_EMBEDDED = 'embedded' | ||
| CONF_BROKER = 'broker' | ||
|
|
@@ -173,7 +175,6 @@ def valid_discovery_topic(value): | |
| vol.Optional(CONF_VALUE_TEMPLATE): cv.template, | ||
| }) | ||
|
|
||
|
|
||
| # Service call validation schema | ||
| MQTT_PUBLISH_SCHEMA = vol.Schema({ | ||
| vol.Required(ATTR_TOPIC): valid_publish_topic, | ||
|
|
@@ -221,32 +222,13 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None): | |
| @bind_hass | ||
| def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS, | ||
| encoding='utf-8'): | ||
| """Subscribe to an MQTT topic.""" | ||
| @callback | ||
| def async_mqtt_topic_subscriber(dp_topic, dp_payload, dp_qos): | ||
| """Match subscribed MQTT topic.""" | ||
| if not _match_topic(topic, dp_topic): | ||
| return | ||
|
|
||
| if encoding is not None: | ||
| try: | ||
| payload = dp_payload.decode(encoding) | ||
| _LOGGER.debug("Received message on %s: %s", dp_topic, payload) | ||
| except (AttributeError, UnicodeDecodeError): | ||
| _LOGGER.error("Illegal payload encoding %s from " | ||
| "MQTT topic: %s, Payload: %s", | ||
| encoding, dp_topic, dp_payload) | ||
| return | ||
| else: | ||
| _LOGGER.debug("Received binary message on %s", dp_topic) | ||
| payload = dp_payload | ||
| """Subscribe to an MQTT topic. | ||
|
|
||
| hass.async_run_job(msg_callback, dp_topic, payload, dp_qos) | ||
|
|
||
| async_remove = async_dispatcher_connect( | ||
| hass, SIGNAL_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber) | ||
|
|
||
| yield from hass.data[DATA_MQTT].async_subscribe(topic, qos) | ||
| Call the return value to unsubscribe. | ||
| """ | ||
| async_remove = \ | ||
| yield from hass.data[DATA_MQTT].async_subscribe(topic, msg_callback, | ||
| qos, encoding) | ||
| return async_remove | ||
|
|
||
|
|
||
|
|
@@ -308,7 +290,7 @@ def _async_setup_discovery(hass, config): | |
|
|
||
|
|
||
| @asyncio.coroutine | ||
| def async_setup(hass, config): | ||
| def async_setup(hass: HomeAssistantType, config: ConfigType): | ||
| """Start the MQTT protocol service.""" | ||
| conf = config.get(DOMAIN) | ||
|
|
||
|
|
@@ -351,17 +333,21 @@ def async_setup(hass, config): | |
| return False | ||
|
|
||
| # For cloudmqtt.com, secured connection, auto fill in certificate | ||
| if certificate is None and 19999 < port < 30000 and \ | ||
| broker.endswith('.cloudmqtt.com'): | ||
| if (certificate is None and 19999 < port < 30000 and | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please don't include these "style" fixes in PRs. It makes it difficult to see what this PR actually adds. If you want to do those, do it in a separate PR. Just know that generally, style fixes best case will make the code more readable, worst case introduce bugs that didn't exist before.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this PR it's fine, just don't do it in the future please.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I'll try to remember this. |
||
| broker.endswith('.cloudmqtt.com')): | ||
| certificate = os.path.join(os.path.dirname(__file__), | ||
| 'addtrustexternalcaroot.crt') | ||
|
|
||
| # When the certificate is set to auto, use bundled certs from requests | ||
| if certificate == 'auto': | ||
| certificate = requests.certs.where() | ||
|
|
||
| will_message = conf.get(CONF_WILL_MESSAGE) | ||
| birth_message = conf.get(CONF_BIRTH_MESSAGE) | ||
| will_message = None | ||
| if conf.get(CONF_WILL_MESSAGE) is not None: | ||
| will_message = Message(**conf.get(CONF_WILL_MESSAGE)) | ||
| birth_message = None | ||
| if conf.get(CONF_BIRTH_MESSAGE) is not None: | ||
| birth_message = Message(**conf.get(CONF_BIRTH_MESSAGE)) | ||
|
|
||
| # Be able to override versions other than TLSv1.0 under Python3.6 | ||
| conf_tls_version = conf.get(CONF_TLS_VERSION) | ||
|
|
@@ -414,8 +400,8 @@ def async_publish_service(call): | |
| template.Template(payload_template, hass).async_render() | ||
| except template.jinja2.TemplateError as exc: | ||
| _LOGGER.error( | ||
| "Unable to publish to '%s': rendering payload template of " | ||
| "'%s' failed because %s", | ||
| "Unable to publish to %s: rendering payload template of " | ||
| "%s failed because %s", | ||
| msg_topic, payload_template, exc) | ||
| return | ||
|
|
||
|
|
@@ -432,23 +418,29 @@ def async_publish_service(call): | |
| return True | ||
|
|
||
|
|
||
| Subscription = namedtuple('Subscription', | ||
| ['topic', 'callback', 'qos', 'encoding']) | ||
| Subscription.__new__.__defaults__ = (0, 'utf-8') | ||
|
|
||
| Message = namedtuple('Message', ['topic', 'payload', 'qos', 'retain']) | ||
| Message.__new__.__defaults__ = (0, False) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While we're at creating |
||
|
|
||
|
|
||
| class MQTT(object): | ||
| """Home Assistant MQTT client.""" | ||
|
|
||
| def __init__(self, hass, broker, port, client_id, keepalive, username, | ||
| password, certificate, client_key, client_cert, | ||
| tls_insecure, protocol, will_message, birth_message, | ||
| tls_version): | ||
| tls_insecure, protocol, will_message: Optional[Message], | ||
| birth_message: Optional[Message], tls_version): | ||
| """Initialize Home Assistant MQTT client.""" | ||
| import paho.mqtt.client as mqtt | ||
|
|
||
| self.hass = hass | ||
| self.broker = broker | ||
| self.port = port | ||
| self.keepalive = keepalive | ||
| self.wanted_topics = {} | ||
| self.subscribed_topics = {} | ||
| self.progress = {} | ||
| self.subscriptions = [] | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This list now holds all active subscriptions with the |
||
| self.birth_message = birth_message | ||
| self._mqttc = None | ||
| self._paho_lock = asyncio.Lock(loop=hass.loop) | ||
|
|
@@ -474,17 +466,12 @@ def __init__(self, hass, broker, port, client_id, keepalive, username, | |
| if tls_insecure is not None: | ||
| self._mqttc.tls_insecure_set(tls_insecure) | ||
|
|
||
| self._mqttc.on_subscribe = self._mqtt_on_subscribe | ||
| self._mqttc.on_unsubscribe = self._mqtt_on_unsubscribe | ||
| self._mqttc.on_connect = self._mqtt_on_connect | ||
| self._mqttc.on_disconnect = self._mqtt_on_disconnect | ||
| self._mqttc.on_message = self._mqtt_on_message | ||
|
|
||
| if will_message: | ||
| self._mqttc.will_set(will_message.get(ATTR_TOPIC), | ||
| will_message.get(ATTR_PAYLOAD), | ||
| will_message.get(ATTR_QOS), | ||
| will_message.get(ATTR_RETAIN)) | ||
| self._mqttc.will_set(*will_message) | ||
|
|
||
| @asyncio.coroutine | ||
| def async_publish(self, topic, payload, qos, retain): | ||
|
|
@@ -526,36 +513,53 @@ def stop(): | |
| return self.hass.async_add_job(stop) | ||
|
|
||
| @asyncio.coroutine | ||
| def async_subscribe(self, topic, qos): | ||
| """Subscribe to a topic. | ||
| def async_subscribe(self, topic, msg_callback, qos, encoding): | ||
| """Set up a subscription to a topic with the provided qos. | ||
|
|
||
| This method is a coroutine. | ||
| """ | ||
| if not isinstance(topic, str): | ||
| raise HomeAssistantError("topic need to be a string!") | ||
| raise HomeAssistantError("topic needs to be a string!") | ||
|
|
||
| with (yield from self._paho_lock): | ||
| if topic in self.subscribed_topics: | ||
| subscription = Subscription(topic, msg_callback, qos, encoding) | ||
| self.subscriptions.append(subscription) | ||
|
|
||
| yield from self._async_perform_subscription(topic, qos) | ||
|
|
||
| @callback | ||
| def async_remove(): | ||
| """Remove subscription.""" | ||
| if subscription not in self.subscriptions: | ||
| raise HomeAssistantError("Can't remove subscription twice") | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As this line of code should only ever be executed when |
||
| self.subscriptions.remove(subscription) | ||
|
|
||
| if any(other.topic == topic for other in self.subscriptions): | ||
| # Other subscriptions on topic remaining - don't unsubscribe. | ||
| return | ||
| self.wanted_topics[topic] = qos | ||
| result, mid = yield from self.hass.async_add_job( | ||
| self._mqttc.subscribe, topic, qos) | ||
| self.hass.async_add_job(self._async_unsubscribe(topic)) | ||
|
|
||
| _raise_on_error(result) | ||
| self.progress[mid] = topic | ||
| return async_remove | ||
|
|
||
| @asyncio.coroutine | ||
| def async_unsubscribe(self, topic): | ||
| """Unsubscribe from topic. | ||
| def _async_unsubscribe(self, topic): | ||
| """Unsubscribe from a topic. | ||
|
|
||
| This method is a coroutine. | ||
| """ | ||
| self.wanted_topics.pop(topic, None) | ||
| result, mid = yield from self.hass.async_add_job( | ||
| self._mqttc.unsubscribe, topic) | ||
| with (yield from self._paho_lock): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could potentially result in a race condition - so better safe than sorry and acquire the lock. |
||
| result, _ = yield from self.hass.async_add_job( | ||
| self._mqttc.unsubscribe, topic) | ||
| _raise_on_error(result) | ||
|
|
||
| @asyncio.coroutine | ||
| def _async_perform_subscription(self, topic, qos): | ||
| """Perform a paho-mqtt subscription.""" | ||
| _LOGGER.debug("Subscribing to %s", topic) | ||
|
|
||
| _raise_on_error(result) | ||
| self.progress[mid] = topic | ||
| with (yield from self._paho_lock): | ||
| result, _ = yield from self.hass.async_add_job( | ||
| self._mqttc.subscribe, topic, qos) | ||
| _raise_on_error(result) | ||
|
|
||
| def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code): | ||
| """On connect callback. | ||
|
|
@@ -571,50 +575,50 @@ def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code): | |
| self._mqttc.disconnect() | ||
| return | ||
|
|
||
| self.progress = {} | ||
| self.subscribed_topics = {} | ||
| for topic, qos in self.wanted_topics.items(): | ||
| self.hass.add_job(self.async_subscribe, topic, qos) | ||
| # Group subscriptions to only re-subscribe once for each topic. | ||
| keyfunc = attrgetter('topic') | ||
| for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), | ||
| keyfunc): | ||
| # Re-subscribe with the highest requested qos | ||
| max_qos = max(subscription.qos for subscription in subs) | ||
| self.hass.add_job(self._async_perform_subscription, topic, max_qos) | ||
|
|
||
| if self.birth_message: | ||
| self.hass.add_job(self.async_publish( | ||
| self.birth_message.get(ATTR_TOPIC), | ||
| self.birth_message.get(ATTR_PAYLOAD), | ||
| self.birth_message.get(ATTR_QOS), | ||
| self.birth_message.get(ATTR_RETAIN))) | ||
|
|
||
| def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos): | ||
| """Subscribe successful callback.""" | ||
| topic = self.progress.pop(mid, None) | ||
| if topic is None: | ||
| return | ||
| self.subscribed_topics[topic] = granted_qos[0] | ||
| self.hass.add_job(self.async_publish(*self.birth_message)) | ||
|
|
||
| def _mqtt_on_message(self, _mqttc, _userdata, msg): | ||
| """Message received callback.""" | ||
| dispatcher_send( | ||
| self.hass, SIGNAL_MQTT_MESSAGE_RECEIVED, msg.topic, msg.payload, | ||
| msg.qos | ||
| ) | ||
|
|
||
| def _mqtt_on_unsubscribe(self, _mqttc, _userdata, mid, granted_qos): | ||
| """Unsubscribe successful callback.""" | ||
| topic = self.progress.pop(mid, None) | ||
| if topic is None: | ||
| return | ||
| self.subscribed_topics.pop(topic, None) | ||
| self.hass.async_add_job(self._mqtt_handle_message, msg) | ||
|
|
||
| @callback | ||
| def _mqtt_handle_message(self, msg): | ||
| _LOGGER.debug("Received message on %s: %s", msg.topic, msg.payload) | ||
|
|
||
| for subscription in self.subscriptions: | ||
| if not _match_topic(subscription.topic, msg.topic): | ||
| continue | ||
|
|
||
| payload = msg.payload | ||
| if subscription.encoding is not None: | ||
| try: | ||
| payload = msg.payload.decode(subscription.encoding) | ||
| except (AttributeError, UnicodeDecodeError): | ||
| _LOGGER.warning("Can't decode payload %s on %s " | ||
| "with encoding %s", | ||
| msg.payload, msg.topic, | ||
| subscription.encoding) | ||
| return | ||
|
|
||
| self.hass.async_run_job(subscription.callback, | ||
| msg.topic, payload, msg.qos) | ||
|
|
||
| def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code): | ||
| """Disconnected callback.""" | ||
| self.progress = {} | ||
| self.subscribed_topics = {} | ||
|
|
||
| # When disconnected because of calling disconnect() | ||
| if result_code == 0: | ||
| return | ||
|
|
||
| tries = 0 | ||
| wait_time = 0 | ||
|
|
||
| while True: | ||
| try: | ||
|
|
@@ -693,7 +697,7 @@ def availability_message_received(topic, payload, qos): | |
| if self._availability_topic is not None: | ||
| yield from async_subscribe( | ||
| self.hass, self._availability_topic, | ||
| availability_message_received, self. _availability_qos) | ||
| availability_message_received, self._availability_qos) | ||
|
|
||
| @property | ||
| def available(self) -> bool: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As all callbacks are now handled by the
MQTTclass, the dispatcher calls are no longer required (*except for tests)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests should test the actual code, so let's make sure it is also not needed in tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'll do that