Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 46 additions & 17 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Optional,
Set,
TypeVar,
Union,
)
import uuid

Expand Down Expand Up @@ -1214,29 +1215,57 @@ async def async_call(
context=context,
)

coro = self._execute_service(handler, service_call)
if not blocking:
self._hass.async_create_task(self._safe_execute(handler, service_call))
self._run_service_in_background(coro, service_call)
return None

task = self._hass.async_create_task(coro)
try:
async with timeout(limit):
await asyncio.shield(self._execute_service(handler, service_call))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the biggest change here is that we're no longer shielding things. I suggest we merge this as-is and see how it highlights any issues before 109 is released. If we find some issues, we can decide to apply shields to the scenarios that are failing (thinking Rest/Websocket call service APIs)

await asyncio.wait({task}, timeout=limit)
except asyncio.CancelledError:
# Task calling us was cancelled, so cancel service call task, and wait for
# it to be cancelled, within reason, before leaving.
_LOGGER.debug("Service call was cancelled: %s", service_call)
task.cancel()
await asyncio.wait({task}, timeout=SERVICE_CALL_LIMIT)
raise

if task.cancelled():
# Service call task was cancelled some other way, such as during shutdown.
_LOGGER.debug("Service was cancelled: %s", service_call)
raise asyncio.CancelledError
if task.done():
# Propagate any exceptions that might have happened during service call.
task.result()
# Service call completed successfully!
return True
except asyncio.TimeoutError:
return False
# Service call task did not complete before timeout expired.
# Let it keep running in background.
self._run_service_in_background(task, service_call)
_LOGGER.debug("Service did not complete before timeout: %s", service_call)
return False

async def _safe_execute(self, handler: Service, service_call: ServiceCall) -> None:
"""Execute a service and catch exceptions."""
try:
await self._execute_service(handler, service_call)
except Unauthorized:
_LOGGER.warning(
"Unauthorized service called %s/%s",
service_call.domain,
service_call.service,
)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error executing service %s", service_call)
def _run_service_in_background(
self, coro_or_task: Union[Coroutine, asyncio.Task], service_call: ServiceCall
) -> None:
"""Run service call in background, catching and logging any exceptions."""

async def catch_exceptions() -> None:
try:
await coro_or_task
except Unauthorized:
_LOGGER.warning(
"Unauthorized service called %s/%s",
service_call.domain,
service_call.service,
)
except asyncio.CancelledError:
_LOGGER.debug("Service was cancelled: %s", service_call)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error executing service: %s", service_call)

self._hass.async_create_task(catch_exceptions())

async def _execute_service(
self, handler: Service, service_call: ServiceCall
Expand Down
36 changes: 36 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,42 @@ async def service_handler(call):
assert len(runs) == 3


@pytest.mark.parametrize("cancel_call", [True, False])
async def test_cancel_service_task(hass, cancel_call):
"""Test cancellation."""
service_called = asyncio.Event()
service_cancelled = False

async def service_handler(call):
nonlocal service_cancelled
service_called.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
service_cancelled = True
raise

hass.services.async_register("test_domain", "test_service", service_handler)
call_task = hass.async_create_task(
hass.services.async_call("test_domain", "test_service", blocking=True)
)

tasks_1 = asyncio.all_tasks()
await asyncio.wait_for(service_called.wait(), timeout=1)
tasks_2 = asyncio.all_tasks() - tasks_1
assert len(tasks_2) == 1
service_task = tasks_2.pop()

if cancel_call:
call_task.cancel()
else:
service_task.cancel()
with pytest.raises(asyncio.CancelledError):
await call_task

assert service_cancelled


def test_valid_entity_id():
"""Test valid entity ID."""
for invalid in [
Expand Down