Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions homeassistant/components/automation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
CONF_MAX,
CONF_MAX_EXCEEDED,
Script,
script_stack_cv,
)
from homeassistant.helpers.script_variables import ScriptVariables
from homeassistant.helpers.service import (
Expand Down Expand Up @@ -505,6 +506,10 @@ def started_action():
EVENT_AUTOMATION_TRIGGERED, event_data, context=trigger_context
)

# Make a new empty script stack; automations are allowed
# to recursively trigger themselves
script_stack_cv.set([])

try:
with trace_path("action"):
await self.action_script.async_run(
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ async def async_run(
and id(self) in script_stack
):
script_execution_set("disallowed_recursion_detected")
_LOGGER.warning("Disallowed recursion detected")
self._log("Disallowed recursion detected", level=logging.WARNING)
return

if self.script_mode != SCRIPT_MODE_QUEUED:
Expand Down
207 changes: 206 additions & 1 deletion tests/components/automation/test_init.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The tests for the automation component."""
import asyncio
from datetime import timedelta
import logging
from unittest.mock import Mock, patch

Expand All @@ -25,14 +26,30 @@
STATE_OFF,
STATE_ON,
)
from homeassistant.core import Context, CoreState, State, callback
from homeassistant.core import (
Context,
CoreState,
HomeAssistant,
ServiceCall,
State,
callback,
)
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers.script import (
SCRIPT_MODE_CHOICES,
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
_async_stop_scripts_at_shutdown,
)
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util

from tests.common import (
assert_setup_component,
async_capture_events,
async_fire_time_changed,
async_mock_service,
mock_restore_cache,
)
Expand Down Expand Up @@ -1570,3 +1587,191 @@ async def test_trigger_condition_explicit_id(hass, calls):
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[-1].data.get("param") == "two"


@pytest.mark.parametrize(
"automation_mode,automation_runs",
(
(SCRIPT_MODE_PARALLEL, 2),
(SCRIPT_MODE_QUEUED, 2),
(SCRIPT_MODE_RESTART, 2),
(SCRIPT_MODE_SINGLE, 1),
),
)
@pytest.mark.parametrize(
"script_mode,script_warning_msg",
(
(SCRIPT_MODE_PARALLEL, "script1: Maximum number of runs exceeded"),
(SCRIPT_MODE_QUEUED, "script1: Disallowed recursion detected"),
(SCRIPT_MODE_RESTART, "script1: Disallowed recursion detected"),
(SCRIPT_MODE_SINGLE, "script1: Already running"),
),
)
async def test_recursive_automation_starting_script(
hass: HomeAssistant,
automation_mode,
automation_runs,
script_mode,
script_warning_msg,
caplog,
):
"""Test starting automations does not interfere with script deadlock prevention."""

# Fail if additional script modes are added to
# make sure we cover all script modes in tests
assert SCRIPT_MODE_CHOICES == [
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
]

stop_scripts_at_shutdown_called = asyncio.Event()
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown

async def mock_stop_scripts_at_shutdown(*args):
await real_stop_scripts_at_shutdown(*args)
stop_scripts_at_shutdown_called.set()

with patch(
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
wraps=mock_stop_scripts_at_shutdown,
):
assert await async_setup_component(
hass,
"script",
{
"script": {
"script1": {
"mode": script_mode,
"sequence": [
{"event": "trigger_automation"},
{
"wait_template": f"{{{{ float(states('sensor.test'), 0) >= {automation_runs} }}}}"
},
{"service": "script.script1"},
{"service": "test.script_done"},
],
},
}
},
)

assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"mode": automation_mode,
"trigger": [
{"platform": "event", "event_type": "trigger_automation"},
],
"action": [
{"service": "test.automation_started"},
{"service": "script.script1"},
],
}
},
)

script_done_event = asyncio.Event()
script_done = []
automation_started = []
automation_triggered = []

async def async_service_handler(service: ServiceCall):
if service.service == "automation_started":
automation_started.append(service)
elif service.service == "script_done":
script_done.append(service)
if len(script_done) == 1:
script_done_event.set()

async def async_automation_triggered(event):
"""Listen to automation_triggered event from the automation integration."""
automation_triggered.append(event)
hass.states.async_set("sensor.test", str(len(automation_triggered)))

hass.services.async_register("test", "script_done", async_service_handler)
hass.services.async_register(
"test", "automation_started", async_service_handler
)
hass.bus.async_listen("automation_triggered", async_automation_triggered)

hass.bus.async_fire("trigger_automation")
await asyncio.wait_for(script_done_event.wait(), 1)

# Trigger 1st stage script shutdown
hass.state = CoreState.stopping
hass.bus.async_fire("homeassistant_stop")
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)

# Trigger 2nd stage script shutdown
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=60))
await hass.async_block_till_done()

assert script_warning_msg in caplog.text


@pytest.mark.parametrize("automation_mode", SCRIPT_MODE_CHOICES)
async def test_recursive_automation(hass: HomeAssistant, automation_mode, caplog):
"""Test automation triggering itself.

- Illegal recursion detection should not be triggered
- Home Assistant should not hang on shut down
"""
stop_scripts_at_shutdown_called = asyncio.Event()
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown

async def stop_scripts_at_shutdown(*args):
await real_stop_scripts_at_shutdown(*args)
stop_scripts_at_shutdown_called.set()

with patch(
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
wraps=stop_scripts_at_shutdown,
):
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"mode": automation_mode,
"trigger": [
{"platform": "event", "event_type": "trigger_automation"},
],
"action": [
{"event": "trigger_automation"},
{"service": "test.automation_done"},
],
}
},
)

service_called = asyncio.Event()
service_called_late = []

async def async_service_handler(service):
if service.service == "automation_done":
service_called.set()
if service.service == "automation_started_late":
service_called_late.append(service)

hass.services.async_register("test", "automation_done", async_service_handler)
hass.services.async_register(
"test", "automation_started_late", async_service_handler
)

hass.bus.async_fire("trigger_automation")
await asyncio.wait_for(service_called.wait(), 1)

# Trigger 1st stage script shutdown
hass.state = CoreState.stopping
hass.bus.async_fire("homeassistant_stop")
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)

# Trigger 2nd stage script shutdown
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=90))
await hass.async_block_till_done()

assert "Disallowed recursion detected" not in caplog.text
4 changes: 0 additions & 4 deletions tests/components/script/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,6 @@ async def async_service_handler(service):
service_called.set()

hass.services.async_register("test", "script", async_service_handler)
hass.states.async_set("input_boolean.test", "on")
hass.states.async_set("input_boolean.test2", "off")

await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1)
Expand Down Expand Up @@ -908,8 +906,6 @@ async def async_service_handler(service):
service_called.set()

hass.services.async_register("test", "script", async_service_handler)
hass.states.async_set("input_boolean.test", "on")
hass.states.async_set("input_boolean.test2", "off")

await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1)
Expand Down