Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 101 additions & 97 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
https://home-assistant.io/components/mqtt/
"""
import asyncio
from collections import namedtuple
from itertools import groupby
from typing import Optional
from operator import attrgetter
import logging
import os
import socket
Expand All @@ -15,13 +19,12 @@

import voluptuous as vol

from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.core import callback
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass
from homeassistant.helpers import template, config_validation as cv
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, dispatcher_send)
from homeassistant.helpers import template, ConfigType, config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.util.async import (
run_coroutine_threadsafe, run_callback_threadsafe)
Expand All @@ -39,7 +42,6 @@
DATA_MQTT = 'mqtt'

SERVICE_PUBLISH = 'publish'
SIGNAL_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As all callbacks are now handled by the MQTT class, the dispatcher calls are no longer required (*except for tests)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests should test the actual code, so let's make sure it is also not needed in tests.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll do that


CONF_EMBEDDED = 'embedded'
CONF_BROKER = 'broker'
Expand Down Expand Up @@ -173,7 +175,6 @@ def valid_discovery_topic(value):
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
})


# Service call validation schema
MQTT_PUBLISH_SCHEMA = vol.Schema({
vol.Required(ATTR_TOPIC): valid_publish_topic,
Expand Down Expand Up @@ -221,32 +222,13 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
@bind_hass
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS,
encoding='utf-8'):
"""Subscribe to an MQTT topic."""
@callback
def async_mqtt_topic_subscriber(dp_topic, dp_payload, dp_qos):
"""Match subscribed MQTT topic."""
if not _match_topic(topic, dp_topic):
return

if encoding is not None:
try:
payload = dp_payload.decode(encoding)
_LOGGER.debug("Received message on %s: %s", dp_topic, payload)
except (AttributeError, UnicodeDecodeError):
_LOGGER.error("Illegal payload encoding %s from "
"MQTT topic: %s, Payload: %s",
encoding, dp_topic, dp_payload)
return
else:
_LOGGER.debug("Received binary message on %s", dp_topic)
payload = dp_payload
"""Subscribe to an MQTT topic.

hass.async_run_job(msg_callback, dp_topic, payload, dp_qos)

async_remove = async_dispatcher_connect(
hass, SIGNAL_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber)

