Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions homeassistant/helpers/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,8 @@ def custom_serializer(schema):

DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)

_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required("scene"): entity_domain("scene")})

SCRIPT_SCHEMA = vol.All(
ensure_list,
[
Expand All @@ -895,6 +897,7 @@ def custom_serializer(schema):
EVENT_SCHEMA,
CONDITION_SCHEMA,
DEVICE_ACTION_SCHEMA,
_SCRIPT_SCENE_SCHEMA,
)
],
)
24 changes: 24 additions & 0 deletions homeassistant/helpers/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
import voluptuous as vol

import homeassistant.components.device_automation as device_automation
import homeassistant.components.scene as scene
from homeassistant.core import HomeAssistant, Context, callback, CALLBACK_TYPE
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_CONDITION,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_TIMEOUT,
SERVICE_TURN_ON,
)
from homeassistant import exceptions
from homeassistant.helpers import (
Expand Down Expand Up @@ -46,6 +49,7 @@
CONF_DELAY = "delay"
CONF_WAIT_TEMPLATE = "wait_template"
CONF_CONTINUE = "continue_on_timeout"
CONF_SCENE = "scene"


ACTION_DELAY = "delay"
Expand All @@ -54,6 +58,7 @@
ACTION_FIRE_EVENT = "event"
ACTION_CALL_SERVICE = "call_service"
ACTION_DEVICE_AUTOMATION = "device"
ACTION_ACTIVATE_SCENE = "scene"


def _determine_action(action):
Expand All @@ -73,6 +78,9 @@ def _determine_action(action):
if CONF_DEVICE_ID in action:
return ACTION_DEVICE_AUTOMATION

if CONF_SCENE in action:
return ACTION_ACTIVATE_SCENE

return ACTION_CALL_SERVICE


Expand Down Expand Up @@ -147,6 +155,7 @@ def __init__(
ACTION_FIRE_EVENT: self._async_fire_event,
ACTION_CALL_SERVICE: self._async_call_service,
ACTION_DEVICE_AUTOMATION: self._async_device_automation,
ACTION_ACTIVATE_SCENE: self._async_activate_scene,
}

@property
Expand Down Expand Up @@ -362,6 +371,21 @@ async def _async_device_automation(self, action, variables, context):
self.hass, action, variables, context
)

async def _async_activate_scene(self, action, variables, context):
"""Activate the scene specified in the action.

This method is a coroutine.
"""
self.last_action = action.get(CONF_ALIAS, "activate scene")
self._log("Executing step %s" % self.last_action)
await self.hass.services.async_call(
scene.DOMAIN,
SERVICE_TURN_ON,
{ATTR_ENTITY_ID: action[CONF_SCENE]},
blocking=True,
context=context,
)

async def _async_fire_event(self, action, variables, context):
"""Fire an event."""
self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT])
Expand Down
27 changes: 27 additions & 0 deletions tests/helpers/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import voluptuous as vol
import pytest

import homeassistant.components.scene as scene
from homeassistant import exceptions
from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON
from homeassistant.core import Context, callback

# Otherwise can't test just this file (import order issue)
Expand Down Expand Up @@ -120,6 +122,31 @@ def record_call(service):
assert calls[0].data.get("hello") == "world"


async def test_activating_scene(hass):
"""Test the activation of a scene."""
calls = []
context = Context()

@callback
def record_call(service):
"""Add recorded event to set."""
calls.append(service)

hass.services.async_register(scene.DOMAIN, SERVICE_TURN_ON, record_call)

hass.async_add_job(
ft.partial(
script.call_from_config, hass, {"scene": "scene.hello"}, context=context
)
)

await hass.async_block_till_done()

assert len(calls) == 1
assert calls[0].context is context
assert calls[0].data.get(ATTR_ENTITY_ID) == "scene.hello"


async def test_calling_service_template(hass):
"""Test the calling of a service."""
calls = []
Expand Down