Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions homeassistant/components/automation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from homeassistant.helpers import extract_domain_configs, script, condition
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.loader import get_platform
from homeassistant.util.dt import utcnow
import homeassistant.helpers.config_validation as cv
Expand Down Expand Up @@ -265,9 +266,15 @@ def is_on(self) -> bool:

@asyncio.coroutine
def async_added_to_hass(self) -> None:
"""Startup if initial_state."""
if self._initial_state:
yield from self.async_enable()
"""Startup with initial state or previous state."""
state = yield from async_get_last_state(self.hass, self.entity_id)
if state is None:
if self._initial_state:
yield from self.async_enable()
else:
self._last_triggered = state.attributes.get('last_triggered')
if state.state == STATE_ON:
yield from self.async_enable()

@asyncio.coroutine
def async_turn_on(self, **kwargs) -> None:
Expand Down
17 changes: 8 additions & 9 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def mock_async_start():

@ha.callback
def clear_instance(event):
"""Clear global instance."""
global INST_COUNT
INST_COUNT -= 1

Expand All @@ -152,20 +153,18 @@ def get_test_instance_port():


def mock_service(hass, domain, service):
"""Setup a fake service.

Return a list that logs all calls to fake service.
"""
"""Setup a fake service & return a list that logs calls to this service."""
calls = []

# pylint: disable=redefined-outer-name
@ha.callback
def mock_service(call):
@asyncio.coroutine
def mock_service_log(call): # pylint: disable=unnecessary-lambda
""""Mocked service call."""
calls.append(call)

# pylint: disable=unnecessary-lambda
hass.services.register(domain, service, mock_service)
if hass.loop.__dict__.get("_thread_ident", 0) == threading.get_ident():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer if we make an async_mock_service and have mock_service call that

hass.services.async_register(domain, service, mock_service_log)
else:
hass.services.register(domain, service, mock_service_log)

return calls

Expand Down
76 changes: 62 additions & 14 deletions tests/components/automation/test_init.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""The tests for the automation component."""
import unittest
import asyncio
from datetime import timedelta
import unittest
from unittest.mock import patch

from homeassistant.core import callback
from homeassistant.bootstrap import setup_component
from homeassistant.core import State
from homeassistant.bootstrap import setup_component, async_setup_component
import homeassistant.components.automation as automation
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.const import ATTR_ENTITY_ID, STATE_ON, STATE_OFF
from homeassistant.exceptions import HomeAssistantError
import homeassistant.util.dt as dt_util

from tests.common import get_test_home_assistant, assert_setup_component, \
fire_time_changed, mock_component
from tests.common import (
assert_setup_component, get_test_home_assistant, fire_time_changed,
mock_component, mock_service, mock_restore_cache)


# pylint: disable=invalid-name
Expand All @@ -22,14 +24,7 @@ def setUp(self):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
mock_component(self.hass, 'group')
self.calls = []

@callback
def record_call(service):
"""Helper to record calls."""
self.calls.append(service)

self.hass.services.register('test', 'automation', record_call)
self.calls = mock_service(self.hass, 'test', 'automation')

def tearDown(self):
"""Stop everything that was started."""
Expand Down Expand Up @@ -572,3 +567,56 @@ def test_reload_config_handles_load_fails(self):
self.hass.bus.fire('test_event')
self.hass.block_till_done()
assert len(self.calls) == 2


@asyncio.coroutine
def test_automation_restore_state(hass):
"""Ensure states are restored on startup."""
time = dt_util.utcnow()

mock_restore_cache(hass, (
State('automation.hello', STATE_ON),
State('automation.bye', STATE_OFF, {'last_triggered': time}),
))

config = {automation.DOMAIN: [{
'alias': 'hello',
'trigger': {
'platform': 'event',
'event_type': 'test_event_hello',
},
'action': {'service': 'test.automation'}
}, {
'alias': 'bye',
'trigger': {
'platform': 'event',
'event_type': 'test_event_bye',
},
'action': {'service': 'test.automation'}
}]}

assert (yield from async_setup_component(hass, automation.DOMAIN, config))

state = hass.states.get('automation.hello')
assert state
assert state.state == STATE_ON

state = hass.states.get('automation.bye')
assert state
assert state.state == STATE_OFF
assert state.attributes.get('last_triggered') == time

calls = mock_service(hass, 'test', 'automation')

assert automation.is_on(hass, 'automation.bye') is False

hass.bus.async_fire('test_event_bye')
yield from hass.async_block_till_done()
assert len(calls) == 0

assert automation.is_on(hass, 'automation.hello')

hass.bus.async_fire('test_event_hello')
yield from hass.async_block_till_done()

assert len(calls) == 1