yield from hass.data[DATA_MQTT].async_subscribe(topic, qos)
Call the return value to unsubscribe.
"""
async_remove = \
yield from hass.data[DATA_MQTT].async_subscribe(topic, msg_callback,
qos, encoding)
return async_remove


Expand Down Expand Up @@ -308,7 +290,7 @@ def _async_setup_discovery(hass, config):


@asyncio.coroutine
def async_setup(hass, config):
def async_setup(hass: HomeAssistantType, config: ConfigType):
"""Start the MQTT protocol service."""
conf = config.get(DOMAIN)

Expand Down Expand Up @@ -351,17 +333,21 @@ def async_setup(hass, config):
return False

# For cloudmqtt.com, secured connection, auto fill in certificate
if certificate is None and 19999 < port < 30000 and \
broker.endswith('.cloudmqtt.com'):
if (certificate is None and 19999 < port < 30000 and
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't include these "style" fixes in PRs. It makes it difficult to see what this PR actually adds.

If you want to do those, do it in a separate PR. Just know that generally, style fixes best case will make the code more readable, worst case introduce bugs that didn't exist before.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR it's fine, just don't do it in the future please.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll try to remember this.

broker.endswith('.cloudmqtt.com')):
certificate = os.path.join(os.path.dirname(__file__),
'addtrustexternalcaroot.crt')

# When the certificate is set to auto, use bundled certs from requests
if certificate == 'auto':
certificate = requests.certs.where()

will_message = conf.get(CONF_WILL_MESSAGE)
birth_message = conf.get(CONF_BIRTH_MESSAGE)
will_message = None
if conf.get(CONF_WILL_MESSAGE) is not None:
will_message = Message(**conf.get(CONF_WILL_MESSAGE))
birth_message = None
if conf.get(CONF_BIRTH_MESSAGE) is not None:
birth_message = Message(**conf.get(CONF_BIRTH_MESSAGE))

# Be able to override versions other than TLSv1.0 under Python3.6
conf_tls_version = conf.get(CONF_TLS_VERSION)
Expand Down Expand Up @@ -414,8 +400,8 @@ def async_publish_service(call):
template.Template(payload_template, hass).async_render()
except template.jinja2.TemplateError as exc:
_LOGGER.error(
"Unable to publish to '%s': rendering payload template of "
"'%s' failed because %s",
"Unable to publish to %s: rendering payload template of "
"%s failed because %s",
msg_topic, payload_template, exc)
return

Expand All @@ -432,23 +418,29 @@ def async_publish_service(call):
return True


Subscription = namedtuple('Subscription',
['topic', 'callback', 'qos', 'encoding'])
Subscription.__new__.__defaults__ = (0, 'utf-8')

Message = namedtuple('Message', ['topic', 'payload', 'qos', 'retain'])
Message.__new__.__defaults__ = (0, False)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're at creating namedtuples, let's also create another struct to copy MQTT messages around.



class MQTT(object):
"""Home Assistant MQTT client."""

def __init__(self, hass, broker, port, client_id, keepalive, username,
password, certificate, client_key, client_cert,
tls_insecure, protocol, will_message, birth_message,
tls_version):
tls_insecure, protocol, will_message: Optional[Message],
birth_message: Optional[Message], tls_version):
"""Initialize Home Assistant MQTT client."""
import paho.mqtt.client as mqtt

self.hass = hass
self.broker = broker
self.port = port
self.keepalive = keepalive
self.wanted_topics = {}
self.subscribed_topics = {}
self.progress = {}
self.subscriptions = []
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This list now holds all active subscriptions with the Subscription tuple defined before. When a new subscription arrives it's first added to this list and then the paho subscribe call is made.
On a reconnect, this is also used to re-create all subscriptions as before.

self.birth_message = birth_message
self._mqttc = None
self._paho_lock = asyncio.Lock(loop=hass.loop)
Expand All @@ -474,17 +466,12 @@ def __init__(self, hass, broker, port, client_id, keepalive, username,
if tls_insecure is not None:
self._mqttc.tls_insecure_set(tls_insecure)

self._mqttc.on_subscribe = self._mqtt_on_subscribe
self._mqttc.on_unsubscribe = self._mqtt_on_unsubscribe
self._mqttc.on_connect = self._mqtt_on_connect
self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message

if will_message:
self._mqttc.will_set(will_message.get(ATTR_TOPIC),
will_message.get(ATTR_PAYLOAD),
will_message.get(ATTR_QOS),
will_message.get(ATTR_RETAIN))
self._mqttc.will_set(*will_message)

@asyncio.coroutine
def async_publish(self, topic, payload, qos, retain):
Expand Down Expand Up @@ -526,36 +513,53 @@ def stop():
return self.hass.async_add_job(stop)

@asyncio.coroutine
def async_subscribe(self, topic, qos):
"""Subscribe to a topic.
def async_subscribe(self, topic, msg_callback, qos, encoding):
"""Set up a subscription to a topic with the provided qos.

This method is a coroutine.
"""
if not isinstance(topic, str):
raise HomeAssistantError("topic need to be a string!")
raise HomeAssistantError("topic needs to be a string!")

with (yield from self._paho_lock):
if topic in self.subscribed_topics:
subscription = Subscription(topic, msg_callback, qos, encoding)
self.subscriptions.append(subscription)

yield from self._async_perform_subscription(topic, qos)

@callback
def async_remove():
"""Remove subscription."""
if subscription not in self.subscriptions:
raise HomeAssistantError("Can't remove subscription twice")
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this line of code should only ever be executed when async_remove is called twice, I think it's good to really fail and not just put out an error log message.

self.subscriptions.remove(subscription)

