Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions homeassistant/helpers/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,9 +888,11 @@ def script_action(value: Any) -> dict:

SCRIPT_SCHEMA = vol.All(ensure_list, [script_action])

SCRIPT_ACTION_BASE_SCHEMA = {vol.Optional(CONF_ALIAS): string}

EVENT_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_EVENT): string,
vol.Optional(CONF_EVENT_DATA): vol.All(dict, template_complex),
vol.Optional(CONF_EVENT_DATA_TEMPLATE): vol.All(dict, template_complex),
Expand All @@ -900,7 +902,7 @@ def script_action(value: Any) -> dict:
SERVICE_SCHEMA = vol.All(
vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Exclusive(CONF_SERVICE, "service name"): vol.Any(
service, dynamic_template
),
Expand All @@ -920,9 +922,12 @@ def script_action(value: Any) -> dict:
vol.Coerce(float), vol.All(str, entity_domain("input_number"))
)

CONDITION_BASE_SCHEMA = {vol.Optional(CONF_ALIAS): string}

NUMERIC_STATE_CONDITION_SCHEMA = vol.All(
vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "numeric_state",
vol.Required(CONF_ENTITY_ID): entity_ids,
vol.Optional(CONF_ATTRIBUTE): str,
Expand All @@ -935,6 +940,7 @@ def script_action(value: Any) -> dict:
)

STATE_CONDITION_BASE_SCHEMA = {
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "state",
vol.Required(CONF_ENTITY_ID): entity_ids,
vol.Optional(CONF_ATTRIBUTE): str,
Expand Down Expand Up @@ -975,6 +981,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name
SUN_CONDITION_SCHEMA = vol.All(
vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "sun",
vol.Optional("before"): sun_event,
vol.Optional("before_offset"): time_period,
Expand All @@ -989,6 +996,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

TEMPLATE_CONDITION_SCHEMA = vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "template",
vol.Required(CONF_VALUE_TEMPLATE): template,
}
Expand All @@ -997,6 +1005,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name
TIME_CONDITION_SCHEMA = vol.All(
vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "time",
"before": vol.Any(time, vol.All(str, entity_domain("input_datetime"))),
"after": vol.Any(time, vol.All(str, entity_domain("input_datetime"))),
Expand All @@ -1008,6 +1017,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

ZONE_CONDITION_SCHEMA = vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "zone",
vol.Required(CONF_ENTITY_ID): entity_ids,
"zone": entity_ids,
Expand All @@ -1019,6 +1029,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

AND_CONDITION_SCHEMA = vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "and",
vol.Required(CONF_CONDITIONS): vol.All(
ensure_list,
Expand All @@ -1030,6 +1041,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

OR_CONDITION_SCHEMA = vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "or",
vol.Required(CONF_CONDITIONS): vol.All(
ensure_list,
Expand All @@ -1041,6 +1053,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

NOT_CONDITION_SCHEMA = vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "not",
vol.Required(CONF_CONDITIONS): vol.All(
ensure_list,
Expand All @@ -1052,6 +1065,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

DEVICE_CONDITION_BASE_SCHEMA = vol.Schema(
{
**CONDITION_BASE_SCHEMA,
vol.Required(CONF_CONDITION): "device",
vol.Required(CONF_DEVICE_ID): str,
vol.Required(CONF_DOMAIN): str,
Expand Down Expand Up @@ -1087,31 +1101,37 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

_SCRIPT_DELAY_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_DELAY): positive_time_period_template,
}
)

_SCRIPT_WAIT_TEMPLATE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_WAIT_TEMPLATE): template,
vol.Optional(CONF_TIMEOUT): positive_time_period_template,
vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean,
}
)

DEVICE_ACTION_BASE_SCHEMA = vol.Schema(
{vol.Required(CONF_DEVICE_ID): string, vol.Required(CONF_DOMAIN): str}
{
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_DEVICE_ID): string,
vol.Required(CONF_DOMAIN): str,
}
)

DEVICE_ACTION_SCHEMA = DEVICE_ACTION_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)

