diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 1a73de885c0877..f37909fe518f7e 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -23,15 +23,16 @@ ) from homeassistant.core import Context, CoreState, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import condition, extract_domain_configs, script +from homeassistant.helpers import condition, extract_domain_configs import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.helpers.script import SCRIPT_PARALLEL_CHOICES, Script from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.typing import TemplateVarsType from homeassistant.loader import bind_hass -from homeassistant.util.dt import parse_datetime, utcnow +from homeassistant.util.dt import parse_datetime # mypy: allow-untyped-calls, allow-untyped-defs # mypy: no-check-untyped-defs, no-warn-return-any @@ -50,6 +51,7 @@ CONF_TRIGGER = "trigger" CONF_CONDITION_TYPE = "condition_type" CONF_INITIAL_STATE = "initial_state" +CONF_PARALLEL_ACTION = "parallel_action" CONF_SKIP_CONDITION = "skip_condition" CONDITION_USE_TRIGGER_VALUES = "use_trigger_values" @@ -106,6 +108,7 @@ def _platform_validator(config): vol.Required(CONF_TRIGGER): _TRIGGER_SCHEMA, vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA, vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA, + vol.Optional(CONF_PARALLEL_ACTION): vol.In(SCRIPT_PARALLEL_CHOICES), } ), ) @@ -394,13 +397,16 @@ async def async_trigger(self, variables, skip_condition=False, context=None): _LOGGER.info("Executing %s", self._name) try: - await self.action_script.async_run(variables, trigger_context) - except Exception as err: # pylint: disable=broad-except - self.action_script.async_log_exception( - _LOGGER, f"Error while executing automation {self.entity_id}", err + await self.action_script.async_run( + variables, + trigger_context, + _LOGGER, + f"Error while executing automation {self.entity_id}", ) + except Exception: # pylint: disable=broad-except + pass - self._last_triggered = utcnow() + self._last_triggered = self.action_script.last_triggered await self.async_update_ha_state() async def async_will_remove_from_hass(self): @@ -508,7 +514,12 @@ async def _async_process_config(hass, config, component): hidden = config_block[CONF_HIDE_ENTITY] initial_state = config_block.get(CONF_INITIAL_STATE) - action_script = script.Script(hass, config_block.get(CONF_ACTION, {}), name) + action_script = Script( + hass, + config_block.get(CONF_ACTION, {}), + name, + mode=config_block.get(CONF_PARALLEL_ACTION), + ) if CONF_CONDITION in config_block: cond_func = await _async_process_if(hass, config, config_block) diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 44684656372eea..b64b9eb4881109 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -21,7 +21,7 @@ from homeassistant.helpers.config_validation import make_entity_service_schema from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.script import Script +from homeassistant.helpers.script import SCRIPT_PARALLEL_ERROR, Script from homeassistant.helpers.service import async_set_service_schema from homeassistant.loader import bind_hass @@ -231,7 +231,9 @@ def __init__(self, hass, object_id, name, sequence): """Initialize the script.""" self.object_id = object_id self.entity_id = ENTITY_ID_FORMAT.format(object_id) - self.script = Script(hass, sequence, name, self.async_update_ha_state) + self.script = Script( + hass, sequence, name, self.async_update_ha_state, SCRIPT_PARALLEL_ERROR + ) @property def should_poll(self): @@ -268,13 +270,12 @@ async def async_turn_on(self, **kwargs): {ATTR_NAME: self.script.name, ATTR_ENTITY_ID: self.entity_id}, context=context, ) - try: - await self.script.async_run(kwargs.get(ATTR_VARIABLES), context) - except Exception as err: - self.script.async_log_exception( - _LOGGER, f"Error executing script {self.entity_id}", err - ) - raise err + await self.script.async_run( + kwargs.get(ATTR_VARIABLES), + context, + _LOGGER, + f"Error executing script {self.entity_id}", + ) async def async_turn_off(self, **kwargs): """Turn script off.""" diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 1cac4679d827fd..effda81d914c0a 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -32,7 +32,7 @@ ) from homeassistant.helpers.typing import ConfigType from homeassistant.util.async_ import run_callback_threadsafe -import homeassistant.util.dt as date_util +from homeassistant.util.dt import utcnow # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs @@ -60,6 +60,18 @@ ACTION_ACTIVATE_SCENE = "scene" +SCRIPT_PARALLEL_ALLOW = "allow" +SCRIPT_PARALLEL_ERROR = "error" +SCRIPT_PARALLEL_RESTART = "restart" +SCRIPT_PARALLEL_SKIP = "skip" +SCRIPT_PARALLEL_CHOICES = [ + SCRIPT_PARALLEL_ALLOW, + SCRIPT_PARALLEL_ERROR, + SCRIPT_PARALLEL_RESTART, + SCRIPT_PARALLEL_SKIP, +] + + def _determine_action(action): """Determine action type.""" if CONF_DELAY in action: @@ -130,6 +142,7 @@ def __init__( sequence: Sequence[Dict[str, Any]], name: Optional[str] = None, change_listener: Optional[Callable[..., Any]] = None, + mode: Optional[str] = None, ) -> None: """Initialize the script.""" self.hass = hass @@ -137,32 +150,35 @@ def __init__( template.attach(hass, self.sequence) self.name = name self._change_listener = change_listener - self._cur = -1 - self._exception_step: Optional[int] = None - self.last_action = None + self._runs: List[Script._ScriptRun] = [] self.last_triggered: Optional[datetime] = None self.can_cancel = any( CONF_DELAY in action or CONF_WAIT_TEMPLATE in action for action in self.sequence ) - self._async_listener: List[CALLBACK_TYPE] = [] self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {} - self._actions = { - ACTION_DELAY: self._async_delay, - ACTION_WAIT_TEMPLATE: self._async_wait_template, - ACTION_CHECK_CONDITION: self._async_check_condition, - ACTION_FIRE_EVENT: self._async_fire_event, - ACTION_CALL_SERVICE: self._async_call_service, - ACTION_DEVICE_AUTOMATION: self._async_device_automation, - ACTION_ACTIVATE_SCENE: self._async_activate_scene, - } + self._mode = ( + mode + if mode + else SCRIPT_PARALLEL_RESTART + if self.can_cancel + else SCRIPT_PARALLEL_ALLOW + ) self._referenced_entities: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None + @property + def last_action(self): + """Return last action.""" + try: + return self._runs[0].last_action + except IndexError: + return None + @property def is_running(self) -> bool: """Return true if script is on.""" - return self._cur != -1 + return len(self._runs) > 0 @property def referenced_devices(self): @@ -220,287 +236,349 @@ def referenced_entities(self): self._referenced_entities = referenced return referenced - def run(self, variables=None, context=None): + def run(self, variables=None, context=None, logger=None, message_base=None): """Run script.""" asyncio.run_coroutine_threadsafe( - self.async_run(variables, context), self.hass.loop + self.async_run(variables, context, logger, message_base), self.hass.loop ).result() - async def async_run( - self, variables: Optional[Sequence] = None, context: Optional[Context] = None - ) -> None: - """Run script. + class _ScriptRun: + def __init__( + self, + hass: HomeAssistant, + parent: "Script", + variables: Optional[Sequence] = None, + context: Optional[Context] = None, + logger: Optional[logging.Logger] = None, + message_base: Optional[str] = None, + ) -> None: + self.hass = hass + self._parent = parent + self._variables = variables + self._context = context + self._logger = logger or _LOGGER + self._message_base = message_base or "Error executing script" + self._actions = { + ACTION_DELAY: self._async_delay, + ACTION_WAIT_TEMPLATE: self._async_wait_template, + ACTION_CHECK_CONDITION: self._async_check_condition, + ACTION_FIRE_EVENT: self._async_fire_event, + ACTION_CALL_SERVICE: self._async_call_service, + ACTION_DEVICE_AUTOMATION: self._async_device_automation, + ACTION_ACTIVATE_SCENE: self._async_activate_scene, + } + self.task: Optional[asyncio.Task] = None + self.last_action = None + self._cur = -1 + self._async_listener: List[CALLBACK_TYPE] = [] + + async def async_run(self) -> None: + """Run script.""" + if self._cur == -1: + self._log("Running script") + self._cur = 0 + + assert not self._async_listener + + for cur, action in islice( + enumerate(self._parent.sequence), self._cur, None + ): + try: + await self._handle_action(action) + except _SuspendScript: + # Store next step to take and notify change listeners + self.task = None + self._cur = cur + 1 + # pylint: disable=protected-access + if self._parent._change_listener: + self.hass.async_add_job(self._parent._change_listener) + return + except _StopScript: + break + except Exception as err: + self._async_log_exception(cur, action, err) + self._async_stop() + # Pass exception on. + raise + + self._async_stop() + # pylint: disable=protected-access + if self._parent._change_listener: + self.hass.async_add_job(self._parent._change_listener) - This method is a coroutine. - """ - self.last_triggered = date_util.utcnow() - if self._cur == -1: - self._log("Running script") - self._cur = 0 + @callback + def _async_stop(self): + self._async_remove_listener() + self._parent._runs.remove(self) # pylint: disable=protected-access - # Unregister callback if we were in a delay or wait but turn on is - # called again. In that case we just continue execution. - self._async_remove_listener() + @callback + def async_stop(self): + """Stop script run.""" + self._async_stop() + with suppress(AttributeError): + self.task.cancel() - for cur, action in islice(enumerate(self.sequence), self._cur, None): - try: - await self._handle_action(action, variables, context) - except _SuspendScript: - # Store next step to take and notify change listeners - self._cur = cur + 1 - if self._change_listener: - self.hass.async_add_job(self._change_listener) - return - except _StopScript: - break - except Exception: - # Store the step that had an exception - self._exception_step = cur - # Set script to not running - self._cur = -1 - self.last_action = None - # Pass exception on. - raise - - # Set script to not-running. - self._cur = -1 - self.last_action = None - if self._change_listener: - self.hass.async_add_job(self._change_listener) + @callback + def _async_log_exception(self, step, action, exception): + action_type = _determine_action(action) - def stop(self) -> None: - """Stop running script.""" - run_callback_threadsafe(self.hass.loop, self.async_stop).result() + error = None + meth = self._logger.error - @callback - def async_stop(self) -> None: - """Stop running script.""" - if self._cur == -1: - return + if isinstance(exception, vol.Invalid): + error_desc = "Invalid data" - self._cur = -1 - self._async_remove_listener() - if self._change_listener: - self.hass.async_add_job(self._change_listener) + elif isinstance(exception, exceptions.TemplateError): + error_desc = "Error rendering template" - @callback - def async_log_exception(self, logger, message_base, exception): - """Log an exception for this script. - - Should only be called on exceptions raised by this scripts async_run. - """ - step = self._exception_step - action = self.sequence[step] - action_type = _determine_action(action) - - error = None - meth = logger.error - - if isinstance(exception, vol.Invalid): - error_desc = "Invalid data" - - elif isinstance(exception, exceptions.TemplateError): - error_desc = "Error rendering template" - - elif isinstance(exception, exceptions.Unauthorized): - error_desc = "Unauthorized" - - elif isinstance(exception, exceptions.ServiceNotFound): - error_desc = "Service not found" - - else: - # Print the full stack trace, unknown error - error_desc = "Unknown error" - meth = logger.exception - error = "" - - if error is None: - error = str(exception) - - meth( - "%s. %s for %s at pos %s: %s", - message_base, - error_desc, - action_type, - step + 1, - error, - ) + elif isinstance(exception, exceptions.Unauthorized): + error_desc = "Unauthorized" - async def _handle_action(self, action, variables, context): - """Handle an action.""" - await self._actions[_determine_action(action)](action, variables, context) + elif isinstance(exception, exceptions.ServiceNotFound): + error_desc = "Service not found" - async def _async_delay(self, action, variables, context): - """Handle delay.""" - # Call ourselves in the future to continue work - unsub = None + else: + # Print the full stack trace, unknown error + error_desc = "Unknown error" + meth = self._logger.exception + error = "" + + if error is None: + error = str(exception) + + meth( + "%s. %s for %s at pos %s: %s", + self._message_base, + error_desc, + action_type, + step + 1, + error, + ) - @callback - def async_script_delay(now): + async def _handle_action(self, action): + """Handle an action.""" + await self._actions[_determine_action(action)](action) + + async def _async_delay(self, action): """Handle delay.""" - with suppress(ValueError): - self._async_listener.remove(unsub) + # Call ourselves in the future to continue work + unsub = None - self.hass.async_create_task(self.async_run(variables, context)) + @callback + def async_script_delay(now): + """Handle delay.""" + with suppress(ValueError): + self._async_listener.remove(unsub) + self.task = self.hass.async_create_task(self.async_run()) - delay = action[CONF_DELAY] + delay = action[CONF_DELAY] - try: - if isinstance(delay, template.Template): - delay = vol.All(cv.time_period, cv.positive_timedelta)( - delay.async_render(variables) + try: + if isinstance(delay, template.Template): + delay = vol.All(cv.time_period, cv.positive_timedelta)( + delay.async_render(self._variables) + ) + elif isinstance(delay, dict): + delay_data = {} + delay_data.update(template.render_complex(delay, self._variables)) + delay = cv.time_period(delay_data) + except (exceptions.TemplateError, vol.Invalid) as ex: + _LOGGER.error( + "Error rendering '%s' delay template: %s", self._parent.name, ex ) - elif isinstance(delay, dict): - delay_data = {} - delay_data.update(template.render_complex(delay, variables)) - delay = cv.time_period(delay_data) - except (exceptions.TemplateError, vol.Invalid) as ex: - _LOGGER.error("Error rendering '%s' delay template: %s", self.name, ex) - raise _StopScript - - self.last_action = action.get(CONF_ALIAS, f"delay {delay}") - self._log("Executing step %s" % self.last_action) - - unsub = async_track_point_in_utc_time( - self.hass, async_script_delay, date_util.utcnow() + delay - ) - self._async_listener.append(unsub) - raise _SuspendScript + raise _StopScript - async def _async_wait_template(self, action, variables, context): - """Handle a wait template.""" - # Call ourselves in the future to continue work - wait_template = action[CONF_WAIT_TEMPLATE] - wait_template.hass = self.hass + self.last_action = action.get(CONF_ALIAS, f"delay {delay}") + self._log("Executing step %s" % self.last_action) - self.last_action = action.get(CONF_ALIAS, "wait template") - self._log("Executing step %s" % self.last_action) + unsub = async_track_point_in_utc_time( + self.hass, async_script_delay, utcnow() + delay + ) + self._async_listener.append(unsub) + raise _SuspendScript - # check if condition already okay - if condition.async_template(self.hass, wait_template, variables): - return + async def _async_wait_template(self, action): + """Handle a wait template.""" + # Call ourselves in the future to continue work + wait_template = action[CONF_WAIT_TEMPLATE] + wait_template.hass = self.hass - @callback - def async_script_wait(entity_id, from_s, to_s): - """Handle script after template condition is true.""" - self._async_remove_listener() - self.hass.async_create_task(self.async_run(variables, context)) + self.last_action = action.get(CONF_ALIAS, "wait template") + self._log("Executing step %s" % self.last_action) - self._async_listener.append( - async_track_template(self.hass, wait_template, async_script_wait, variables) - ) - - if CONF_TIMEOUT in action: - self._async_set_timeout( - action, variables, context, action.get(CONF_CONTINUE, True) - ) + # check if condition already okay + if condition.async_template(self.hass, wait_template, self._variables): + return - raise _SuspendScript - - async def _async_call_service(self, action, variables, context): - """Call the service specified in the action. - - This method is a coroutine. - """ - self.last_action = action.get(CONF_ALIAS, "call service") - self._log("Executing step %s" % self.last_action) - await service.async_call_from_config( - self.hass, - action, - blocking=True, - variables=variables, - validate_config=False, - context=context, - ) + @callback + def async_script_wait(entity_id, from_s, to_s): + """Handle script after template condition is true.""" + self._async_remove_listener() + self.task = self.hass.async_create_task(self.async_run()) - async def _async_device_automation(self, action, variables, context): - """Perform the device automation specified in the action. + self._async_listener.append( + async_track_template( + self.hass, wait_template, async_script_wait, self._variables + ) + ) - This method is a coroutine. - """ - self.last_action = action.get(CONF_ALIAS, "device automation") - self._log("Executing step %s" % self.last_action) - platform = await device_automation.async_get_device_automation_platform( - self.hass, action[CONF_DOMAIN], "action" - ) - await platform.async_call_action_from_config( - self.hass, action, variables, context - ) + if CONF_TIMEOUT in action: + self._async_set_timeout(action) + + raise _SuspendScript + + async def _async_call_service(self, action): + """Call the service specified in the action.""" + self.last_action = action.get(CONF_ALIAS, "call service") + self._log("Executing step %s" % self.last_action) + await service.async_call_from_config( + self.hass, + action, + blocking=True, + variables=self._variables, + validate_config=False, + context=self._context, + ) - async def _async_activate_scene(self, action, variables, context): - """Activate the scene specified in the action. - - This method is a coroutine. - """ - self.last_action = action.get(CONF_ALIAS, "activate scene") - self._log("Executing step %s" % self.last_action) - await self.hass.services.async_call( - scene.DOMAIN, - SERVICE_TURN_ON, - {ATTR_ENTITY_ID: action[CONF_SCENE]}, - blocking=True, - context=context, - ) + async def _async_device_automation(self, action): + """Perform the device automation specified in the action.""" + self.last_action = action.get(CONF_ALIAS, "device automation") + self._log("Executing step %s" % self.last_action) + platform = await device_automation.async_get_device_automation_platform( + self.hass, action[CONF_DOMAIN], "action" + ) + await platform.async_call_action_from_config( + self.hass, action, self._variables, self._context + ) - async def _async_fire_event(self, action, variables, context): - """Fire an event.""" - self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) - self._log("Executing step %s" % self.last_action) - event_data = dict(action.get(CONF_EVENT_DATA, {})) - if CONF_EVENT_DATA_TEMPLATE in action: - try: - event_data.update( - template.render_complex(action[CONF_EVENT_DATA_TEMPLATE], variables) - ) - except exceptions.TemplateError as ex: - _LOGGER.error("Error rendering event data template: %s", ex) + async def _async_activate_scene(self, action): + """Activate the scene specified in the action.""" + self.last_action = action.get(CONF_ALIAS, "activate scene") + self._log("Executing step %s" % self.last_action) + await self.hass.services.async_call( + scene.DOMAIN, + SERVICE_TURN_ON, + {ATTR_ENTITY_ID: action[CONF_SCENE]}, + blocking=True, + context=self._context, + ) - self.hass.bus.async_fire(action[CONF_EVENT], event_data, context=context) + async def _async_fire_event(self, action): + """Fire an event.""" + self.last_action = action.get(CONF_ALIAS, action[CONF_EVENT]) + self._log("Executing step %s" % self.last_action) + event_data = dict(action.get(CONF_EVENT_DATA, {})) + if CONF_EVENT_DATA_TEMPLATE in action: + try: + event_data.update( + template.render_complex( + action[CONF_EVENT_DATA_TEMPLATE], self._variables + ) + ) + except exceptions.TemplateError as ex: + _LOGGER.error("Error rendering event data template: %s", ex) + + self.hass.bus.async_fire( + action[CONF_EVENT], event_data, context=self._context + ) - async def _async_check_condition(self, action, variables, context): - """Test if condition is matching.""" - config_cache_key = frozenset((k, str(v)) for k, v in action.items()) - config = self._config_cache.get(config_cache_key) - if not config: - config = await condition.async_from_config(self.hass, action, False) - self._config_cache[config_cache_key] = config + async def _async_check_condition(self, action): + """Test if condition is matching.""" + config_cache_key = frozenset((k, str(v)) for k, v in action.items()) + # pylint: disable=protected-access + config = self._parent._config_cache.get(config_cache_key) + if not config: + config = await condition.async_from_config(self.hass, action, False) + self._parent._config_cache[config_cache_key] = config + + self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION]) + check = config(self.hass, self._variables) + self._log(f"Test condition {self.last_action}: {check}") + + if not check: + raise _StopScript + + def _async_set_timeout(self, action): + """Schedule a timeout to abort or continue script.""" + timeout = action[CONF_TIMEOUT] + unsub = None + + @callback + def async_script_timeout(now): + """Call after timeout is retrieve.""" + with suppress(ValueError): + self._async_listener.remove(unsub) + self._async_remove_listener() + + # Check if we want to continue to execute + # the script after the timeout + if action.get(CONF_CONTINUE, True): + self.task = self.hass.async_create_task(self.async_run()) + else: + self._log("Timeout reached, abort script.") + self._async_stop() + + unsub = async_track_point_in_utc_time( + self.hass, async_script_timeout, utcnow() + timeout + ) + self._async_listener.append(unsub) - self.last_action = action.get(CONF_ALIAS, action[CONF_CONDITION]) - check = config(self.hass, variables) - self._log(f"Test condition {self.last_action}: {check}") + def _async_remove_listener(self): + """Remove listeners, if any.""" + for unsub in self._async_listener: + unsub() + self._async_listener.clear() - if not check: - raise _StopScript + def _log(self, msg): + """Logger helper.""" + if self._parent.name is not None: + msg = f"Script {self._parent.name}: {msg}" - def _async_set_timeout(self, action, variables, context, continue_on_timeout): - """Schedule a timeout to abort or continue script.""" - timeout = action[CONF_TIMEOUT] - unsub = None + _LOGGER.info(msg) - @callback - def async_script_timeout(now): - """Call after timeout is retrieve.""" - with suppress(ValueError): - self._async_listener.remove(unsub) - - # Check if we want to continue to execute - # the script after the timeout - if continue_on_timeout: - self.hass.async_create_task(self.async_run(variables, context)) - else: - self._log("Timeout reached, abort script.") + async def async_run( + self, + variables: Optional[Sequence] = None, + context: Optional[Context] = None, + logger: Optional[logging.Logger] = None, + message_base: Optional[str] = None, + ) -> None: + """Run script.""" + if self.is_running: + if self._mode == SCRIPT_PARALLEL_SKIP: + self._log("Skipping script") + return + if self._mode == SCRIPT_PARALLEL_ERROR: + if logger: + logger.error("%s. Already running", message_base) + raise exceptions.HomeAssistantError( + f"{self.name if self.name else 'Script'} already running" + ) + if self._mode == SCRIPT_PARALLEL_RESTART: + self._log("Restarting script") self.async_stop() - unsub = async_track_point_in_utc_time( - self.hass, async_script_timeout, date_util.utcnow() + timeout + self.last_triggered = utcnow() + run = Script._ScriptRun( + self.hass, self, variables, context, logger, message_base ) - self._async_listener.append(unsub) + self._runs.append(run) + run.task = self.hass.async_create_task(run.async_run()) + await run.task + + def stop(self) -> None: + """Stop running script.""" + run_callback_threadsafe(self.hass.loop, self.async_stop).result() - def _async_remove_listener(self): - """Remove point in time listener, if any.""" - for unsub in self._async_listener: - unsub() - self._async_listener.clear() + @callback + def async_stop(self) -> None: + """Stop running script.""" + if not self.is_running: + return + for run in self._runs: + run.async_stop() + if self._change_listener: + self.hass.async_add_job(self._change_listener) def _log(self, msg): """Logger helper.""" diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index c27a0262a4e7b6..b77ee27ca8931e 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -71,7 +71,7 @@ async def test_service_specify_data(hass, calls): time = dt_util.utcnow() - with patch("homeassistant.components.automation.utcnow", return_value=time): + with patch("homeassistant.helpers.script.utcnow", return_value=time): hass.bus.async_fire("test_event") await hass.async_block_till_done() @@ -114,7 +114,7 @@ async def test_action_delay(hass, calls): time = dt_util.utcnow() - with patch("homeassistant.components.automation.utcnow", return_value=time): + with patch("homeassistant.helpers.script.utcnow", return_value=time): hass.bus.async_fire("test_event") await hass.async_block_till_done() @@ -133,6 +133,313 @@ async def test_action_delay(hass, calls): assert state.attributes.get("last_triggered") == time +async def test_action_delay_retrigger_allow(hass, calls): + """Test action delay with parallel_action: allow.""" + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "alias": "hello", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [ + { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.platform }} - " + "{{ trigger.event.event_type }} - 1" + }, + }, + {"delay": {"minutes": "10"}}, + { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.platform }} - " + "{{ trigger.event.event_type }} - 2" + }, + }, + ], + "parallel_action": "allow", + } + }, + ) + + time1 = dt_util.utcnow() + + with patch("homeassistant.helpers.script.utcnow", return_value=time1): + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].data["some"] == "event - test_event - 1" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time1 + + time2 = dt_util.utcnow() + timedelta(minutes=5) + + async_fire_time_changed(hass, time2) + await hass.async_block_till_done() + + assert len(calls) == 1 + + with patch("homeassistant.helpers.script.utcnow", return_value=time2): + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + assert len(calls) == 2 + assert calls[1].data["some"] == "event - test_event - 1" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time2 + + time3 = dt_util.utcnow() + timedelta(minutes=10) + + async_fire_time_changed(hass, time3) + await hass.async_block_till_done() + + assert len(calls) == 3 + assert calls[2].data["some"] == "event - test_event - 2" + + time4 = dt_util.utcnow() + timedelta(minutes=15) + + async_fire_time_changed(hass, time4) + await hass.async_block_till_done() + + assert len(calls) == 4 + assert calls[3].data["some"] == "event - test_event - 2" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time2 + + +async def test_action_delay_retrigger_error(hass, calls, caplog): + """Test action delay with parallel_action: error.""" + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "alias": "hello", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [ + { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.platform }} - " + "{{ trigger.event.event_type }} - 1" + }, + }, + {"delay": {"minutes": "10"}}, + { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.platform }} - " + "{{ trigger.event.event_type }} - 2" + }, + }, + ], + "parallel_action": "error", + } + }, + ) + + time1 = dt_util.utcnow() + + with patch("homeassistant.helpers.script.utcnow", return_value=time1): + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].data["some"] == "event - test_event - 1" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time1 + + time2 = dt_util.utcnow() + timedelta(minutes=5) + + async_fire_time_changed(hass, time2) + await hass.async_block_till_done() + + assert len(calls) == 1 + + with patch("homeassistant.helpers.script.utcnow", return_value=time2): + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + assert len(calls) == 1 + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time1 + + time3 = dt_util.utcnow() + timedelta(minutes=10) + + async_fire_time_changed(hass, time3) + await hass.async_block_till_done() + + assert len(calls) == 2 + assert calls[1].data["some"] == "event - test_event - 2" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time1 + + assert "Error while executing automation" in caplog.text + + +async def test_action_delay_retrigger_restart(hass, calls): + """Test action delay with parallel_action: restart.""" + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "alias": "hello", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [ + { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.platform }} - " + "{{ trigger.event.event_type }} - 1" + }, + }, + {"delay": {"minutes": "10"}}, + { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.platform }} - " + "{{ trigger.event.event_type }} - 2" + }, + }, + ], + } + }, + ) + + time1 = dt_util.utcnow() + + with patch("homeassistant.helpers.script.utcnow", return_value=time1): + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].data["some"] == "event - test_event - 1" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time1 + + time2 = dt_util.utcnow() + timedelta(minutes=5) + + async_fire_time_changed(hass, time2) + await hass.async_block_till_done() + + assert len(calls) == 1 + + with patch("homeassistant.helpers.script.utcnow", return_value=time2): + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + assert len(calls) == 2 + assert calls[1].data["some"] == "event - test_event - 1" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time2 + + time3 = dt_util.utcnow() + timedelta(minutes=15) + + async_fire_time_changed(hass, time3) + await hass.async_block_till_done() + + assert len(calls) == 3 + assert calls[2].data["some"] == "event - test_event - 2" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time2 + + +async def test_action_delay_retrigger_skip(hass, calls): + """Test action delay with parallel_action: skip.""" + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: { + "alias": "hello", + "trigger": {"platform": "event", "event_type": "test_event"}, + "action": [ + { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.platform }} - " + "{{ trigger.event.event_type }} - 1" + }, + }, + {"delay": {"minutes": "10"}}, + { + "service": "test.automation", + "data_template": { + "some": "{{ trigger.platform }} - " + "{{ trigger.event.event_type }} - 2" + }, + }, + ], + "parallel_action": "skip", + } + }, + ) + + time1 = dt_util.utcnow() + + with patch("homeassistant.helpers.script.utcnow", return_value=time1): + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].data["some"] == "event - test_event - 1" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time1 + + time2 = dt_util.utcnow() + timedelta(minutes=5) + + async_fire_time_changed(hass, time2) + await hass.async_block_till_done() + + assert len(calls) == 1 + + with patch("homeassistant.helpers.script.utcnow", return_value=time2): + hass.bus.async_fire("test_event") + await hass.async_block_till_done() + + assert len(calls) == 1 + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time1 + + time3 = dt_util.utcnow() + timedelta(minutes=10) + + async_fire_time_changed(hass, time3) + await hass.async_block_till_done() + + assert len(calls) == 2 + assert calls[1].data["some"] == "event - test_event - 2" + + state = hass.states.get("automation.hello") + assert state is not None + assert state.attributes.get("last_triggered") == time1 + + async def test_service_specify_entity_id(hass, calls): """Test service data.""" assert await async_setup_component( diff --git a/tests/helpers/test_script.py b/tests/helpers/test_script.py index 5e748e3adfe643..3325cca4347459 100644 --- a/tests/helpers/test_script.py +++ b/tests/helpers/test_script.py @@ -5,7 +5,6 @@ from unittest import mock import asynctest -import jinja2 import pytest import voluptuous as vol @@ -414,6 +413,56 @@ def record_event(event): assert len(events) == 0 +async def test_run_while_suspended(hass): + """Test running when already running.""" + event = "test_event" + events = [] + context = Context() + delay_alias = "delay step" + + @callback + def record_event(event): + """Add recorded event to set.""" + events.append(event) + + hass.bus.async_listen(event, record_event) + + script_obj = script.Script( + hass, + cv.SCRIPT_SCHEMA( + [ + {"event": event}, + {"delay": {"seconds": 5}, "alias": delay_alias}, + {"event": event}, + ] + ), + mode="error", + ) + + await script_obj.async_run(context=context) + await hass.async_block_till_done() + + assert script_obj.is_running + assert script_obj.can_cancel + assert script_obj.last_action == delay_alias + assert len(events) == 1 + + with pytest.raises(exceptions.HomeAssistantError): + await script_obj.async_run(context=context) + + assert script_obj.is_running + assert len(events) == 1 + + future = dt_util.utcnow() + timedelta(seconds=5) + async_fire_time_changed(hass, future) + await hass.async_block_till_done() + + assert not script_obj.is_running + assert len(events) == 2 + assert events[0].context is context + assert events[1].context is context + + async def test_wait_template(hass): """Test the wait template.""" event = "test_event" @@ -897,7 +946,7 @@ async def test_last_triggered(hass): assert script_obj.last_triggered is None time = dt_util.utcnow() - with mock.patch("homeassistant.helpers.script.date_util.utcnow", return_value=time): + with mock.patch("homeassistant.helpers.script.utcnow", return_value=time): await script_obj.async_run() await hass.async_block_till_done() @@ -922,7 +971,7 @@ def record_event(event): await script_obj.async_run() assert len(events) == 0 - assert script_obj._cur == -1 + assert not script_obj.is_running async def test_propagate_error_invalid_service_data(hass): @@ -958,7 +1007,7 @@ def record_call(service): assert len(events) == 0 assert len(calls) == 0 - assert script_obj._cur == -1 + assert not script_obj.is_running async def test_propagate_error_service_exception(hass): @@ -989,39 +1038,7 @@ def record_call(service): assert len(events) == 0 assert len(calls) == 0 - assert script_obj._cur == -1 - - -def test_log_exception(): - """Test logged output.""" - script_obj = script.Script( - None, cv.SCRIPT_SCHEMA([{"service": "test.script"}, {"event": "test_event"}]) - ) - script_obj._exception_step = 1 - - for exc, msg in ( - (vol.Invalid("Invalid number"), "Invalid data"), - ( - exceptions.TemplateError(jinja2.TemplateError("Unclosed bracket")), - "Error rendering template", - ), - (exceptions.Unauthorized(), "Unauthorized"), - (exceptions.ServiceNotFound("light", "turn_on"), "Service not found"), - (ValueError("Cannot parse JSON"), "Unknown error"), - ): - logger = mock.Mock() - script_obj.async_log_exception(logger, "Test error", exc) - - assert len(logger.mock_calls) == 1 - _, _, p_error_desc, p_action_type, p_step, p_error = logger.mock_calls[0][1] - - assert p_error_desc == msg - assert p_action_type == script.ACTION_FIRE_EVENT - assert p_step == 2 - if isinstance(exc, ValueError): - assert p_error == "" - else: - assert p_error == str(exc) + assert not script_obj.is_running async def test_referenced_entities():