Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions homeassistant/components/device_tracker/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,14 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None):
devices = config[CONF_DEVICES]
qos = config[CONF_QOS]

dev_id_lookup = {}

@callback
def async_tracker_message_received(topic, payload, qos):
"""Handle received MQTT message."""
hass.async_add_job(
async_see(dev_id=dev_id_lookup[topic], location_name=payload))

for dev_id, topic in devices.items():
dev_id_lookup[topic] = dev_id
@callback
def async_message_received(topic, payload, qos, dev_id=dev_id):
"""Handle received MQTT message."""
hass.async_add_job(
async_see(dev_id=dev_id, location_name=payload))

yield from mqtt.async_subscribe(
hass, topic, async_tracker_message_received, qos)
hass, topic, async_message_received, qos)

return True
42 changes: 18 additions & 24 deletions homeassistant/components/device_tracker/mqtt_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,26 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None):
devices = config[CONF_DEVICES]
qos = config[CONF_QOS]

dev_id_lookup = {}

@callback
def async_tracker_message_received(topic, payload, qos):
"""Handle received MQTT message."""
dev_id = dev_id_lookup[topic]

try:
data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload))
except vol.MultipleInvalid:
_LOGGER.error("Skipping update for following data "
"because of missing or malformatted data: %s",
payload)
return
except ValueError:
_LOGGER.error("Error parsing JSON payload: %s", payload)
return

kwargs = _parse_see_args(dev_id, data)
hass.async_add_job(
async_see(**kwargs))

for dev_id, topic in devices.items():
dev_id_lookup[topic] = dev_id
@callback
def async_message_received(topic, payload, qos, dev_id=dev_id):
"""Handle received MQTT message."""
try:
data = GPS_JSON_PAYLOAD_SCHEMA(json.loads(payload))
except vol.MultipleInvalid:
_LOGGER.error("Skipping update for following data "
"because of missing or malformatted data: %s",
payload)
return
except ValueError:
_LOGGER.error("Error parsing JSON payload: %s", payload)
return

kwargs = _parse_see_args(dev_id, data)
hass.async_add_job(async_see(**kwargs))

yield from mqtt.async_subscribe(
hass, topic, async_tracker_message_received, qos)
hass, topic, async_message_received, qos)

return True

Expand Down
76 changes: 76 additions & 0 deletions tests/components/device_tracker/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,79 @@ def test_new_message(self):
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertEqual(location, self.hass.states.get(entity_id).state)

def test_single_level_wildcard_topic(self):
"""Test single level wildcard topic."""
dev_id = 'paulus'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = '/location/+/paulus'
topic = '/location/room/paulus'
location = 'work'

self.hass.config.components = set(['mqtt', 'zone'])
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertEqual(location, self.hass.states.get(entity_id).state)

def test_multi_level_wildcard_topic(self):
"""Test multi level wildcard topic."""
dev_id = 'paulus'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = '/location/#'
topic = '/location/room/paulus'
location = 'work'

self.hass.config.components = set(['mqtt', 'zone'])
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertEqual(location, self.hass.states.get(entity_id).state)

def test_single_level_wildcard_topic_not_matching(self):
"""Test not matching single level wildcard topic."""
dev_id = 'paulus'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = '/location/+/paulus'
topic = '/location/paulus'
location = 'work'

self.hass.config.components = set(['mqtt', 'zone'])
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertIsNone(self.hass.states.get(entity_id))

def test_multi_level_wildcard_topic_not_matching(self):
"""Test not matching multi level wildcard topic."""
dev_id = 'paulus'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = '/location/#'
topic = '/somewhere/room/paulus'
location = 'work'

self.hass.config.components = set(['mqtt', 'zone'])
assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertIsNone(self.hass.states.get(entity_id))
74 changes: 74 additions & 0 deletions tests/components/device_tracker/test_mqtt_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,77 @@ def test_incomplete_message(self):
"Skipping update for following data because of missing "
"or malformatted data: {\"longitude\": 2.0}",
test_handle.output[0])

def test_single_level_wildcard_topic(self):
"""Test single level wildcard topic."""
dev_id = 'zanzito'
subscription = 'location/+/zanzito'
topic = 'location/room/zanzito'
location = json.dumps(LOCATION_MESSAGE)

assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt_json',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
state = self.hass.states.get('device_tracker.zanzito')
self.assertEqual(state.attributes.get('latitude'), 2.0)
self.assertEqual(state.attributes.get('longitude'), 1.0)

def test_multi_level_wildcard_topic(self):
"""Test multi level wildcard topic."""
dev_id = 'zanzito'
subscription = 'location/#'
topic = 'location/zanzito'
location = json.dumps(LOCATION_MESSAGE)

assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt_json',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
state = self.hass.states.get('device_tracker.zanzito')
self.assertEqual(state.attributes.get('latitude'), 2.0)
self.assertEqual(state.attributes.get('longitude'), 1.0)

def test_single_level_wildcard_topic_not_matching(self):
"""Test not matching single level wildcard topic."""
dev_id = 'zanzito'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = 'location/+/zanzito'
topic = 'location/zanzito'
location = json.dumps(LOCATION_MESSAGE)

assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt_json',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertIsNone(self.hass.states.get(entity_id))

def test_multi_level_wildcard_topic_not_matching(self):
"""Test not matching multi level wildcard topic."""
dev_id = 'zanzito'
entity_id = device_tracker.ENTITY_ID_FORMAT.format(dev_id)
subscription = 'location/#'
topic = 'somewhere/zanzito'
location = json.dumps(LOCATION_MESSAGE)

assert setup_component(self.hass, device_tracker.DOMAIN, {
device_tracker.DOMAIN: {
CONF_PLATFORM: 'mqtt_json',
'devices': {dev_id: subscription}
}
})
fire_mqtt_message(self.hass, topic, location)
self.hass.block_till_done()
self.assertIsNone(self.hass.states.get(entity_id))