_SCRIPT_SCENE_SCHEMA = vol.Schema({vol.Required(CONF_SCENE): entity_domain("scene")})
_SCRIPT_SCENE_SCHEMA = vol.Schema(
{**SCRIPT_ACTION_BASE_SCHEMA, vol.Required(CONF_SCENE): entity_domain("scene")}
)

_SCRIPT_REPEAT_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_REPEAT): vol.All(
{
vol.Exclusive(CONF_COUNT, "repeat"): vol.Any(vol.Coerce(int), template),
Expand All @@ -1130,11 +1150,12 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

_SCRIPT_CHOOSE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_CHOOSE): vol.All(
ensure_list,
[
{
vol.Optional(CONF_ALIAS): string,
vol.Required(CONF_CONDITIONS): vol.All(
ensure_list, [CONDITION_SCHEMA]
),
Expand All @@ -1148,7 +1169,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

_SCRIPT_WAIT_FOR_TRIGGER_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_WAIT_FOR_TRIGGER): TRIGGER_SCHEMA,
vol.Optional(CONF_TIMEOUT): positive_time_period_template,
vol.Optional(CONF_CONTINUE_ON_TIMEOUT): boolean,
Expand All @@ -1157,7 +1178,7 @@ def STATE_CONDITION_SCHEMA(value: Any) -> dict: # pylint: disable=invalid-name

_SCRIPT_SET_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ALIAS): string,
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_VARIABLES): SCRIPT_VARIABLES_SCHEMA,
}
)
Expand Down
68 changes: 29 additions & 39 deletions homeassistant/helpers/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
cast,
)

from async_timeout import timeout
import async_timeout
import voluptuous as vol

from homeassistant import exceptions
Expand Down Expand Up @@ -235,6 +235,13 @@ def _log(
msg, *args, level=level, **kwargs
)

def _step_log(self, default_message, timeout=None):
self._script.last_action = self._action.get(CONF_ALIAS, default_message)
_timeout = (
"" if timeout is None else f" (timeout: {timedelta(seconds=timeout)})"
)
self._log("Executing step %s%s", self._script.last_action, _timeout)

async def async_run(self) -> None:
"""Run script."""
try:
Expand Down Expand Up @@ -327,32 +334,26 @@ async def _async_delay_step(self):
"""Handle delay."""
delay = self._get_pos_time_period_template(CONF_DELAY)

self._script.last_action = self._action.get(CONF_ALIAS, f"delay {delay}")
self._log("Executing step %s", self._script.last_action)
self._step_log(f"delay {delay}")

delay = delay.total_seconds()
self._changed()
try:
async with timeout(delay):
async with async_timeout.timeout(delay):
await self._stop.wait()
except asyncio.TimeoutError:
pass

async def _async_wait_template_step(self):
"""Handle a wait template."""
if CONF_TIMEOUT in self._action:
delay = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
else:
delay = None
timeout = None

self._script.last_action = self._action.get(CONF_ALIAS, "wait template")
self._log(
"Executing step %s%s",
self._script.last_action,
"" if delay is None else f" (timeout: {timedelta(seconds=delay)})",
)
self._step_log("wait template", timeout)

self._variables["wait"] = {"remaining": delay, "completed": False}
self._variables["wait"] = {"remaining": timeout, "completed": False}

wait_template = self._action[CONF_WAIT_TEMPLATE]
wait_template.hass = self._hass
Expand All @@ -366,7 +367,7 @@ async def _async_wait_template_step(self):
def async_script_wait(entity_id, from_s, to_s):
"""Handle script after template condition is true."""
self._variables["wait"] = {
"remaining": to_context.remaining if to_context else delay,
"remaining": to_context.remaining if to_context else timeout,
"completed": True,
}
done.set()
Expand All @@ -382,7 +383,7 @@ def async_script_wait(entity_id, from_s, to_s):
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
]
try:
async with timeout(delay) as to_context:
async with async_timeout.timeout(timeout) as to_context:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except asyncio.TimeoutError as ex:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
Expand Down Expand Up @@ -431,8 +432,7 @@ async def async_cancel_long_task() -> None:

