Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions homeassistant/components/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
CONF_MAX,
CONF_MAX_EXCEEDED,
Script,
script_stack_cv,
)
from homeassistant.helpers.service import async_set_service_schema
from homeassistant.helpers.trace import trace_get, trace_path
Expand Down Expand Up @@ -398,10 +399,14 @@ async def async_turn_on(self, **kwargs):
return

# Caller does not want to wait for called script to finish so let script run in
# separate Task. However, wait for first state change so we can guarantee that
# it is written to the State Machine before we return.
# separate Task. Make a new empty script stack; scripts are allowed to
# recursively turn themselves on when not waiting.
script_stack_cv.set([])

self._changed.clear()
self.hass.async_create_task(coro)
# Wait for first state change so we can guarantee that
# it is written to the State Machine before we return.
await self._changed.wait()

async def _async_run(self, variables, context):
Expand Down
6 changes: 0 additions & 6 deletions tests/components/automation/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,18 +1791,12 @@ async def stop_scripts_at_shutdown(*args):
)

service_called = asyncio.Event()
service_called_late = []

async def async_service_handler(service):
if service.service == "automation_done":
service_called.set()
if service.service == "automation_started_late":
service_called_late.append(service)

hass.services.async_register("test", "automation_done", async_service_handler)
hass.services.async_register(
"test", "automation_started_late", async_service_handler
)

hass.bus.async_fire("trigger_automation")
await asyncio.wait_for(service_called.wait(), 1)
Expand Down
89 changes: 88 additions & 1 deletion tests/components/script/test_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The tests for the Script component."""
# pylint: disable=protected-access
import asyncio
from datetime import timedelta
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -33,12 +34,13 @@
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
_async_stop_scripts_at_shutdown,
)
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util

from tests.common import async_mock_service, mock_restore_cache
from tests.common import async_fire_time_changed, async_mock_service, mock_restore_cache
from tests.components.logbook.test_init import MockLazyEventPartialState

ENTITY_ID = "script.test"
Expand Down Expand Up @@ -919,6 +921,91 @@ async def async_service_handler(service):
assert warning_msg in caplog.text


@pytest.mark.parametrize(
"script_mode", [SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED, SCRIPT_MODE_RESTART]
)
async def test_recursive_script_turn_on(hass: HomeAssistant, script_mode, caplog):
"""Test script turning itself on.

- Illegal recursion detection should not be triggered
- Home Assistant should not hang on shut down
- SCRIPT_MODE_SINGLE is not relevant because suca script can't turn itself on
"""
# Make sure we cover all script modes
assert SCRIPT_MODE_CHOICES == [
SCRIPT_MODE_PARALLEL,
SCRIPT_MODE_QUEUED,
SCRIPT_MODE_RESTART,
SCRIPT_MODE_SINGLE,
]
stop_scripts_at_shutdown_called = asyncio.Event()
real_stop_scripts_at_shutdown = _async_stop_scripts_at_shutdown

async def stop_scripts_at_shutdown(*args):
await real_stop_scripts_at_shutdown(*args)
stop_scripts_at_shutdown_called.set()

with patch(
"homeassistant.helpers.script._async_stop_scripts_at_shutdown",
wraps=stop_scripts_at_shutdown,
):
assert await async_setup_component(
hass,
script.DOMAIN,
{
script.DOMAIN: {
"script1": {
"mode": script_mode,
"sequence": [
{
"choose": {
"conditions": {
"condition": "template",
"value_template": "{{ request == 'step_2' }}",
},
"sequence": {"service": "test.script_done"},
},
"default": {
"service": "script.turn_on",
"data": {
"entity_id": "script.script1",
"variables": {"request": "step_2"},
},
},
},
{
"service": "script.turn_on",
"data": {"entity_id": "script.script1"},
},
],
}
}
},
)

service_called = asyncio.Event()

async def async_service_handler(service):
if service.service == "script_done":
service_called.set()

hass.services.async_register("test", "script_done", async_service_handler)

await hass.services.async_call("script", "script1")
await asyncio.wait_for(service_called.wait(), 1)

# Trigger 1st stage script shutdown
hass.state = CoreState.stopping
hass.bus.async_fire("homeassistant_stop")
await asyncio.wait_for(stop_scripts_at_shutdown_called.wait(), 1)

# Trigger 2nd stage script shutdown
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(seconds=90))
await hass.async_block_till_done()

assert "Disallowed recursion detected" not in caplog.text


async def test_setup_with_duplicate_scripts(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
Expand Down