Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions homeassistant/components/switch/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/switch.mqtt/
"""
import asyncio
import logging

import voluptuous as vol
Expand All @@ -17,9 +16,10 @@
from homeassistant.components.switch import SwitchDevice
from homeassistant.const import (
CONF_NAME, CONF_OPTIMISTIC, CONF_VALUE_TEMPLATE, CONF_PAYLOAD_OFF,
CONF_PAYLOAD_ON, CONF_ICON)
CONF_PAYLOAD_ON, CONF_ICON, STATE_ON)
import homeassistant.components.mqtt as mqtt
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.restore_state import async_get_last_state

_LOGGER = logging.getLogger(__name__)

Expand All @@ -39,8 +39,8 @@
}).extend(mqtt.MQTT_AVAILABILITY_SCHEMA.schema)


@asyncio.coroutine
def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
async def async_setup_platform(hass, config, async_add_devices,
discovery_info=None):
"""Set up the MQTT switch."""
if discovery_info is not None:
config = PLATFORM_SCHEMA(discovery_info)
Expand Down Expand Up @@ -88,10 +88,9 @@ def __init__(self, name, icon,
self._optimistic = optimistic
self._template = value_template

@asyncio.coroutine
def async_added_to_hass(self):
async def async_added_to_hass(self):
"""Subscribe to MQTT events."""
yield from super().async_added_to_hass()
await super().async_added_to_hass()

@callback
def state_message_received(topic, payload, qos):
Expand All @@ -110,10 +109,16 @@ def state_message_received(topic, payload, qos):
# Force into optimistic mode.
self._optimistic = True
else:
yield from mqtt.async_subscribe(
await mqtt.async_subscribe(
self.hass, self._state_topic, state_message_received,
self._qos)

if self._optimistic:
last_state = await async_get_last_state(self.hass,
self.entity_id)
if last_state:
self._state = last_state.state == STATE_ON

@property
def should_poll(self):
"""Return the polling state."""
Expand All @@ -139,8 +144,7 @@ def icon(self):
"""Return the icon."""
return self._icon

@asyncio.coroutine
def async_turn_on(self, **kwargs):
async def async_turn_on(self, **kwargs):
"""Turn the device on.

This method is a coroutine.
Expand All @@ -153,8 +157,7 @@ def async_turn_on(self, **kwargs):
self._state = True
self.async_schedule_update_ha_state()

@asyncio.coroutine
def async_turn_off(self, **kwargs):
async def async_turn_off(self, **kwargs):
"""Turn the device off.

This method is a coroutine.
Expand Down
30 changes: 18 additions & 12 deletions tests/components/switch/test_mqtt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""The tests for the MQTT switch platform."""
import unittest
from unittest.mock import patch

from homeassistant.setup import setup_component
from homeassistant.const import STATE_ON, STATE_OFF, STATE_UNAVAILABLE,\
ATTR_ASSUMED_STATE
import homeassistant.core as ha
import homeassistant.components.switch as switch
from tests.common import (
mock_mqtt_component, fire_mqtt_message, get_test_home_assistant)
mock_mqtt_component, fire_mqtt_message, get_test_home_assistant, mock_coro)


class TestSwitchMQTT(unittest.TestCase):
Expand Down Expand Up @@ -52,19 +54,23 @@ def test_controlling_state_via_topic(self):

def test_sending_mqtt_commands_and_optimistic(self):
"""Test the sending MQTT commands in optimistic mode."""
assert setup_component(self.hass, switch.DOMAIN, {
switch.DOMAIN: {
'platform': 'mqtt',
'name': 'test',
'command_topic': 'command-topic',
'payload_on': 'beer on',
'payload_off': 'beer off',
'qos': '2'
}
})
fake_state = ha.State('switch.test', 'on')

with patch('homeassistant.components.switch.mqtt.async_get_last_state',
return_value=mock_coro(fake_state)):
assert setup_component(self.hass, switch.DOMAIN, {
switch.DOMAIN: {
'platform': 'mqtt',
'name': 'test',
'command_topic': 'command-topic',
'payload_on': 'beer on',
'payload_off': 'beer off',
'qos': '2'
}
})

state = self.hass.states.get('switch.test')
self.assertEqual(STATE_OFF, state.state)
self.assertEqual(STATE_ON, state.state)
self.assertTrue(state.attributes.get(ATTR_ASSUMED_STATE))

switch.turn_on(self.hass, 'switch.test')
Expand Down