async def _async_call_service_step(self):
"""Call the service specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "call service")
self._log("Executing step %s", self._script.last_action)
self._step_log("call service")

params = service.async_prepare_call_from_config(
self._hass, self._action, self._variables
Expand Down Expand Up @@ -467,8 +467,7 @@ async def _async_call_service_step(self):

async def _async_device_step(self):
"""Perform the device automation specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "device automation")
self._log("Executing step %s", self._script.last_action)
self._step_log("device automation")
platform = await device_automation.async_get_device_automation_platform(
self._hass, self._action[CONF_DOMAIN], "action"
)
Expand All @@ -478,8 +477,7 @@ async def _async_device_step(self):

async def _async_scene_step(self):
"""Activate the scene specified in the action."""
self._script.last_action = self._action.get(CONF_ALIAS, "activate scene")
self._log("Executing step %s", self._script.last_action)
self._step_log("activate scene")
await self._hass.services.async_call(
scene.DOMAIN,
SERVICE_TURN_ON,
Expand All @@ -490,10 +488,7 @@ async def _async_scene_step(self):

async def _async_event_step(self):
"""Fire an event."""
self._script.last_action = self._action.get(
CONF_ALIAS, self._action[CONF_EVENT]
)
self._log("Executing step %s", self._script.last_action)
self._step_log(self._action.get(CONF_ALIAS, self._action[CONF_EVENT]))
event_data = {}
for conf in [CONF_EVENT_DATA, CONF_EVENT_DATA_TEMPLATE]:
if conf not in self._action:
Expand Down Expand Up @@ -627,25 +622,20 @@ async def _async_choose_step(self) -> None:
async def _async_wait_for_trigger_step(self):
"""Wait for a trigger event."""
if CONF_TIMEOUT in self._action:
delay = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
timeout = self._get_pos_time_period_template(CONF_TIMEOUT).total_seconds()
else:
delay = None
timeout = None

self._script.last_action = self._action.get(CONF_ALIAS, "wait for trigger")
self._log(
"Executing step %s%s",
self._script.last_action,
"" if delay is None else f" (timeout: {timedelta(seconds=delay)})",
)
self._step_log("wait for trigger", timeout)

variables = {**self._variables}
self._variables["wait"] = {"remaining": delay, "trigger": None}
self._variables["wait"] = {"remaining": timeout, "trigger": None}

done = asyncio.Event()

async def async_done(variables, context=None):
self._variables["wait"] = {
"remaining": to_context.remaining if to_context else delay,
"remaining": to_context.remaining if to_context else timeout,
"trigger": variables["trigger"],
}
done.set()
Expand All @@ -671,7 +661,7 @@ def log_cb(level, msg, **kwargs):
self._hass.async_create_task(flag.wait()) for flag in (self._stop, done)
]
try:
async with timeout(delay) as to_context:
async with async_timeout.timeout(timeout) as to_context:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
except asyncio.TimeoutError as ex:
if not self._action.get(CONF_CONTINUE_ON_TIMEOUT, True):
Expand All @@ -685,8 +675,7 @@ def log_cb(level, msg, **kwargs):

async def _async_variables_step(self):
"""Set a variable value."""
self._script.last_action = self._action.get(CONF_ALIAS, "setting variables")
self._log("Executing step %s", self._script.last_action)
self._step_log("setting variables")
self._variables = self._action[CONF_VARIABLES].async_render(
self._hass, self._variables, render_as_defaults=False
)
Expand Down Expand Up @@ -1111,10 +1100,11 @@ async def _async_prep_choose_data(self, step):
await self._async_get_condition(config)
for config in choice.get(CONF_CONDITIONS, [])
]
choice_name = choice.get(CONF_ALIAS, f"choice {idx}")
sub_script = Script(
self._hass,
choice[CONF_SEQUENCE],
f"{self.name}: {step_name}: choice {idx}",
f"{self.name}: {step_name}: {choice_name}",
self.domain,
running_description=self.running_description,
script_mode=SCRIPT_MODE_PARALLEL,
Expand Down
Loading