Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ async def _async_unsubscribe(self, topic: str) -> None:

This method is a coroutine.
"""
_LOGGER.debug("Unsubscribing from %s", topic)
async with self._paho_lock:
result: int = None
result, _ = await self.hass.async_add_executor_job(
Expand Down
12 changes: 8 additions & 4 deletions homeassistant/components/mqtt/device_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,17 @@ async def update_trigger(self, config, discovery_hash, remove_signal):
self.remove_signal = remove_signal
self.type = config[CONF_TYPE]
self.subtype = config[CONF_SUBTYPE]
self.topic = config[CONF_TOPIC]
self.payload = config[CONF_PAYLOAD]
self.qos = config[CONF_QOS]
topic_changed = self.topic != config[CONF_TOPIC]
self.topic = config[CONF_TOPIC]

# Unsubscribe+subscribe if this trigger is in use
for trig in self.trigger_instances:
await trig.async_attach_trigger()
# Unsubscribe+subscribe if this trigger is in use and topic has changed
# If topic is same unsubscribe+subscribe will execute in the wrong order
# because unsubscribe is done with help of async_create_task
if topic_changed:
for trig in self.trigger_instances:
await trig.async_attach_trigger()

def detach_trigger(self):
"""Remove MQTT device trigger."""
Expand Down
62 changes: 61 additions & 1 deletion tests/components/mqtt/test_device_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
assert_lists_same,
async_fire_mqtt_message,
async_get_device_automations,
async_mock_mqtt_component,
async_mock_service,
mock_device_registry,
mock_registry,
Expand Down Expand Up @@ -456,7 +457,7 @@ async def test_if_fires_on_mqtt_message_after_update(
await hass.async_block_till_done()
assert len(calls) == 1

# Update the trigger
# Update the trigger with different topic
async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data2)
await hass.async_block_till_done()

Expand All @@ -468,6 +469,65 @@ async def test_if_fires_on_mqtt_message_after_update(
await hass.async_block_till_done()
assert len(calls) == 2

# Update the trigger with same topic
async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data2)
await hass.async_block_till_done()

async_fire_mqtt_message(hass, "foobar/triggers/button1", "")
await hass.async_block_till_done()
assert len(calls) == 2

async_fire_mqtt_message(hass, "foobar/triggers/buttonOne", "")
await hass.async_block_till_done()
assert len(calls) == 3


async def test_no_resubscribe_same_topic(hass, device_reg, mqtt_mock):
"""Test subscription to topics without change."""
mock_mqtt = await async_mock_mqtt_component(hass)
config_entry = MockConfigEntry(domain=DOMAIN, data={})
config_entry.add_to_hass(hass)
await async_start(hass, "homeassistant", {}, config_entry)

data1 = (
'{ "automation_type":"trigger",'
' "device":{"identifiers":["0AFFD2"]},'
' "topic": "foobar/triggers/button1",'
' "type": "button_short_press",'
' "subtype": "button_1" }'
)
async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1)
await hass.async_block_till_done()
device_entry = device_reg.async_get_device({("mqtt", "0AFFD2")}, set())

assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger": {
"platform": "device",
"domain": DOMAIN,
"device_id": device_entry.id,
"discovery_id": "bla1",
"type": "button_short_press",
"subtype": "button_1",
},
"action": {
"service": "test.automation",
"data_template": {"some": ("short_press")},
},
},
]
},
)

call_count = mock_mqtt.async_subscribe.call_count
async_fire_mqtt_message(hass, "homeassistant/device_automation/bla1/config", data1)
await hass.async_block_till_done()
assert mock_mqtt.async_subscribe.call_count == call_count


async def test_not_fires_on_mqtt_message_after_remove_by_mqtt(
hass, device_reg, calls, mqtt_mock
Expand Down