if any(other.topic == topic for other in self.subscriptions):
# Other subscriptions on topic remaining - don't unsubscribe.
return
self.wanted_topics[topic] = qos
result, mid = yield from self.hass.async_add_job(
self._mqttc.subscribe, topic, qos)
self.hass.async_add_job(self._async_unsubscribe(topic))

_raise_on_error(result)
self.progress[mid] = topic
return async_remove

@asyncio.coroutine
def async_unsubscribe(self, topic):
"""Unsubscribe from topic.
def _async_unsubscribe(self, topic):
"""Unsubscribe from a topic.

This method is a coroutine.
"""
self.wanted_topics.pop(topic, None)
result, mid = yield from self.hass.async_add_job(
self._mqttc.unsubscribe, topic)
with (yield from self._paho_lock):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could potentially result in a race condition - so better safe than sorry and acquire the lock.

result, _ = yield from self.hass.async_add_job(
self._mqttc.unsubscribe, topic)
_raise_on_error(result)

@asyncio.coroutine
def _async_perform_subscription(self, topic, qos):
"""Perform a paho-mqtt subscription."""
_LOGGER.debug("Subscribing to %s", topic)

_raise_on_error(result)
self.progress[mid] = topic
with (yield from self._paho_lock):
result, _ = yield from self.hass.async_add_job(
self._mqttc.subscribe, topic, qos)
_raise_on_error(result)

def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code):
"""On connect callback.
Expand All @@ -571,50 +575,50 @@ def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code):
self._mqttc.disconnect()
return

self.progress = {}
self.subscribed_topics = {}
for topic, qos in self.wanted_topics.items():
self.hass.add_job(self.async_subscribe, topic, qos)
# Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter('topic')
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc),
keyfunc):
# Re-subscribe with the highest requested qos
max_qos = max(subscription.qos for subscription in subs)
self.hass.add_job(self._async_perform_subscription, topic, max_qos)

if self.birth_message:
self.hass.add_job(self.async_publish(
self.birth_message.get(ATTR_TOPIC),
self.birth_message.get(ATTR_PAYLOAD),
self.birth_message.get(ATTR_QOS),
self.birth_message.get(ATTR_RETAIN)))

def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos):
"""Subscribe successful callback."""
topic = self.progress.pop(mid, None)
if topic is None:
return
self.subscribed_topics[topic] = granted_qos[0]
self.hass.add_job(self.async_publish(*self.birth_message))

def _mqtt_on_message(self, _mqttc, _userdata, msg):
"""Message received callback."""
dispatcher_send(
self.hass, SIGNAL_MQTT_MESSAGE_RECEIVED, msg.topic, msg.payload,
msg.qos
)

def _mqtt_on_unsubscribe(self, _mqttc, _userdata, mid, granted_qos):
"""Unsubscribe successful callback."""
topic = self.progress.pop(mid, None)
if topic is None:
return
self.subscribed_topics.pop(topic, None)
self.hass.async_add_job(self._mqtt_handle_message, msg)

@callback
def _mqtt_handle_message(self, msg):
_LOGGER.debug("Received message on %s: %s", msg.topic, msg.payload)

for subscription in self.subscriptions:
if not _match_topic(subscription.topic, msg.topic):
continue

payload = msg.payload
if subscription.encoding is not None:
try:
payload = msg.payload.decode(subscription.encoding)
except (AttributeError, UnicodeDecodeError):
_LOGGER.warning("Can't decode payload %s on %s "
"with encoding %s",
msg.payload, msg.topic,
subscription.encoding)
return

self.hass.async_run_job(subscription.callback,
msg.topic, payload, msg.qos)

def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code):
"""Disconnected callback."""
self.progress = {}
self.subscribed_topics = {}

# When disconnected because of calling disconnect()
if result_code == 0:
return

tries = 0
wait_time = 0

while True:
try:
Expand Down Expand Up @@ -693,7 +697,7 @@ def availability_message_received(topic, payload, qos):
if self._availability_topic is not None:
yield from async_subscribe(
self.hass, self._availability_topic,
availability_message_received, self. _availability_qos)
availability_message_received, self._availability_qos)

@property
def available(self) -> bool:
Expand Down
Loading